Differentiation in Python - Symbolic, Numerical and Automatic
In this lab you explore which tools and libraries are available in Python to compute derivatives. You will perform symbolic differentiation with SymPy library, numerical with NumPy and automatic with JAX (based on Autograd). Comparing the speed of calculations, you will investigate the computational efficiency of those three methods.
Functions in Python
This is just a reminder how to define functions in Python. A simple function , it can be set up as:
You can easily find the derivative of this function analytically. You can set it up as a separate function:
Since you have been working with the NumPy
arrays, you can apply the function to each element of an array:
Now you can apply those functions f
and dfdx
to an array of a larger size. The following code will plot function and its derivative (you don't have to understand the details of the plot_f1_and_f2
function at this stage):
In real life the functions are more complicated and it is not possible to calculate the derivatives analytically every time. Let's explore which tools and libraries are available in Python for the computation of derivatives without manual derivation.
Symbolic Differentiation
Symbolic computation deals with the computation of mathematical objects that are represented exactly, not approximately (e.g. will be written as it is, not as ). For differentiation it would mean that the output will be somehow similar to if you were computing derivatives by hand using rules (analytically). Thus, symbolic differentiation can produce exact derivatives.
Introduction to Symbolic Computation with SymPy
Let's explore symbolic differentiation in Python with commonly used SymPy
library.
If you want to compute the approximate decimal value of , you could normally do it in the following way:
The output is an approximate result. You may recall that and see that it is pretty much impossible to deduct it from the approximate result. But with the symbolic computation systems the roots are not approximated with a decimal number but rather only simplified, so the output is exact:
Numerical evaluation of the result is available, and you can set number of the digits to show in the approximated output:
In SymPy
variables are defined using symbols. In this particular library they need to be predefined (a list of them should be provided). Have a look in the cell below, how the symbolic expression, correspoinding to the mathematical expression , is defined:
Now you can perform various manipulations with this expression: add or subtract some terms, multiply by other expressions etc., just like if you were doing it by hands:
You can also expand the expression:
Or factorise it:
To substitute particular values for the variables in the expression, you can use the following code:
This can be used to evaluate a function :
You might be wondering now, is it possible to evaluate the symbolic functions for each element of the array? At the beginning of the lab you have defined a NumPy
array x_array
:
Now try to evaluate function f_symb
for each element of the array. You will get an error:
It is possible to evaluate the symbolic functions for each element of the array, but you need to make a function NumPy
-friendly first:
The following code should work now:
SymPy
has lots of great functions to manipulate expressions and perform various operations from calculus. More information about them can be found in the official documentation here.
Symbolic Differentiation with SymPy
Let's try to find a derivative of a simple power function using SymPy
:
Some standard functions can be used in the expression, and SymPy
will apply required rules (sum, product, chain) to calculate the derivative:
Now calculate the derivative of the function f_symb
defined in 2.1 and make it NumPy
-friendly:
Evaluate function dfdx_symb_numpy
for each element of the x_array
:
You can apply symbolically defined functions to the arrays of larger size. The following code will plot function and its derivative, you can see that it works:
Limitations of Symbolic Differentiation
Symbolic Differentiation seems to be a great tool. But it also has some limitations. Sometimes the output expressions are too complicated and even not possible to evaluate. For example, find the derivative of the function Analytically, its derivative is:
Have a look the output from the symbolic differentiation:
Looks complicated, but it would not be a problem if it was possible to evaluate. But check, that for instead of the derivative value it outputs some unevaluated expression:
And in the NumPy
friendly version it also will give an error:
In fact, there are problems with the evaluation of the symbolic expressions wherever there is a "jump" in the derivative (e.g. function expressions are different for different intervals of ), like it happens with .
Also, you can see in this example, that you can get a very complicated function as an output of symbolic computation. This is called expression swell, which results in unefficiently slow computations. You will see the example of that below after learning other differentiation libraries in Python.
Numerical Differentiation with NumPy
You can call function np.gradient
to find the derivative of function defined above. The first argument is an array of function values, the second defines the spacing for the evaluation. Here pass it as an array of values, the differences will be calculated automatically. You can find the documentation here.
Try to do numerical differentiation for more complicated function:
The results are pretty impressive, keeping in mind that it does not matter at all how the function was calculated - only the final values of it!
Obviously, the first downside of the numerical differentiation is that it is not exact. However, the accuracy of it is normally enough for machine learning applications. At this stage there is no need to evaluate errors of the numerical differentiation.
Limitations of Numerical Differentiation
Another problem is similar to the one which appeared in the symbolic differentiation: it is inaccurate at the points where there are "jumps" of the derivative. Let's compare the exact derivative of the absolute value function and with numerical approximation:
You can see that the results near the "jump" are and , while they should be and . These cases can give significant errors in the computations.
But the biggest problem with the numerical differentiation is slow speed. It requires function evalutation every time. In machine learning models there are hundreds of parameters and there are hundreds of derivatives to be calculated, performing full function evaluation every time slows down the computation process. You will see the example of it below.
Automatic Differentiation
Automatic differentiation (autodiff) method breaks down the function into common functions (, , , power functions, etc.), and constructs the computational graph consisting of the basic functions. Then the chain rule is used to compute the derivative at any node of the graph. It is the most commonly used approach in machine learning applications and neural networks, as the computational graph for the function and its derivatives can be built during the construction of the neural network, saving in future computations.
The main disadvantage of it is implementational difficulty. However, nowadays there are libraries that are convenient to use, such as MyGrad, Autograd and JAX. Autograd
and JAX
are the most commonly used in the frameworks to build neural networks. JAX
brings together Autograd
functionality for optimization problems, and XLA
(Accelerated Linear Algebra) compiler for parallel computing.
The syntax of Autograd
and JAX
are slightly different. It would be overwhelming to cover both at this stage. In this notebook you will be performing automatic differentiation using one of them: JAX
.
Introduction to JAX
To begin with, load the required libraries. From jax
package you need to load just a couple of functions for now (grad
and vmap
). Package jax.numpy
is a wrapped NumPy
, which pretty much replaces NumPy
when JAX
is used. It can be loaded as np
as if it was an original NumPy
in most of the cases. However, in this notebook you'll upload it as jnp
to distinguish them for now.
Create a new jnp
array and check its type.
The same array can be created just converting previously defined x_array = np.array([1, 2, 3])
, although in some cases JAX
does not operate with integers, thus the values need to be converted to floats. You will see an example of it below.
Note, that jnp
array has a specific type jaxlib.xla_extension.DeviceArray
. In most of the cases the same operators and functions are applicable to them as in the original NumPy
, for example:
But sometimes working with jnp
arrays the approach needs to be changed. In the following code, trying to assign a new value to one of the elements, you will get an error:
To assign a new value to an element in the jnp
array you need to apply functions .at[i]
, stating which element to update, and .set(value)
to set a new value. These functions also operate out-of-place, the updated array is returned as a new array and the original array is not modified by the update.
Although, some of the JAX
functions will work with arrays defined with np
and jnp
. In the following code you will get the same result in both lines:
This is probably confusing - which NumPy
to use then? Usually when JAX
is used, only jax.numpy
gets imported as np
, and used instead of the original one.
Automatic Differentiation with JAX
Time to do automatic differentiation with JAX
. The following code will calculate the derivative of the previously defined function at the point :
Very easy, right? Keep in mind, please, that this cannot be done using integers. The following code will output an error:
Try to apply the grad
function to an array, calculating the derivative for each of its elements:
There is some broadcasting issue there. You don't need to get into more details of this at this stage, function vmap
can be used here to solve the problem.
Note: Broadcasting is covered in the Course 1 of this Specialization "Linear Algebra". You can also review it in the documentation here.
Great, now vmap(grad(f))
can be used to calculate the derivative of function f
for arrays of larger size and you can plot the output:
In the following code you can comment/uncomment lines to visualize the common derivatives. All of them are found using JAX
automatic differentiation. The results look pretty good!
Computational Efficiency of Symbolic, Numerical and Automatic Differentiation
In sections 2.3 and 3.2 low computational efficiency of symbolic and numerical differentiation was discussed. Now it is time to compare speed of calculations for each of three approaches. Try to find the derivative of the same simple function multiple times, evaluating it for an array of a larger size, compare the results and time used:
The results are pretty much the same, but the time used is different. Numerical approach is obviously inefficient when differentiation needs to be performed many times, which happens a lot training machine learning models. Symbolic and automatic approach seem to be performing similarly for this simple example. But if the function becomes a little bit more complicated, symbolic computation will experiance significant expression swell and the calculations will slow down.
Note: Sometimes the execution time results may vary slightly, especially for automatic differentiation. You can run the code above a few time to see different outputs. That does not influence the conclusion that numerical differentiation is slower. timeit
module can be used more efficiently to evaluate execution time of the codes, but that would unnecessary overcomplicate the codes here.
Try to define some polynomial function, which should not be that hard to differentiate, and compare the computation time for its differentiation symbolically and automatically:
Again, the results are similar, but automatic differentiation is times faster.
With the increase of function computation graph, the efficiency of automatic differentiation compared to other methods raises, because autodiff method uses chain rule!
Conclusion
Congratulations! Now you are equiped with Python tools to perform differentiation.