Let the Compiler Help Us Compute: Practical Compiled Code for Neuroscience
Author
import numpy as np
import numexpr as ne
import matplotlib.pyplot as plt
from numba import njit, prangeLet the Compiler Help Us Compute: Practical Compiled Code for Neuroscience
Vectorized NumPy code is often fast. The more work we can hand over to compiled functions, the less often they have to interact with Python in order to do their work.
In this notebook, we’ll look at two powerful tools for doing more computational work outside of Python’s runtime:
numexpr— to evaluate array expressions more efficientlynumba— to compile Python functions into machine code
By the end of this notebook, we should be able to:
- Recognize when Python overhead limits performance
- Use
numexprto reduce interpreter overhead in array expressions - Use Numba’s
@njitdecorator to compile loop-based code - Understand compilation overhead and when it matters
- Avoid common pitfalls that prevent compilation
Section 1: Using numexpr to Reduce Python Overhead
numexpr to Reduce Python Overhead
NumPy is fast because its core operations run in compiled C code. But every time we write an expression like:
np.arange(10_000_000) + 1 + 2 + 3 + 4 + 5 + 6Python still:
- Parses the expression
- Allocates temporary arrays
- Manages intermediate results
numexpr lets you create a micro-function that:
- Minimizes temporary arrays
- Uses multiple CPU cores automatically
Reference
| Code | Description |
|---|---|
import numexpr as ne |
Import numexpr. |
ne.evaluate(expr) |
Evaluate an expression using numexpr. |
ne.NumExpr(expr) |
Pre-compile an expression. |
f(x=x) |
Execute a compiled expression. |
ne.set_num_threads(n) |
Control number of threads. |
Exercises
Example: Speed up the calculation below with ne.evaluate:
# Before (keep this cell the same, for later comparison)
x = np.random.random(10_000_000)
y = %time x * 2 + 3
y.sum()CPU times: user 13.7 ms, sys: 90.3 ms, total: 104 ms
Wall time: 104 msnp.float64(39998907.27311504)# After (modify this code cell)
x = np.random.random(10_000_000)
y = %time ne.evaluate('x * 2 + 3')
y.sum()CPU times: user 75.5 ms, sys: 399 ms, total: 475 ms
Wall time: 35.5 msnp.float64(40004385.33772915)Example: Speed up the calculation below with ne.evaluate:
# Before (keep this cell the same, for later comparison)
x = np.random.random(10_000_000)
y = %time 3 * x ** 2 + 2 * x + 1
y.sum()CPU times: user 72.2 ms, sys: 77.7 ms, total: 150 ms
Wall time: 150 msnp.float64(30000783.167188667)# After (modify this code cell)
x = np.random.random(10_000_000)
y = %time ne.evaluate('3 * x ** 2 + 2 * x + 1')
y.sum()CPU times: user 96.4 ms, sys: 275 ms, total: 371 ms
Wall time: 32.1 msnp.float64(30000495.308897942)Exercise: Speed up the timed calculation below using ne.evaluate():
# Before (keep this the same, for comparison)
x = np.random.random(20_000_000)
y = np.random.random(20_000_000)
z = np.random.random(20_000_000)
total = %time x + y + z
plt.figure(figsize=(10, 1.5))
plt.subplot(1, 4, 1); plt.hist(x, bins=201);
plt.subplot(1, 4, 2); plt.hist(y, bins=201);
plt.subplot(1, 4, 3); plt.hist(z, bins=201);
plt.subplot(1, 4, 4); plt.hist(total, bins=201);
plt.tight_layout();CPU times: user 71.2 ms, sys: 125 ms, total: 196 ms
Wall time: 196 msSolution
# After (modify this cell)
x = np.random.random(20_000_000)
y = np.random.random(20_000_000)
z = np.random.random(20_000_000)
total = %time x + y + z
plt.figure(figsize=(10, 1.5))
plt.subplot(1, 4, 1); plt.hist(x, bins=201);
plt.subplot(1, 4, 2); plt.hist(y, bins=201);
plt.subplot(1, 4, 3); plt.hist(z, bins=201);
plt.subplot(1, 4, 4); plt.hist(total, bins=201);
plt.tight_layout();CPU times: user 95 ms, sys: 32 ms, total: 127 ms
Wall time: 127 msExercise: Speed up the following calculation with `np.evaluate():
x = np.random.random(5_000_000)%%time
# Before (keep this the same, for comparison)
for _ in range(50):
x = x * 2 * 2 * 2 * 2CPU times: user 1.23 s, sys: 641 ms, total: 1.87 s
Wall time: 1.86 sSolution
%%time
# After (modify this cell)
for _ in range(50):
x = ne.evaluate('x * 2 * 2 * 2 * 2')CPU times: user 4.19 s, sys: 21.1 s, total: 25.3 s
Wall time: 1.77 sExercise: Let’s do it again, but this time with a smaller array. What is different about the performance?
x = np.random.random(50)%%time
# Before (keep this the same, for comparison)
for _ in range(50):
x = x * 2 * 2 * 2 * 2CPU times: user 1.47 ms, sys: 0 ns, total: 1.47 ms
Wall time: 1.5 msSolution
%%time
# After (modify this cell)
for _ in range(50):
x = ne.evaluate('x * 2 + 1')CPU times: user 832 μs, sys: 5.14 ms, total: 5.97 ms
Wall time: 5.19 msExercise: Let’s try the smaller case again, but this time pre-compiling the function using the pattern f = ne.NumExpr('x + 1'); f(x). How does this affect the performance?
%%time
# Before (keep this the same, for comparison)
for _ in range(50):
x = x * 2 * 2 * 2 * 2CPU times: user 620 μs, sys: 0 ns, total: 620 μs
Wall time: 627 μsSolution
%%time
# After (modify this cell)
f = ne.NumExpr('x * 2 + 1')
for _ in range(50):
x = f(x)CPU times: user 353 μs, sys: 448 μs, total: 801 μs
Wall time: 820 μsExercise: numexpr is great at helping with broadcasting, but is not a full numpy replacement. Let’s try it out here, and see if it helps:
x = np.random.random(10_000_000)
%time np.sum(np.sqrt(x))CPU times: user 31.6 ms, sys: 88.5 ms, total: 120 ms
Wall time: 119 msnp.float64(6665823.263745766)Solution
x = np.random.random(10_000_000)
%time ne.evaluate("sum(sqrt(x))")CPU times: user 71.8 ms, sys: 0 ns, total: 71.8 ms
Wall time: 73.5 msarray(6665720.97519577)Section 2: Just-in-Time Compilation with @njit
@njit
In Section 1, we improved performance for vectorized array expressions, but not all problems are easy to express with pure NumPy operations.
In neuroscience, we often need:
- Custom spike detection rules
- Trial-by-trial logic
- Nested loops
- Conditional branching
- Simulation steps
Pure Python loops are slow because:
- Every iteration runs through the Python interpreter
- Types are dynamic
- Operations are dispatched at runtime
Numba allows us to compile Python functions into machine code using Just-in-Time (JIT) compilation.
With @njit, we:
- Keep normal Python syntax
- Restrict ourselves to supported features
- Gain near-C speed
Reference
| Code | Description |
|---|---|
from numba import njit |
Import JIT decorator. |
@njit |
Compile function in nopython mode. |
@jit |
Compile but allow Python fallback. |
f.inspect_types() |
Show inferred types. |
Exercises
For exach of the exercises below, we’ll take an existing Python function and add @njit, and compare performance. On the first run @njit compiles the function, on the second run, you’ll see (even more of) the performance benefit.
Exercise: Basic Spike Counter
## Before (keep this the sam, for later comparison)
def count_spikes(x, threshold):
count = 0
for i in range(len(x)):
if x[i] > threshold:
count += 1
return count## Call this cell multiple times and check if the performance changes
x = np.random.random(10_000_000)
%time count_spikes(x, 0.1)CPU times: user 2.01 s, sys: 0 ns, total: 2.01 s
Wall time: 2.01 s9000080Solution
## After (moidfy this cell)
@njit
def count_spikes(x, threshold):
count = 0
for i in range(len(x)):
if x[i] > threshold:
count += 1
return count## Call this cell multiple times and check if the performance changes
x = np.random.random(10_000_000)
%time count_spikes(x, 0.1)CPU times: user 3.1 s, sys: 240 ms, total: 3.33 s
Wall time: 966 ms9000304Exercise: Sum of Squares
# Before (keep this the same, for later comparison)
def sum_of_squares(x):
total = 0.0
for i in range(len(x)):
total += x[i] * x[i]
return total# Time the function here
%time sum_of_squares(np.ones(100_000))CPU times: user 42.2 ms, sys: 1.19 ms, total: 43.3 ms
Wall time: 42.3 msnp.float64(100000.0)Solution
# After (add @njit to this function)
@njit
def sum_of_squares(x):
total = 0.0
for i in range(len(x)):
total += x[i] * x[i]
return total# Time the function here
%time sum_of_squares(np.ones(100_000))CPU times: user 115 ms, sys: 20.8 ms, total: 136 ms
Wall time: 134 ms100000.0Exercise: Conditional Logic
# Before (keep this the same, for later comparison)
def weighted_sum(x, threshold):
total = 0.0
for i in range(len(x)):
if x[i] > threshold:
total += x[i]
else:
total -= x[i]
return total# time it here
%time weighted_sum(np.ones(100_000), 1)CPU times: user 35.8 ms, sys: 3.08 ms, total: 38.9 ms
Wall time: 37.8 msnp.float64(-100000.0)Solution
# After (add @njit to this one)
@njit
def weighted_sum(x, threshold):
total = 0.0
for i in range(len(x)):
if x[i] > threshold:
total += x[i]
else:
total -= x[i]
return total# time it here
%time weighted_sum(np.ones(100_000), 1)CPU times: user 153 ms, sys: 5.6 ms, total: 159 ms
Wall time: 157 ms-100000.0Section 3: Requesting Parallel Execution with parallel=True and prange
parallel=True and prange
Now we go one step further: we ask Numba to run loops in parallel across CPU cores. This is helpful, for example, if we are computing spike counts independently across 200 recording channels. Each channel can be processed independently — which makes it a good candidate for parallelization.
We do this with:
@njit(parallel=True)prange()instead ofrange()
However:
- Not all loops can be parallelized.
- Parallel overhead is real.
- Numba does not always parallelize even when we request it.
Reference
| Code | Description |
|---|---|
@njit(parallel=True) |
Enable parallel compilation. |
from numba import prange |
Import parallel range. |
prange(n) |
Parallel loop iterator. |
f.parallel_diagnostics(level=4) |
Show parallel analysis report. |
Exercises
Exercise: Add (parallel=True) and prange() to this function. What performance benefit is there?
def sum_seq(x):
total = 0.0
for i in range(len(x)):
total += x[i]
return totalx = np.random.random(50_000_000)
%time sum_seq(x)CPU times: user 10.3 s, sys: 0 ns, total: 10.3 s
Wall time: 10.2 snp.float64(24996706.24008694)Solution
@njit(parallel=True)
def sum_seq(x):
total = 0.0
for i in prange(len(x)):
total += x[i]
return totalx = np.random.random(50_000_000)
%time sum_seq(x)CPU times: user 1.41 s, sys: 108 ms, total: 1.52 s
Wall time: 825 ms24999219.29505931Exercise: Run sum_seq.parallel_diagnostics() and give the output to an AI. Was Numba able to paralllize this code?
Solution
sum_seq.parallel_diagnostics(level=4)
================================================================================
Parallel Accelerator Optimizing: Function sum_seq,
/tmp/ipykernel_838177/1605433682.py (1)
================================================================================
Parallel loop listing for Function sum_seq, /tmp/ipykernel_838177/1605433682.py (1)
--------------------------------|loop #ID
@njit(parallel=True) |
def sum_seq(x): |
total = 0.0 |
for i in prange(len(x)):----| #0
total += x[i] |
return total |
--------------------------------- Fusing loops ---------------------------------
Attempting fusion of parallel loops (combines loops with similar properties)...
----------------------------- Before Optimisation ------------------------------
--------------------------------------------------------------------------------
------------------------------ After Optimisation ------------------------------
Parallel structure is already optimal.
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
---------------------------Loop invariant code motion---------------------------
Allocation hoisting:
No allocation hoisting found
Instruction hoisting:
loop #0:
Failed to hoist the following:
dependency: $58binary_subscr.5 = getitem(value=x, index=$parfor__index_2.20, fn=<built-in function getitem>)
dependency: $total_2.23 = inplace_binop(fn=<built-in function iadd>, immutable_fn=<built-in function add>, lhs=total_2, rhs=$58binary_subscr.5, static_lhs=Undefined, static_rhs=Undefined)
dependency: total_2 = $total_2.23
--------------------------------------------------------------------------------Exercise: Add (parallel=True) and prange() to this function. What performance benefit is there?
def dependent_loop(x):
for i in range(1, len(x)):
x[i] = x[i] + x[i - 1]
return xx = np.random.random(5_000_000)
%time dependent_loop(x)CPU times: user 2.02 s, sys: 0 ns, total: 2.02 s
Wall time: 2.01 sarray([1.31466274e-01, 9.42240651e-01, 1.88078036e+00, ...,
2.49941364e+06, 2.49941389e+06, 2.49941395e+06], shape=(5000000,))Solution
@njit(parallel=True)
def dependent_loop(x):
for i in prange(1, len(x)):
x[i] = x[i] + x[i - 1]
return xx = np.random.random(5_000_000)
%time dependent_loop(x)CPU times: user 655 ms, sys: 14.6 ms, total: 670 ms
Wall time: 532 msarray([6.39768101e-02, 9.46011035e-01, 1.76934503e+00, ...,
1.24927298e+05, 1.24927854e+05, 1.24928719e+05], shape=(5000000,))Exercise: Run sum_seq.parallel_diagnostics() and give the output to an AI. Was Numba able to paralllize this code?
Solution
dependent_loop.parallel_diagnostics(level=4)
================================================================================
Parallel Accelerator Optimizing: Function dependent_loop,
/tmp/ipykernel_838177/2706474430.py (1)
================================================================================
Parallel loop listing for Function dependent_loop, /tmp/ipykernel_838177/2706474430.py (1)
-----------------------------------|loop #ID
@njit(parallel=True) |
def dependent_loop(x): |
for i in prange(1, len(x)):----| #1
x[i] = x[i] + x[i - 1] |
return x |
--------------------------------- Fusing loops ---------------------------------
Attempting fusion of parallel loops (combines loops with similar properties)...
----------------------------- Before Optimisation ------------------------------
--------------------------------------------------------------------------------
------------------------------ After Optimisation ------------------------------
Parallel structure is already optimal.
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
---------------------------Loop invariant code motion---------------------------
Allocation hoisting:
No allocation hoisting found
Instruction hoisting:
loop #1:
Has the following hoisted:
$const60.7.1 = const(int, 1)
Failed to hoist the following:
dependency: $54binary_subscr.4 = getitem(value=x, index=$parfor__index_26.40, fn=<built-in function getitem>)
dependency: $binop_sub62.8 = $parfor__index_26.40 - $const60.7.1
dependency: $66binary_subscr.9 = getitem(value=x, index=$binop_sub62.8, fn=<built-in function getitem>)
dependency: $binop_add70.10 = $54binary_subscr.4 + $66binary_subscr.9
--------------------------------------------------------------------------------