diff --git a/lectures/_toc.yml b/lectures/_toc.yml index 97c429c4..0ea886b1 100644 --- a/lectures/_toc.yml +++ b/lectures/_toc.yml @@ -25,6 +25,7 @@ parts: - file: numba - file: jax_intro - file: numpy_vs_numba_vs_jax + - file: autodiff - caption: Working with Data numbered: true chapters: diff --git a/lectures/autodiff.md b/lectures/autodiff.md new file mode 100644 index 00000000..51ef258e --- /dev/null +++ b/lectures/autodiff.md @@ -0,0 +1,522 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.17.2 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +(autodiff)= + +# Adventures with Autodiff + + +```{include} _admonition/gpu.md +``` + +## Overview + +This lecture gives a more thorough introduction to automatic differentiation +using Google JAX, building on {doc}`our brief preview `. + +Automatic differentiation is one of the key elements of modern machine learning +and artificial intelligence. + +As such it has attracted a great deal of investment and there are several +powerful implementations available. + +One of the best of these is the automatic differentiation routines contained +in JAX. + +While other software packages also offer this feature, the JAX version is +particularly powerful because it integrates so well with other core +components of JAX (e.g., JIT compilation and parallelization). + +Automatic differentiation can be used not only +for AI but also for many problems faced in mathematical modeling, such as +multi-dimensional nonlinear optimization and root-finding problems. + +In addition to what's in Anaconda, this lecture will need the following libraries: + +```{code-cell} ipython3 +:tags: [hide-output] + +!pip install jax +``` + +We need the following imports + +```{code-cell} ipython3 +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +from sympy import symbols +``` + +## What is automatic differentiation? + +Autodiff is a technique for calculating derivatives on a computer. + +### Autodiff is not finite differences + +The derivative of $f(x) = \exp(2x)$ is + +$$ + f'(x) = 2 \exp(2x) +$$ + + + +A computer that doesn't know how to take derivatives might approximate this with the finite difference ratio + +$$ + (Df)(x) := \frac{f(x+h) - f(x)}{h} +$$ + +where $h$ is a small positive number. + +```{code-cell} ipython3 +def f(x): + "Original function." + return np.exp(2 * x) + +def f_prime(x): + "True derivative." + return 2 * np.exp(2 * x) + +def Df(x, h=0.1): + "Approximate derivative (finite difference)." + return (f(x + h) - f(x))/h + +x_grid = np.linspace(-2, 1, 200) +fig, ax = plt.subplots() +ax.plot(x_grid, f_prime(x_grid), label="$f'$") +ax.plot(x_grid, Df(x_grid), label="$Df$") +ax.legend() +plt.show() +``` + +This kind of numerical derivative is often inaccurate and unstable. + +One reason is that + +$$ + \frac{f(x+h) - f(x)}{h} \approx \frac{0}{0} +$$ + +Small numbers in the numerator and denominator cause rounding errors. + +The situation is exponentially worse in high dimensions / with higher order derivatives. + ++++ + +### Autodiff is not symbolic calculus + ++++ + +Symbolic calculus tries to use rules for differentiation to produce a single +closed-form expression representing a derivative. + +```{code-cell} ipython3 +m, a, b, x = symbols('m a b x') +f_x = (a*x + b)**m +f_x.diff((x, 6)) # 6-th order derivative +``` + +Symbolic calculus is not well suited to high performance +computing. + +One disadvantage is that symbolic calculus cannot differentiate through control flow. + +Also, using symbolic calculus might involve redundant calculations. + +For example, consider + +$$ + (f g h)' + = (f' g + g' f) h + (f g) h' +$$ + +If we evaluate at $x$, then we evaluate $f(x)$ and $g(x)$ twice each. + +Also, computing $f'(x)$ and $f(x)$ might involve similar terms (e.g., $f(x) = \exp(2x) \implies f'(x) = 2f(x)$) but this is not exploited in symbolic algebra. + ++++ + +### Autodiff + +Autodiff produces functions that evaluate derivatives at numerical values +passed in by the calling code, rather than producing a single symbolic +expression representing the entire derivative. + +Derivatives are constructed by breaking calculations into component parts via the chain rule. + +The chain rule is applied until the point where the terms reduce to primitive functions that the program knows how to differentiate exactly (addition, subtraction, exponentiation, sine and cosine, etc.) + ++++ + +## Some experiments + ++++ + +Let's start with some real-valued functions on $\mathbb R$. + ++++ + +### A differentiable function + ++++ + +Let's test JAX's autodiff with a relatively simple function. + +```{code-cell} ipython3 +def f(x): + return jnp.sin(x) - 2 * jnp.cos(3 * x) * jnp.exp(- x**2) +``` + +We use `grad` to compute the gradient of a real-valued function: + +```{code-cell} ipython3 +f_prime = jax.grad(f) +``` + +Let's plot the result: + +```{code-cell} ipython3 +x_grid = jnp.linspace(-5, 5, 100) +``` + +```{code-cell} ipython3 +fig, ax = plt.subplots() +ax.plot(x_grid, [f(x) for x in x_grid], label="$f$") +ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$") +ax.legend() +plt.show() +``` + +### Absolute value function + ++++ + +What happens if the function is not differentiable? + +```{code-cell} ipython3 +def f(x): + return jnp.abs(x) +``` + +```{code-cell} ipython3 +f_prime = jax.grad(f) +``` + +```{code-cell} ipython3 +fig, ax = plt.subplots() +ax.plot(x_grid, [f(x) for x in x_grid], label="$f$") +ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$") +ax.legend() +plt.show() +``` + +At the nondifferentiable point $0$, `jax.grad` returns the right derivative: + +```{code-cell} ipython3 +f_prime(0.0) +``` + +### Differentiating through control flow + ++++ + +Let's try differentiating through some loops and conditions. + +```{code-cell} ipython3 +def f(x): + def f1(x): + for i in range(2): + x *= 0.2 * x + return x + def f2(x): + x = sum((x**i + i) for i in range(3)) + return x + y = f1(x) if x < 0 else f2(x) + return y +``` + +```{code-cell} ipython3 +f_prime = jax.grad(f) +``` + +```{code-cell} ipython3 +x_grid = jnp.linspace(-5, 5, 100) +``` + +```{code-cell} ipython3 +fig, ax = plt.subplots() +ax.plot(x_grid, [f(x) for x in x_grid], label="$f$") +ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$") +ax.legend() +plt.show() +``` + +### Differentiating through a linear interpolation + ++++ + +We can differentiate through linear interpolation, even though the function is not smooth: + +```{code-cell} ipython3 +n = 20 +xp = jnp.linspace(-5, 5, n) +yp = jnp.cos(2 * xp) + +fig, ax = plt.subplots() +ax.plot(x_grid, jnp.interp(x_grid, xp, yp)) +plt.show() +``` + +```{code-cell} ipython3 +f_prime = jax.grad(jnp.interp) +``` + +```{code-cell} ipython3 +f_prime_vec = jax.vmap(f_prime, in_axes=(0, None, None)) +``` + +```{code-cell} ipython3 +fig, ax = plt.subplots() +ax.plot(x_grid, f_prime_vec(x_grid, xp, yp)) +plt.show() +``` + +## Gradient Descent + ++++ + +Let's try implementing gradient descent. + +As a simple application, we'll use gradient descent to solve for the OLS parameter estimates in simple linear regression. + ++++ + +### A function for gradient descent + ++++ + +Here's an implementation of gradient descent. + +```{code-cell} ipython3 +def grad_descent(f, # Function to be minimized + args, # Extra arguments to the function + x0, # Initial condition + λ=0.1, # Initial learning rate + tol=1e-5, + max_iter=1_000): + """ + Minimize the function f via gradient descent, starting from guess x0. + + The learning rate is computed according to the Barzilai-Borwein method. + + """ + + f_grad = jax.grad(f) + x = jnp.array(x0) + df = f_grad(x, args) + ϵ = tol + 1 + i = 0 + while ϵ > tol and i < max_iter: + new_x = x - λ * df + new_df = f_grad(new_x, args) + Δx = new_x - x + Δdf = new_df - df + λ = jnp.abs(Δx @ Δdf) / (Δdf @ Δdf) + ϵ = jnp.max(jnp.abs(Δx)) + x, df = new_x, new_df + i += 1 + + return x + +``` + +### Simulated data + +We're going to test our gradient descent function by minimizing a sum of least squares in a regression problem. + +Let's generate some simulated data: + +```{code-cell} ipython3 +n = 100 +key = jax.random.key(1234) +x = jax.random.uniform(key, (n,)) + +α, β, σ = 0.5, 1.0, 0.1 # Set the true intercept and slope. +key, subkey = jax.random.split(key) +ϵ = jax.random.normal(subkey, (n,)) + +y = α * x + β + σ * ϵ +``` + +```{code-cell} ipython3 +fig, ax = plt.subplots() +ax.scatter(x, y) +plt.show() +``` + +Let's start by calculating the estimated slope and intercept using closed form solutions. + +```{code-cell} ipython3 +mx = x.mean() +my = y.mean() +α_hat = jnp.sum((x - mx) * (y - my)) / jnp.sum((x - mx)**2) +β_hat = my - α_hat * mx +``` + +```{code-cell} ipython3 +α_hat, β_hat +``` + +```{code-cell} ipython3 +fig, ax = plt.subplots() +ax.scatter(x, y) +ax.plot(x, α_hat * x + β_hat, 'k-') +ax.text(0.1, 1.55, rf'$\hat \alpha = {α_hat:.3}$') +ax.text(0.1, 1.50, rf'$\hat \beta = {β_hat:.3}$') +plt.show() +``` + +### Minimizing squared loss by gradient descent + ++++ + +Let's see if we can get the same values with our gradient descent function. + +First we set up the least squares loss function. + +```{code-cell} ipython3 +@jax.jit +def loss(params, data): + a, b = params + x, y = data + return jnp.sum((y - a * x - b)**2) +``` + +Now we minimize it: + +```{code-cell} ipython3 +p0 = jnp.zeros(2) # Initial guess for α, β +data = x, y +α_hat, β_hat = grad_descent(loss, data, p0) +``` + +Let's plot the results. + +```{code-cell} ipython3 +fig, ax = plt.subplots() +x_grid = jnp.linspace(0, 1, 100) +ax.scatter(x, y) +ax.plot(x_grid, α_hat * x_grid + β_hat, 'k-', alpha=0.6) +ax.text(0.1, 1.55, rf'$\hat \alpha = {α_hat:.3}$') +ax.text(0.1, 1.50, rf'$\hat \beta = {β_hat:.3}$') +plt.show() +``` + +Notice that we get the same estimates as we did from the closed form solutions. + ++++ + +### Adding a squared term + +Now let's try fitting a second order polynomial. + +Here's our new loss function. + +```{code-cell} ipython3 +@jax.jit +def loss(params, data): + a, b, c = params + x, y = data + return jnp.sum((y - a * x**2 - b * x - c)**2) +``` + +Now we're minimizing in three dimensions. + +Let's try it. + +```{code-cell} ipython3 +p0 = jnp.zeros(3) +α_hat, β_hat, γ_hat = grad_descent(loss, data, p0) + +fig, ax = plt.subplots() +ax.scatter(x, y) +ax.plot(x_grid, α_hat * x_grid**2 + β_hat * x_grid + γ_hat, 'k-', alpha=0.6) +ax.text(0.1, 1.55, rf'$\hat \alpha = {α_hat:.3}$') +ax.text(0.1, 1.50, rf'$\hat \beta = {β_hat:.3}$') +plt.show() +``` + +## Exercises + +```{exercise-start} +:label: auto_ex1 +``` + +The function `jnp.polyval` evaluates polynomials. + +For example, if `len(p)` is 3, then `jnp.polyval(p, x)` returns + +$$ + f(p, x) := p_0 x^2 + p_1 x + p_2 +$$ + +Use this function for polynomial regression. + +The (empirical) loss becomes + +$$ + \ell(p, x, y) + = \sum_{i=1}^n (y_i - f(p, x_i))^2 +$$ + +Set $k=4$ and set the initial guess of `params` to `jnp.zeros(k)`. + +Use gradient descent to find the array `params` that minimizes the loss +function and plot the result (following the examples above). + + +```{exercise-end} +``` + +```{solution-start} auto_ex1 +:class: dropdown +``` + +Here's one solution. + +```{code-cell} ipython3 +def loss(params, data): + x, y = data + return jnp.sum((y - jnp.polyval(params, x))**2) +``` + +```{code-cell} ipython3 +k = 4 +p0 = jnp.zeros(k) +p_hat = grad_descent(loss, data, p0) +print('Estimated parameter vector:') +print(p_hat) +print('\n\n') + +fig, ax = plt.subplots() +ax.scatter(x, y) +ax.plot(x_grid, jnp.polyval(p_hat, x_grid), 'k-', alpha=0.6) +plt.show() +``` + + +```{solution-end} +``` diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index adf3b80c..0fa36e26 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -11,6 +11,8 @@ kernelspec: name: python3 --- +(jax_intro)= + # JAX This lecture provides a short introduction to [Google JAX](https://github.com/jax-ml/jax). @@ -51,16 +53,16 @@ We'll use the following imports ```{code-cell} ipython3 import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +import numpy as np import quantecon as qe ``` -In addition, we replace `import numpy as np` with - -```{code-cell} ipython3 -import jax.numpy as jnp -``` +Notice that we import `jax.numpy as jnp`, which provides a NumPy-like interface. -Now we can use `jnp` in place of `np` for the usual array operations: +Here are some standard array operations using `jnp`: ```{code-cell} ipython3 a = jnp.asarray((1.0, 3.2, -1.5)) @@ -150,7 +152,6 @@ As a NumPy replacement, a more significant difference is that arrays are treated For example, with NumPy we can write ```{code-cell} ipython3 -import numpy as np a = np.linspace(0, 1, 3) a ``` @@ -310,7 +311,7 @@ First we produce a key, which seeds the random number generator. ```{code-cell} ipython3 seed = 1234 -key = jax.random.PRNGKey(seed) +key = jax.random.key(seed) ``` Now we can use the key to generate some random numbers: @@ -340,6 +341,79 @@ jax.random.normal(key, (3, 3)) jax.random.normal(subkey, (3, 3)) ``` +The following diagram illustrates how `split` produces a tree of keys from a +single root, with each key generating independent random draws. + +```{code-cell} ipython3 +:tags: [hide-input] + +fig, ax = plt.subplots(figsize=(8, 4)) +ax.set_xlim(-0.5, 6.5) +ax.set_ylim(-0.5, 3.5) +ax.set_aspect('equal') +ax.axis('off') + +box_style = dict(boxstyle="round,pad=0.3", facecolor="white", + edgecolor="black", linewidth=1.5) +box_used = dict(boxstyle="round,pad=0.3", facecolor="#d4edda", + edgecolor="black", linewidth=1.5) + +# Root key +ax.text(3, 3, "key₀", ha='center', va='center', fontsize=11, + bbox=box_style) + +# Level 1 +ax.annotate("", xy=(1.5, 2), xytext=(3, 2.7), + arrowprops=dict(arrowstyle="->", lw=1.5)) +ax.annotate("", xy=(4.5, 2), xytext=(3, 2.7), + arrowprops=dict(arrowstyle="->", lw=1.5)) +ax.text(1.5, 2, "key₁", ha='center', va='center', fontsize=11, + bbox=box_style) +ax.text(4.5, 2, "subkey₁", ha='center', va='center', fontsize=11, + bbox=box_used) +ax.text(5.7, 2, "→ draw", ha='left', va='center', fontsize=10, + color='green') + +# Label the split +ax.text(2, 2.65, "split", ha='center', va='center', fontsize=9, + fontstyle='italic', color='gray') + +# Level 2 +ax.annotate("", xy=(0.5, 1), xytext=(1.5, 1.7), + arrowprops=dict(arrowstyle="->", lw=1.5)) +ax.annotate("", xy=(2.5, 1), xytext=(1.5, 1.7), + arrowprops=dict(arrowstyle="->", lw=1.5)) +ax.text(0.5, 1, "key₂", ha='center', va='center', fontsize=11, + bbox=box_style) +ax.text(2.5, 1, "subkey₂", ha='center', va='center', fontsize=11, + bbox=box_used) +ax.text(3.7, 1, "→ draw", ha='left', va='center', fontsize=10, + color='green') + +ax.text(0.7, 1.65, "split", ha='center', va='center', fontsize=9, + fontstyle='italic', color='gray') + +# Level 3 +ax.annotate("", xy=(0, 0), xytext=(0.5, 0.7), + arrowprops=dict(arrowstyle="->", lw=1.5)) +ax.annotate("", xy=(1.5, 0), xytext=(0.5, 0.7), + arrowprops=dict(arrowstyle="->", lw=1.5)) +ax.text(0, 0, "key₃", ha='center', va='center', fontsize=11, + bbox=box_style) +ax.text(1.5, 0, "subkey₃", ha='center', va='center', fontsize=11, + bbox=box_used) +ax.text(2.7, 0, "→ draw", ha='left', va='center', fontsize=10, + color='green') +ax.text(0, 0.65, "split", ha='center', va='center', fontsize=9, + fontstyle='italic', color='gray') + +ax.text(3, -0.5, "⋮", ha='center', va='center', fontsize=14) + +ax.set_title("PRNG Key Splitting Tree", fontsize=13, pad=10) +plt.tight_layout() +plt.show() +``` + This syntax will seem unusual for a NumPy or Matlab user --- but will make a lot of sense when we progress to parallel programming. @@ -358,7 +432,7 @@ def gen_random_matrices(key, n=2, k=3): ```{code-cell} ipython3 seed = 42 -key = jax.random.PRNGKey(seed) +key = jax.random.key(seed) matrices = gen_random_matrices(key) ``` @@ -376,7 +450,7 @@ def gen_random_matrices(key, n=2, k=3): ``` ```{code-cell} ipython3 -key = jax.random.PRNGKey(seed) +key = jax.random.key(seed) matrices = gen_random_matrices(key) ``` @@ -426,7 +500,7 @@ def random_sum_jax(key): With the same key, we always get the same result: ```{code-cell} ipython3 -key = jax.random.PRNGKey(42) +key = jax.random.key(42) random_sum_jax(key) ``` @@ -576,7 +650,7 @@ Let's try the same thing with a more complex function. ```{code-cell} def f(x): - y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - 0.1 * x**2 + y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - x**2 return y ``` @@ -628,10 +702,70 @@ with qe.Timer(): The outcome is similar to the `cos` example --- JAX is faster, especially on the second run after JIT compilation. -Moreover, with JAX, we have another trick up our sleeve: +Moreover, with JAX, we have another trick up our sleeve --- we can JIT-compile +the *entire* function, not just individual operations. + +### How JIT compilation works + +When we apply `jax.jit` to a function, JAX *traces* it: instead of executing +the operations immediately, it records the sequence of operations as a +computational graph and hands that graph to the +[XLA](https://openxla.org/xla) compiler. +XLA then fuses and optimizes the operations into a single compiled kernel +tailored to the available hardware (CPU, GPU, or TPU). -### Compiling the Whole Function +The following diagram shows this pipeline for a simple function: + +```{code-cell} ipython3 +:tags: [hide-input] + +fig, ax = plt.subplots(figsize=(7, 2)) +ax.set_xlim(-0.2, 7.2) +ax.set_ylim(0.2, 2.2) +ax.axis('off') + +# Boxes for pipeline stages +stages = [ + (0.7, 1.2, "Python\nfunction"), + (2.6, 1.2, "computational\ngraph"), + (4.5, 1.2, "optimized\nkernel"), + (6.4, 1.2, "fast\nexecution"), +] + +colors = ["#e3f2fd", "#fff9c4", "#f3e5f5", "#d4edda"] + +for (x, y, label), color in zip(stages, colors): + box = mpatches.FancyBboxPatch( + (x - 0.7, y - 0.5), 1.4, 1.0, + boxstyle="round,pad=0.15", + facecolor=color, edgecolor="black", linewidth=1.5) + ax.add_patch(box) + ax.text(x, y, label, ha='center', va='center', fontsize=9) + +# Arrows with labels +arrows = [ + (1.4, 1.9, "trace"), + (3.3, 3.8, "XLA"), + (5.2, 5.7, "run"), +] + +for x_start, x_end, label in arrows: + ax.annotate("", xy=(x_end, 1.2), xytext=(x_start, 1.2), + arrowprops=dict(arrowstyle="->", lw=1.5, color="gray")) + ax.text((x_start + x_end) / 2, 1.55, label, + ha='center', fontsize=8, color='gray') + +plt.tight_layout() +plt.show() +``` + +The first call to a JIT-compiled function incurs compilation overhead, but +subsequent calls with the same input shapes and types reuse the cached +compiled code and run at full speed. + + +### Compiling the whole function The JAX just-in-time (JIT) compiler can accelerate execution within functions by fusing linear algebra operations into a single optimized kernel. @@ -736,24 +870,85 @@ The compiler loves pure functions and functional programming because * Pure functions are easier to parallelize and optimize (don't depend on shared mutable state) -## Gradients +## Vectorization with `vmap` -JAX can use automatic differentiation to compute gradients. +Another powerful JAX transformation is `jax.vmap`, which automatically +vectorizes a function written for a single input so that it operates over +batches. -This can be extremely useful for optimization and solving nonlinear systems. +This avoids the need to manually write vectorized code or use explicit loops. -We will see significant applications later in this lecture series. +### A simple example -For now, here's a very simple illustration involving the function +Suppose we have a function that computes summary statistics for a single array: ```{code-cell} ipython3 -def f(x): - return (x**2) / 2 +def summary(x): + return jnp.mean(x), jnp.median(x) +``` + +We can apply it to a single vector: + +```{code-cell} ipython3 +x = jnp.array([1.0, 2.0, 5.0]) +summary(x) +``` + +Now suppose we have a matrix and want to compute these statistics for each row. + +Without `vmap`, we'd need an explicit loop: + +```{code-cell} ipython3 +X = jnp.array([[1.0, 2.0, 5.0], + [4.0, 5.0, 6.0], + [1.0, 8.0, 9.0]]) + +for row in X: + print(summary(row)) +``` + +However, Python loops are slow and cannot be efficiently compiled or +parallelized by JAX. + +Using `vmap` keeps the computation on the accelerator and composes with other +JAX transformations like `jit` and `grad`: + +```{code-cell} ipython3 +batch_summary = jax.vmap(summary) +batch_summary(X) ``` -Let's take the derivative: +The function `summary` was written for a single array, and `vmap` automatically +lifted it to operate row-wise over a matrix --- no loops, no reshaping. + +### Combining transformations + +One of JAX's strengths is that transformations compose naturally. + +For example, we can JIT-compile a vectorized function: ```{code-cell} ipython3 +fast_batch_summary = jax.jit(jax.vmap(summary)) +fast_batch_summary(X) +``` + +This composition of `jit`, `vmap`, and (as we'll see next) `grad` is central to +JAX's design and makes it especially powerful for scientific computing and +machine learning. + + +## Automatic differentiation: a preview + +JAX can use automatic differentiation to compute gradients. + +This can be extremely useful for optimization and solving nonlinear systems. + +Here's a simple illustration involving the function $f(x) = x^2 / 2$: + +```{code-cell} ipython3 +def f(x): + return (x**2) / 2 + f_prime = jax.grad(f) ``` @@ -764,8 +959,6 @@ f_prime(10.0) Let's plot the function and derivative, noting that $f'(x) = x$. ```{code-cell} ipython3 -import matplotlib.pyplot as plt - fig, ax = plt.subplots() x_grid = jnp.linspace(-4, 4, 200) ax.plot(x_grid, f(x_grid), label="$f$") @@ -774,7 +967,9 @@ ax.legend(loc='upper center') plt.show() ``` -We defer further exploration of automatic differentiation with JAX until {doc}`jax:autodiff`. +Automatic differentiation is a deep topic with many applications in economics +and finance. We provide a more thorough treatment in {doc}`our lecture on +autodiff `. ## Exercises @@ -819,7 +1014,7 @@ def compute_call_price_jax(β=β, ρ=ρ, ν=ν, M=M, - key=jax.random.PRNGKey(1)): + key=jax.random.key(1)): s = jnp.full(M, np.log(S0)) h = jnp.full(M, h0) diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index aa9c85a9..084f2c0d 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -55,13 +55,17 @@ We will use the following imports. ```{code-cell} ipython3 import random +from functools import partial + import numpy as np +import numba import quantecon as qe import matplotlib.pyplot as plt from mpl_toolkits.mplot3d.axes3d import Axes3D from matplotlib import cm import jax import jax.numpy as jnp +from jax import lax ``` ## Vectorized operations @@ -101,7 +105,7 @@ ax.plot_surface(x, y, f(x, y), rstride=2, cstride=2, - cmap=cm.jet, + cmap=cm.viridis, alpha=0.7, linewidth=0.25) ax.set_zlim(-0.5, 1.0) @@ -162,8 +166,6 @@ before it sees the size of the arrays `x` and `y`.) Now let's see if we can achieve better performance using Numba with a simple loop. ```{code-cell} ipython3 -import numba - @numba.jit def compute_max_numba(grid): m = -np.inf @@ -177,9 +179,9 @@ def compute_max_numba(grid): grid = np.linspace(-3, 3, 3_000) with qe.Timer(precision=8): - z_max_numpy = compute_max_numba(grid) + z_max_numba = compute_max_numba(grid) -print(f"Numba result: {z_max_numpy:.6f}") +print(f"Numba result: {z_max_numba:.6f}") ``` Let's run again to eliminate compile time. @@ -232,7 +234,7 @@ The reason is that the variable `m` is shared across threads and not properly co When multiple threads try to read and write `m` simultaneously, they interfere with each other. -Threads read stale values of `m` or overwrite each other's updates --— or `m` never gets updated from its initial value. +Threads read stale values of `m` or overwrite each other's updates --- or `m` never gets updated from its initial value. Here's a more carefully written version. @@ -299,7 +301,7 @@ calculation, we can use a `meshgrid` operation designed for this purpose: ```{code-cell} ipython3 grid = jnp.linspace(-3, 3, 3_000) -x_mesh, y_mesh = np.meshgrid(grid, grid) +x_mesh, y_mesh = jnp.meshgrid(grid, grid) with qe.Timer(precision=8): z_max = jnp.max(f(x_mesh, y_mesh)) @@ -316,7 +318,7 @@ with qe.Timer(precision=8): z_max.block_until_ready() ``` -Once compiled, JAX is significantly faster than NumPy due to GPU acceleration. +Once compiled, JAX is significantly faster than NumPy, especially on a GPU. The compilation overhead is a one-time cost that pays off when the function is called repeatedly. @@ -419,7 +421,7 @@ Let's try it. with qe.Timer(precision=8): z_max = compute_max_vmap_v2(grid).block_until_ready() -print(f"JAX vmap v1 result: {z_max:.6f}") +print(f"JAX vmap v2 result: {z_max:.6f}") ``` Let's run it again to eliminate compilation time: @@ -508,9 +510,6 @@ Now let's create a JAX version using `lax.scan`: (We'll hold `n` static because it affects array size and hence JAX wants to specialize on its value in the compiled code.) ```{code-cell} ipython3 -from jax import lax -from functools import partial - cpu = jax.devices("cpu")[0] @partial(jax.jit, static_argnums=(1,), device=cpu) @@ -575,3 +574,40 @@ Additionally, JAX's immutable arrays mean we cannot simply update array elements For this type of sequential operation, Numba is the clear winner in terms of code clarity and ease of implementation, as well as high performance. + +## Overall recommendations + +Let's now step back and summarize the trade-offs. + +For **vectorized operations**, JAX is the strongest choice. + +It matches or exceeds NumPy in speed, thanks to JIT compilation and efficient +parallelization across CPUs and GPUs. + +The `vmap` transformation reduces memory usage and often leads to clearer code +than traditional meshgrid-based vectorization. + +In addition, JAX functions are automatically differentiable, as we explore in +{doc}`autodiff`. + +For **sequential operations**, Numba has clear advantages. + +The code is natural and readable --- just a Python loop with a decorator --- +and performance is excellent. + +JAX can handle sequential problems via `lax.scan`, but the syntax is less +intuitive and the performance gain is minimal for purely sequential work. + +That said, `lax.scan` has one important advantage: it supports automatic +differentiation through the loop, which Numba cannot do. + +If you need to differentiate through a sequential computation (e.g., computing +sensitivities of a trajectory to model parameters), JAX is the better choice +despite the less natural syntax. + +In practice, many problems involve a mix of both patterns. + +A good rule of thumb: default to JAX for new projects, especially when +hardware acceleration or differentiability might be useful, and reach for Numba +when you have a tight sequential loop that needs to be fast and readable. +