From 0478ce4ad754671c03dddc5329984b4478fd7904 Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Sun, 5 Apr 2026 08:20:04 +0900 Subject: [PATCH] Modernize and improve Numba lecture Co-Authored-By: Claude Opus 4.6 (1M context) --- lectures/numba.md | 262 ++++++++++++++-------------------------------- 1 file changed, 79 insertions(+), 183 deletions(-) diff --git a/lectures/numba.md b/lectures/numba.md index b6a3cccb..8d0fb182 100644 --- a/lectures/numba.md +++ b/lectures/numba.md @@ -45,31 +45,24 @@ import matplotlib.pyplot as plt ## Overview -In an {doc}`earlier lecture ` we learned about vectorization, which is one method to improve speed and efficiency in numerical work. +In an {doc}`earlier lecture ` we discussed vectorization, +can improve execution speed by sending array processing operations in batch to efficient low-level code. -Vectorization involves sending array processing -operations in batch to efficient low-level code. +However, as {ref}`discussed previously `, traditional vectorization schemes, such as those found in Matlab, Julia, and NumPy, have several weaknesses. -However, as {ref}`discussed previously `, vectorization has several weaknesses. +For example, they can be highly memory-intensive and, for some algorithms, vectorization is ineffective or impossible. -One is that it is highly memory-intensive when working with large amounts of data. +One way around these problems is through [Numba](https://numba.pydata.org/), a +**just in time (JIT) compiler** for Python that is oriented towards numerical work. -Another is that the set of algorithms that can be entirely vectorized is not universal. +Numba compiles functions to native machine code instructions during runtime. -In fact, for some algorithms, vectorization is ineffective. +When it succeeds, Numba will be on par with machine code from low-level languages. -Fortunately, a new Python library called [Numba](https://numba.pydata.org/) -solves many of these problems. +In addition, Numba can do other useful tricks, such as {ref}`multithreading` or +interfacing with GPUs (through `numba.cuda`). -It does so through something called **just in time (JIT) compilation**. - -The key idea is to compile functions to native machine code instructions on the fly. - -When it succeeds, the compiled code is extremely fast. - -Beyond speed gains from compilation, Numba is specifically designed for numerical work and can also do other tricks such as {ref}`multithreading`. - -This lecture introduces the main ideas. +This lecture introduces the core ideas. (numba_link)= ## {index}`Compiling Functions ` @@ -77,18 +70,17 @@ This lecture introduces the main ideas. ```{index} single: Python; Numba ``` -As stated above, Numba's primary use is compiling functions to fast native -machine code during runtime. (quad_map_eg)= ### An Example -Let's consider a problem that is difficult to vectorize: generating the trajectory of a difference equation given an initial condition. +Let's consider a problem that's difficult to vectorize: generating the +trajectory of a difference equation given an initial condition. We will take the difference equation to be the quadratic map $$ -x_{t+1} = \alpha x_t (1 - x_t) + x_{t+1} = \alpha x_t (1 - x_t) $$ In what follows we set @@ -162,7 +154,6 @@ time3 = timer3.elapsed time1 / time3 # Calculate speed gain ``` -This kind of speed gain is impressive relative to how simple and clear the modification is. ### How and When it Works @@ -180,18 +171,19 @@ The basic idea is this: * This makes it hard to *pre*-compile the function (i.e., compile before runtime). * However, when we do actually call the function, say by running `qm(0.5, 10)`, the types of `x0` and `n` become clear. -* Moreover, the types of other variables in `qm` can be inferred once the input types are known. +* Moreover, the types of *other variables* in `qm` *can be inferred once the input types are known*. * So the strategy of Numba and other JIT compilers is to wait until this - moment, and *then* compile the function. + moment, and then compile the function. That's why it is called "just-in-time" compilation. Note that, if you make the call `qm(0.5, 10)` and then follow it with `qm(0.9, 20)`, compilation only takes place on the first call. -The compiled code is then cached and recycled as required. +This is because compiled code is cached and reused as required. This is why, in the code above, `time3` is smaller than `time2`. + ## Decorator Notation In the code above we created a JIT compiled version of `qm` via the call @@ -204,9 +196,7 @@ In practice this would typically be done using an alternative *decorator* syntax (We discuss decorators in a {doc}`separate lecture ` but you can skip the details at this stage.) -Let's see how this is done. - -To target a function for JIT compilation we can put `@jit` before the function definition. +Specifically, to target a function for JIT compilation we can put `@jit` before the function definition. Here's what this looks like for `qm` @@ -248,16 +238,24 @@ In an ideal setting, Numba can infer all necessary type information. This allows it to generate native machine code, without having to call the Python runtime environment. -In such a setting, Numba will be on par with machine code from low-level languages. - When Numba cannot infer all type information, it will raise an error. +```{note} +In older versions of Numba, the `@jit` decorator would silently fall back +to "object mode" when it could not infer all types, which provided little or +no speed gain. Current versions of Numba use `nopython` mode by default, +meaning the compiler insists on full type inference and raises an error if +it fails. You will often see `@njit` used in other code, which is simply +an alias for `@jit(nopython=True)`. Since nopython mode is now the default, +`@jit` and `@njit` are equivalent. +``` + For example, in the (artificial) setting below, Numba is unable to determine the type of function `mean` when compiling the function `bootstrap` ```{code-cell} ipython3 @jit -def bootstrap(data, statistics, n): - bootstrap_stat = np.empty(n) +def bootstrap(data, statistics, n_resamples): + bootstrap_stat = np.empty(n_resamples) n = len(data) for i in range(n_resamples): resample = np.random.choice(data, size=n, replace=True) @@ -289,121 +287,10 @@ with qe.Timer(): bootstrap(data, mean, n_resamples) ``` -## Compiling Classes - -As mentioned above, at present Numba can only compile a subset of Python. - -However, that subset is ever expanding. - -Notably, Numba is now quite effective at compiling classes. - -If a class is successfully compiled, then its methods act as JIT-compiled -functions. - -To give one example, let's consider the class for analyzing the Solow growth model we -created in {doc}`this lecture `. - -To compile this class we use the `@jitclass` decorator: - -```{code-cell} ipython3 -from numba import float64 -from numba.experimental import jitclass -``` - -Notice that we also imported something called `float64`. - -This is a data type representing standard floating point numbers. - -We are importing it here because Numba needs a bit of extra help with types when it tries to deal with classes. - -Here's our code: - -```{code-cell} ipython3 -solow_data = [ - ('n', float64), - ('s', float64), - ('δ', float64), - ('α', float64), - ('z', float64), - ('k', float64) -] - -@jitclass(solow_data) -class Solow: - r""" - Implements the Solow growth model with the update rule - - k_{t+1} = [(s z k^α_t) + (1 - δ)k_t] /(1 + n) - - """ - def __init__(self, n=0.05, # population growth rate - s=0.25, # savings rate - δ=0.1, # depreciation rate - α=0.3, # share of labor - z=2.0, # productivity - k=1.0): # current capital stock - - self.n, self.s, self.δ, self.α, self.z = n, s, δ, α, z - self.k = k - - def h(self): - "Evaluate the h function" - # Unpack parameters (get rid of self to simplify notation) - n, s, δ, α, z = self.n, self.s, self.δ, self.α, self.z - # Apply the update rule - return (s * z * self.k**α + (1 - δ) * self.k) / (1 + n) - - def update(self): - "Update the current state (i.e., the capital stock)." - self.k = self.h() - - def steady_state(self): - "Compute the steady state value of capital." - # Unpack parameters (get rid of self to simplify notation) - n, s, δ, α, z = self.n, self.s, self.δ, self.α, self.z - # Compute and return steady state - return ((s * z) / (n + δ))**(1 / (1 - α)) - - def generate_sequence(self, t): - "Generate and return a time series of length t" - path = [] - for i in range(t): - path.append(self.k) - self.update() - return path -``` - -First we specified the types of the instance data for the class in -`solow_data`. - -After that, targeting the class for JIT compilation only requires adding -`@jitclass(solow_data)` before the class definition. - -When we call the methods in the class, the methods are compiled just like functions. - -```{code-cell} ipython3 -s1 = Solow() -s2 = Solow(k=8.0) - -T = 60 -fig, ax = plt.subplots() - -# Plot the common steady state value of capital -ax.plot([s1.steady_state()]*T, 'k-', label='steady state') - -# Plot time series for each economy -for s in s1, s2: - lb = f'capital series from initial state {s.k}' - ax.plot(s.generate_sequence(T), 'o-', lw=2, alpha=0.6, label=lb) -ax.set_ylabel('$k_{t}$', fontsize=12) -ax.set_xlabel('$t$', fontsize=12) -ax.legend() -plt.show() -``` ## Dangers and Limitations -Let's review the above and add some cautionary notes. +Let's add some cautionary notes. ### Limitations @@ -414,9 +301,10 @@ For simple routines, Numba infers types very well. For larger ones, or for routines using external libraries, it can easily fail. -Hence, it's prudent when using Numba to focus on speeding up small, time-critical snippets of code. +Hence, it's best to focus on speeding up small, time-critical snippets of code. + +This will give you much better performance than blanketing your Python programs with `@jit` statements. -This will give you much better performance than blanketing your Python programs with `@njit` statements. ### A Gotcha: Global Variables @@ -445,16 +333,32 @@ function. When Numba compiles machine code for functions, it treats global variables as constants to ensure type stability. -(multithreading)= -## Multithreaded Loops in Numba +### Caching Compiled Code -In addition to JIT compilation, Numba provides powerful support for parallel computing on CPUs. +By default, Numba recompiles functions each time a new Python session starts. -By distributing computations across multiple CPU cores, we can achieve significant speed gains for many numerical algorithms. +To avoid this overhead, you can pass `cache=True` to the decorator: -The key tool for parallelization in Numba is the `prange` function, which tells Numba to execute loop iterations in parallel across available CPU cores. +```{code-cell} ipython3 +@jit(cache=True) +def qm(x0, n): + x = np.empty(n+1) + x[0] = x0 + for t in range(n): + x[t+1] = α * x[t] * (1 - x[t]) + return x +``` + +This stores the compiled code on disk so that subsequent sessions can skip +the compilation step. + +(multithreading)= +## Multithreaded Loops in Numba -This approach to multithreading works well for a wide range of problems in scientific computing and quantitative economics. +In addition to JIT compilation, Numba provides support for parallel computing on CPUs. + +The key tool for parallelization in Numba is the `prange` function, which tells +Numba to execute loop iterations in parallel across available CPU cores. To illustrate, let's look first at a simple, single-threaded (i.e., non-parallelized) piece of code. @@ -476,18 +380,17 @@ distribution. Here's the code: ```{code-cell} ipython3 -from numpy.random import randn -from numba import njit +from numba import jit -@njit +@jit def h(w, r=0.1, s=0.3, v1=0.1, v2=1.0): """ Updates household wealth. """ # Draw shocks - R = np.exp(v1 * randn()) * (1 + r) - y = np.exp(v2 * randn()) + R = np.exp(v1 * np.random.randn()) * (1 + r) + y = np.exp(v2 * np.random.randn()) # Update wealth w = R * s * w + y @@ -522,14 +425,13 @@ calculate median wealth for this group. Suppose we are interested in the long-run average of this median over time. -It turns out that, for the specification that we've chosen above, we can -calculate this by taking a one-period snapshot of what has happened to median +For the specification that we've chosen above, we can +calculate this by taking a one-period cross-sectional snapshot of median wealth of the group at the end of a long simulation. -Moreover, provided the simulation period is long enough, initial conditions -don't matter. +Moreover, provided the simulation period is long enough, initial conditions don't matter. -* This is due to something called ergodicity, which we will discuss [later on](https://python.quantecon.org/finite_markov.html#id15). +(This is due to [ergodicity](https://python.quantecon.org/finite_markov.html#id15).) So, in summary, we are going to simulate 50,000 households by @@ -541,7 +443,7 @@ Then we'll calculate median wealth at the end period. Here's the code: ```{code-cell} ipython3 -@njit +@jit def compute_long_run_median(w0=1, T=1000, num_reps=50_000): obs = np.empty(num_reps) @@ -568,7 +470,7 @@ To do so, we add the `parallel=True` flag and change `range` to `prange`: ```{code-cell} ipython3 from numba import prange -@njit(parallel=True) +@jit(parallel=True) def compute_long_run_median_parallel(w0=1, T=1000, num_reps=50_000): obs = np.empty(num_reps) @@ -611,13 +513,11 @@ Compare speed with and without Numba when the sample size is large. Here is one solution: ```{code-cell} ipython3 -from random import uniform - @jit def calculate_pi(n=1_000_000): count = 0 for i in range(n): - u, v = uniform(0, 1), uniform(0, 1) + u, v = np.random.uniform(0, 1), np.random.uniform(0, 1) d = np.sqrt((u - 0.5)**2 + (v - 0.5)**2) if d < 0.5: count += 1 @@ -638,11 +538,10 @@ with qe.Timer(): calculate_pi() ``` -If we switch off JIT compilation by removing `@njit`, the code takes around +If we switch off JIT compilation by removing `@jit`, the code takes around 150 times as long on our machine. -So we get a speed gain of 2 orders of magnitude--which is huge--by adding four -characters. +So we get a speed gain of 2 orders of magnitude by adding four characters. ```{solution-end} ``` @@ -686,7 +585,7 @@ If your code is correct, it should be about 2/3. :class: dropdown * Represent the low state as 0 and the high state as 1. -* If you want to store integers in a NumPy array and then apply JIT compilation, use `x = np.empty(n, dtype=np.int_)`. +* If you want to store integers in a NumPy array and then apply JIT compilation, use `x = np.empty(n, dtype=np.int64)`. ``` @@ -710,7 +609,7 @@ Here's a pure Python version of the function ```{code-cell} ipython3 def compute_series(n): - x = np.empty(n, dtype=np.int_) + x = np.empty(n, dtype=np.int64) x[0] = 1 # Start in state 1 U = np.random.uniform(0, 1, size=n) for t in range(1, n): @@ -795,13 +694,11 @@ For the size of the Monte Carlo simulation, use something substantial, such as Here is one solution: ```{code-cell} ipython3 -from random import uniform - -@njit(parallel=True) +@jit(parallel=True) def calculate_pi(n=1_000_000): count = 0 for i in prange(n): - u, v = uniform(0, 1), uniform(0, 1) + u, v = np.random.uniform(0, 1), np.random.uniform(0, 1) d = np.sqrt((u - 0.5)**2 + (v - 0.5)**2) if d < 0.5: count += 1 @@ -823,7 +720,7 @@ with qe.Timer(): ``` By switching parallelization on and off (selecting `True` or -`False` in the `@njit` annotation), we can test the speed gain that +`False` in the `@jit` annotation), we can test the speed gain that multithreading provides on top of JIT compilation. On our workstation, we find that parallelization increases execution speed by @@ -913,13 +810,12 @@ Using this fact, the solution can be written as follows. ```{code-cell} ipython3 -from numpy.random import randn M = 10_000_000 n, β, K = 20, 0.99, 100 μ, ρ, ν, S0, h0 = 0.0001, 0.1, 0.001, 10, 0 -@njit(parallel=True) +@jit(parallel=True) def compute_call_price_parallel(β=β, μ=μ, S0=S0, @@ -936,10 +832,10 @@ def compute_call_price_parallel(β=β, h = h0 # Simulate forward in time for t in range(n): - s = s + μ + np.exp(h) * randn() - h = ρ * h + ν * randn() + s = s + μ + np.exp(h) * np.random.randn() + h = ρ * h + ν * np.random.randn() # And add the value max{S_n - K, 0} to current_sum - current_sum += np.maximum(np.exp(s) - K, 0) + current_sum += max(np.exp(s) - K, 0) return β**n * current_sum / M ```