Let the Compiler Help Us Compute: Practical Compiled Code for Neuroscience
Author
import numpy as np
import numexpr as ne
import matplotlib.pyplot as plt
from numba import njit, prangeLet the Compiler Help Us Compute: Practical Compiled Code for Neuroscience
Vectorized NumPy code is often fast. The more work we can hand over to compiled functions, the less often they have to interact with Python in order to do their work.
In this notebook, we’ll look at two powerful tools for doing more computational work outside of Python’s runtime:
numexpr— to evaluate array expressions more efficientlynumba— to compile Python functions into machine code
By the end of this notebook, we should be able to:
- Recognize when Python overhead limits performance
- Use
numexprto reduce interpreter overhead in array expressions - Use Numba’s
@njitdecorator to compile loop-based code - Understand compilation overhead and when it matters
- Avoid common pitfalls that prevent compilation
Section 1: Using numexpr to Reduce Python Overhead
NumPy is fast because its core operations run in compiled C code. But every time we write an expression like:
np.arange(10_000_000) + 1 + 2 + 3 + 4 + 5 + 6Python still:
- Parses the expression
- Allocates temporary arrays
- Manages intermediate results
numexpr lets you create a micro-function that:
- Minimizes temporary arrays
- Uses multiple CPU cores automatically
Reference
| Code | Description |
|---|---|
import numexpr as ne |
Import numexpr. |
ne.evaluate(expr) |
Evaluate an expression using numexpr. |
ne.NumExpr(expr) |
Pre-compile an expression. |
f(x=x) |
Execute a compiled expression. |
ne.set_num_threads(n) |
Control number of threads. |
Exercises
Example: Speed up the calculation below with ne.evaluate:
# Before (keep this cell the same, for later comparison)
x = np.random.random(10_000_000)
y = %time x * 2 + 3
y.sum()CPU times: user 5.96 ms, sys: 47.5 ms, total: 53.4 ms
Wall time: 59.8 ms39998281.98519501# After (modify this code cell)
x = np.random.random(10_000_000)
y = %time ne.evaluate('x * 2 + 3')
y.sum()CPU times: user 23.6 ms, sys: 276 ms, total: 299 ms
Wall time: 23.8 ms39997663.02705389Example: Speed up the calculation below with ne.evaluate:
# Before (keep this cell the same, for later comparison)
x = np.random.random(10_000_000)
y = %time 3 * x ** 2 + 2 * x + 1
y.sum()CPU times: user 44.4 ms, sys: 65 ms, total: 109 ms
Wall time: 109 ms30004296.01429403# After (modify this code cell)
x = np.random.random(10_000_000)
y = %time ne.evaluate('3 * x ** 2 + 2 * x + 1')
y.sum()CPU times: user 84.4 ms, sys: 184 ms, total: 268 ms
Wall time: 20.9 ms29997648.79665374Exercise: Speed up the timed calculation below using ne.evaluate():
# Before (keep this the same, for comparison)
x = np.random.random(20_000_000)
y = np.random.random(20_000_000)
z = np.random.random(20_000_000)
total = %time x + y + z
plt.figure(figsize=(10, 1.5))
plt.subplot(1, 4, 1); plt.hist(x, bins=201);
plt.subplot(1, 4, 2); plt.hist(y, bins=201);
plt.subplot(1, 4, 3); plt.hist(z, bins=201);
plt.subplot(1, 4, 4); plt.hist(total, bins=201);
plt.tight_layout();CPU times: user 27.1 ms, sys: 114 ms, total: 141 ms
Wall time: 141 msSolution
# After (modify this cell)
x = np.random.random(20_000_000)
y = np.random.random(20_000_000)
z = np.random.random(20_000_000)
total = %time x + y + z
plt.figure(figsize=(10, 1.5))
plt.subplot(1, 4, 1); plt.hist(x, bins=201);
plt.subplot(1, 4, 2); plt.hist(y, bins=201);
plt.subplot(1, 4, 3); plt.hist(z, bins=201);
plt.subplot(1, 4, 4); plt.hist(total, bins=201);
plt.tight_layout();CPU times: user 41.7 ms, sys: 16.9 ms, total: 58.6 ms
Wall time: 58.6 msExercise: Speed up the following calculation with `np.evaluate():
x = np.random.random(5_000_000)%%time
# Before (keep this the same, for comparison)
for _ in range(50):
x = x * 2 * 2 * 2 * 2CPU times: user 690 ms, sys: 284 ms, total: 974 ms
Wall time: 972 msSolution
%%time
# After (modify this cell)
for _ in range(50):
x = ne.evaluate('x * 2 * 2 * 2 * 2')CPU times: user 2.09 s, sys: 7.53 s, total: 9.62 s
Wall time: 754 msExercise: Let’s do it again, but this time with a smaller array. What is different about the performance?
x = np.random.random(50)%%time
# Before (keep this the same, for comparison)
for _ in range(50):
x = x * 2 * 2 * 2 * 2CPU times: user 262 μs, sys: 308 μs, total: 570 μs
Wall time: 577 μsSolution
%%time
# After (modify this cell)
for _ in range(50):
x = ne.evaluate('x * 2 + 1')CPU times: user 3.63 ms, sys: 0 ns, total: 3.63 ms
Wall time: 3.03 msExercise: Let’s try the smaller case again, but this time pre-compiling the function using the pattern f = ne.NumExpr('x + 1'); f(x). How does this affect the performance?
%%time
# Before (keep this the same, for comparison)
for _ in range(50):
x = x * 2 * 2 * 2 * 2CPU times: user 227 μs, sys: 265 μs, total: 492 μs
Wall time: 498 μsSolution
%%time
# After (modify this cell)
f = ne.NumExpr('x * 2 + 1')
for _ in range(50):
x = f(x)CPU times: user 338 μs, sys: 396 μs, total: 734 μs
Wall time: 1.08 msExercise: numexpr is great at helping with broadcasting, but is not a full numpy replacement. Let’s try it out here, and see if it helps:
x = np.random.random(10_000_000)
%time np.sum(np.sqrt(x))CPU times: user 28.3 ms, sys: 15.9 ms, total: 44.2 ms
Wall time: 44 ms6666012.275613346Solution
x = np.random.random(10_000_000)
%time ne.evaluate("sum(sqrt(x))")CPU times: user 44.2 ms, sys: 0 ns, total: 44.2 ms
Wall time: 44 msarray(6667698.22208743)Section 2: Just-in-Time Compilation with @njit
In Section 1, we improved performance for vectorized array expressions, but not all problems are easy to express with pure NumPy operations.
In neuroscience, we often need:
- Custom spike detection rules
- Trial-by-trial logic
- Nested loops
- Conditional branching
- Simulation steps
Pure Python loops are slow because:
- Every iteration runs through the Python interpreter
- Types are dynamic
- Operations are dispatched at runtime
Numba allows us to compile Python functions into machine code using Just-in-Time (JIT) compilation.
With @njit, we:
- Keep normal Python syntax
- Restrict ourselves to supported features
- Gain near-C speed
Reference
| Code | Description |
|---|---|
from numba import njit |
Import JIT decorator. |
@njit |
Compile function in nopython mode. |
@jit |
Compile but allow Python fallback. |
f.inspect_types() |
Show inferred types. |
Exercises
For exach of the exercises below, we’ll take an existing Python function and add @njit, and compare performance. On the first run @njit compiles the function, on the second run, you’ll see (even more of) the performance benefit.
Exercise: Basic Spike Counter
## Before (keep this the sam, for later comparison)
def count_spikes(x, threshold):
count = 0
for i in range(len(x)):
if x[i] > threshold:
count += 1
return count## Call this cell multiple times and check if the performance changes
x = np.random.random(10_000_000)
%time count_spikes(x, 0.1)CPU times: user 1.3 s, sys: 0 ns, total: 1.3 s
Wall time: 1.3 s9000141Solution
## After (moidfy this cell)
@njit
def count_spikes(x, threshold):
count = 0
for i in range(len(x)):
if x[i] > threshold:
count += 1
return count## Call this cell multiple times and check if the performance changes
x = np.random.random(10_000_000)
%time count_spikes(x, 0.1)CPU times: user 3.22 s, sys: 241 ms, total: 3.46 s
Wall time: 1.47 s9000593Exercise: Sum of Squares
# Before (keep this the same, for later comparison)
def sum_of_squares(x):
total = 0.0
for i in range(len(x)):
total += x[i] * x[i]
return total# Time the function here
%time sum_of_squares(np.ones(100_000))CPU times: user 15 ms, sys: 0 ns, total: 15 ms
Wall time: 14.7 ms100000.0Solution
# After (add @njit to this function)
@njit
def sum_of_squares(x):
total = 0.0
for i in range(len(x)):
total += x[i] * x[i]
return total# Time the function here
%time sum_of_squares(np.ones(100_000))CPU times: user 583 μs, sys: 0 ns, total: 583 μs
Wall time: 468 μs100000.0Exercise: Conditional Logic
# Before (keep this the same, for later comparison)
def weighted_sum(x, threshold):
total = 0.0
for i in range(len(x)):
if x[i] > threshold:
total += x[i]
else:
total -= x[i]
return total# time it here
%time weighted_sum(np.ones(100_000), 1)CPU times: user 13.2 ms, sys: 653 μs, total: 13.8 ms
Wall time: 13.7 ms-100000.0Solution
# After (add @njit to this one)
@njit
def weighted_sum(x, threshold):
total = 0.0
for i in range(len(x)):
if x[i] > threshold:
total += x[i]
else:
total -= x[i]
return total# time it here
%time weighted_sum(np.ones(100_000), 1)CPU times: user 116 ms, sys: 0 ns, total: 116 ms
Wall time: 114 ms-100000.0Section 3: Requesting Parallel Execution with parallel=True and prange
Now we go one step further: we ask Numba to run loops in parallel across CPU cores. This is helpful, for example, if we are computing spike counts independently across 200 recording channels. Each channel can be processed independently — which makes it a good candidate for parallelization.
We do this with:
@njit(parallel=True)prange()instead ofrange()
However:
- Not all loops can be parallelized.
- Parallel overhead is real.
- Numba does not always parallelize even when we request it.
Reference
| Code | Description |
|---|---|
@njit(parallel=True) |
Enable parallel compilation. |
from numba import prange |
Import parallel range. |
prange(n) |
Parallel loop iterator. |
f.parallel_diagnostics(level=4) |
Show parallel analysis report. |
Exercises
Exercise: Add (parallel=True) and prange() to this function. What performance benefit is there?
def sum_seq(x):
total = 0.0
for i in range(len(x)):
total += x[i]
return totalx = np.random.random(50_000_000)
%time sum_seq(x)CPU times: user 5.39 s, sys: 0 ns, total: 5.39 s
Wall time: 5.39 s25001634.642637063Solution
@njit(parallel=True)
def sum_seq(x):
total = 0.0
for i in prange(len(x)):
total += x[i]
return totalx = np.random.random(50_000_000)
%time sum_seq(x)CPU times: user 127 ms, sys: 11.5 ms, total: 139 ms
Wall time: 16 ms24999944.13798241Exercise: Run sum_seq.parallel_diagnostics() and give the output to an AI. Was Numba able to paralllize this code?
Solution
sum_seq.parallel_diagnostics(level=4)
================================================================================
Parallel Accelerator Optimizing: Function sum_seq,
C:\Users\delgr\AppData\Local\Temp\ipykernel_43432\1605433682.py (1)
================================================================================
Parallel loop listing for Function sum_seq, C:\Users\delgr\AppData\Local\Temp\ipykernel_43432\1605433682.py (1)
--------------------------------|loop #ID
@njit(parallel=True) |
def sum_seq(x): |
total = 0.0 |
for i in prange(len(x)):----| #0
total += x[i] |
return total |
--------------------------------- Fusing loops ---------------------------------
Attempting fusion of parallel loops (combines loops with similar properties)...
----------------------------- Before Optimisation ------------------------------
--------------------------------------------------------------------------------
------------------------------ After Optimisation ------------------------------
Parallel structure is already optimal.
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
---------------------------Loop invariant code motion---------------------------
Allocation hoisting:
No allocation hoisting found
Instruction hoisting:
loop #0:
Failed to hoist the following:
dependency: $58binary_subscr.5 = getitem(value=x, index=$parfor__index_8.26, fn=<built-in function getitem>)
dependency: $total_2.29 = inplace_binop(fn=<built-in function iadd>, immutable_fn=<built-in function add>, lhs=total_2, rhs=$58binary_subscr.5, static_lhs=Undefined, static_rhs=Undefined)
dependency: total_2 = $total_2.29
--------------------------------------------------------------------------------Exercise: Add (parallel=True) and prange() to this function. What performance benefit is there?
def dependent_loop(x):
for i in range(1, len(x)):
x[i] = x[i] + x[i - 1]
return xx = np.random.random(5_000_000)
%time dependent_loop(x)CPU times: user 1.57 s, sys: 0 ns, total: 1.57 s
Wall time: 1.57 sarray([7.07890959e-01, 1.13226800e+00, 1.77534046e+00, ...,
2.49940318e+06, 2.49940328e+06, 2.49940378e+06])Solution
@njit(parallel=True)
def dependent_loop(x):
for i in prange(1, len(x)):
x[i] = x[i] + x[i - 1]
return xx = np.random.random(5_000_000)
%time dependent_loop(x)CPU times: user 172 ms, sys: 0 ns, total: 172 ms
Wall time: 12.5 msarray([4.00641515e-01, 4.32420367e-01, 1.09493956e+00, ...,
1.24947539e+05, 1.24948206e+05, 1.24949077e+05])Exercise: Run sum_seq.parallel_diagnostics() and give the output to an AI. Was Numba able to paralllize this code?
Solution
dependent_loop.parallel_diagnostics(level=4)
================================================================================
Parallel Accelerator Optimizing: Function dependent_loop,
/tmp/ipykernel_877473/2706474430.py (1)
================================================================================
Parallel loop listing for Function dependent_loop, /tmp/ipykernel_877473/2706474430.py (1)
-----------------------------------|loop #ID
@njit(parallel=True) |
def dependent_loop(x): |
for i in prange(1, len(x)):----| #1
x[i] = x[i] + x[i - 1] |
return x |
--------------------------------- Fusing loops ---------------------------------
Attempting fusion of parallel loops (combines loops with similar properties)...
----------------------------- Before Optimisation ------------------------------
--------------------------------------------------------------------------------
------------------------------ After Optimisation ------------------------------
Parallel structure is already optimal.
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
---------------------------Loop invariant code motion---------------------------
Allocation hoisting:
No allocation hoisting found
Instruction hoisting:
loop #1:
Has the following hoisted:
$const64.7.1 = const(int, 1)
Failed to hoist the following:
dependency: $56binary_subscr.4 = getitem(value=x, index=$parfor__index_26.40, fn=<built-in function getitem>)
dependency: $binop_sub66.8 = $parfor__index_26.40 - $const64.7.1
dependency: $70binary_subscr.9 = getitem(value=x, index=$binop_sub66.8, fn=<built-in function getitem>)
dependency: $binop_add74.10 = $56binary_subscr.4 + $70binary_subscr.9
--------------------------------------------------------------------------------