Let the Compiler Help Us Compute: Practical Compiled Code for Neuroscience

Let the Compiler Help Us Compute: Practical Compiled Code for Neuroscience

Author
Dr. Nicholas Del Grosso

import numpy as np
import numexpr as ne
import matplotlib.pyplot as plt
from numba import njit, prange

Let the Compiler Help Us Compute: Practical Compiled Code for Neuroscience

Vectorized NumPy code is often fast. The more work we can hand over to compiled functions, the less often they have to interact with Python in order to do their work.

In this notebook, we’ll look at two powerful tools for doing more computational work outside of Python’s runtime:

  • numexpr — to evaluate array expressions more efficiently
  • numba — to compile Python functions into machine code

By the end of this notebook, we should be able to:

  • Recognize when Python overhead limits performance
  • Use numexpr to reduce interpreter overhead in array expressions
  • Use Numba’s @njit decorator to compile loop-based code
  • Understand compilation overhead and when it matters
  • Avoid common pitfalls that prevent compilation

Section 1: Using numexpr to Reduce Python Overhead

NumPy is fast because its core operations run in compiled C code. But every time we write an expression like:

np.arange(10_000_000) + 1 + 2 + 3 + 4 + 5 + 6

Python still:

  • Parses the expression
  • Allocates temporary arrays
  • Manages intermediate results

numexpr lets you create a micro-function that:

  • Minimizes temporary arrays
  • Uses multiple CPU cores automatically

Reference

Code Description
import numexpr as ne Import numexpr.
ne.evaluate(expr) Evaluate an expression using numexpr.
ne.NumExpr(expr) Pre-compile an expression.
f(x=x) Execute a compiled expression.
ne.set_num_threads(n) Control number of threads.

Exercises

Example: Speed up the calculation below with ne.evaluate:

# Before (keep this cell the same, for later comparison)

x = np.random.random(10_000_000)

y = %time x * 2 + 3
y.sum()
CPU times: user 5.96 ms, sys: 47.5 ms, total: 53.4 ms
Wall time: 59.8 ms
39998281.98519501
# After (modify this code cell)

x = np.random.random(10_000_000)

y = %time ne.evaluate('x * 2 + 3')
y.sum()
CPU times: user 23.6 ms, sys: 276 ms, total: 299 ms
Wall time: 23.8 ms
39997663.02705389

Example: Speed up the calculation below with ne.evaluate:

# Before (keep this cell the same, for later comparison)

x = np.random.random(10_000_000)

y = %time 3 * x ** 2 + 2 * x + 1
y.sum()
CPU times: user 44.4 ms, sys: 65 ms, total: 109 ms
Wall time: 109 ms
30004296.01429403
# After (modify this code cell)

x = np.random.random(10_000_000)

y = %time ne.evaluate('3 * x ** 2 + 2 * x + 1')
y.sum()
CPU times: user 84.4 ms, sys: 184 ms, total: 268 ms
Wall time: 20.9 ms
29997648.79665374

Exercise: Speed up the timed calculation below using ne.evaluate():

# Before (keep this the same, for comparison)


x = np.random.random(20_000_000)
y = np.random.random(20_000_000)
z = np.random.random(20_000_000)

total = %time x + y + z

plt.figure(figsize=(10, 1.5))
plt.subplot(1, 4, 1); plt.hist(x, bins=201);
plt.subplot(1, 4, 2); plt.hist(y, bins=201);
plt.subplot(1, 4, 3); plt.hist(z, bins=201);
plt.subplot(1, 4, 4); plt.hist(total, bins=201);
plt.tight_layout();
CPU times: user 27.1 ms, sys: 114 ms, total: 141 ms
Wall time: 141 ms

Solution
# After (modify this cell)

x = np.random.random(20_000_000)
y = np.random.random(20_000_000)
z = np.random.random(20_000_000)

total = %time x + y + z

plt.figure(figsize=(10, 1.5))
plt.subplot(1, 4, 1); plt.hist(x, bins=201);
plt.subplot(1, 4, 2); plt.hist(y, bins=201);
plt.subplot(1, 4, 3); plt.hist(z, bins=201);
plt.subplot(1, 4, 4); plt.hist(total, bins=201);
plt.tight_layout();
CPU times: user 41.7 ms, sys: 16.9 ms, total: 58.6 ms
Wall time: 58.6 ms

Exercise: Speed up the following calculation with `np.evaluate():

x = np.random.random(5_000_000)
%%time
# Before (keep this the same, for comparison)

for _ in range(50):
    x = x * 2 * 2 * 2 * 2
CPU times: user 690 ms, sys: 284 ms, total: 974 ms
Wall time: 972 ms
Solution
%%time
# After (modify this cell)

for _ in range(50):
    x = ne.evaluate('x * 2 * 2 * 2 * 2')
CPU times: user 2.09 s, sys: 7.53 s, total: 9.62 s
Wall time: 754 ms

Exercise: Let’s do it again, but this time with a smaller array. What is different about the performance?

x = np.random.random(50)
%%time
# Before (keep this the same, for comparison)

for _ in range(50):
    x = x * 2 * 2 * 2 * 2
CPU times: user 262 μs, sys: 308 μs, total: 570 μs
Wall time: 577 μs
Solution
%%time
# After (modify this cell)

for _ in range(50):
    x = ne.evaluate('x * 2 + 1')
CPU times: user 3.63 ms, sys: 0 ns, total: 3.63 ms
Wall time: 3.03 ms

Exercise: Let’s try the smaller case again, but this time pre-compiling the function using the pattern f = ne.NumExpr('x + 1'); f(x). How does this affect the performance?

%%time
# Before (keep this the same, for comparison)

for _ in range(50):
    x = x * 2 * 2 * 2 * 2
CPU times: user 227 μs, sys: 265 μs, total: 492 μs
Wall time: 498 μs
Solution
%%time
# After (modify this cell)

f = ne.NumExpr('x * 2 + 1')
for _ in range(50):
    x = f(x)
CPU times: user 338 μs, sys: 396 μs, total: 734 μs
Wall time: 1.08 ms

Exercise: numexpr is great at helping with broadcasting, but is not a full numpy replacement. Let’s try it out here, and see if it helps:

x = np.random.random(10_000_000)

%time np.sum(np.sqrt(x))
CPU times: user 28.3 ms, sys: 15.9 ms, total: 44.2 ms
Wall time: 44 ms
6666012.275613346
Solution
x = np.random.random(10_000_000)

%time ne.evaluate("sum(sqrt(x))")
CPU times: user 44.2 ms, sys: 0 ns, total: 44.2 ms
Wall time: 44 ms
array(6667698.22208743)

Section 2: Just-in-Time Compilation with @njit

In Section 1, we improved performance for vectorized array expressions, but not all problems are easy to express with pure NumPy operations.

In neuroscience, we often need:

  • Custom spike detection rules
  • Trial-by-trial logic
  • Nested loops
  • Conditional branching
  • Simulation steps

Pure Python loops are slow because:

  • Every iteration runs through the Python interpreter
  • Types are dynamic
  • Operations are dispatched at runtime

Numba allows us to compile Python functions into machine code using Just-in-Time (JIT) compilation.

With @njit, we:

  • Keep normal Python syntax
  • Restrict ourselves to supported features
  • Gain near-C speed

Reference

Code Description
from numba import njit Import JIT decorator.
@njit Compile function in nopython mode.
@jit Compile but allow Python fallback.
f.inspect_types() Show inferred types.

Exercises

For exach of the exercises below, we’ll take an existing Python function and add @njit, and compare performance. On the first run @njit compiles the function, on the second run, you’ll see (even more of) the performance benefit.

Exercise: Basic Spike Counter

## Before (keep this the sam, for later comparison)

def count_spikes(x, threshold):
    count = 0
    for i in range(len(x)):
        if x[i] > threshold:
            count += 1
    return count
## Call this cell multiple times and check if the performance changes

x = np.random.random(10_000_000)
%time count_spikes(x, 0.1)
CPU times: user 1.3 s, sys: 0 ns, total: 1.3 s
Wall time: 1.3 s
9000141
Solution
## After (moidfy this cell)

@njit
def count_spikes(x, threshold):
    count = 0
    for i in range(len(x)):
        if x[i] > threshold:
            count += 1
    return count
## Call this cell multiple times and check if the performance changes

x = np.random.random(10_000_000)
%time count_spikes(x, 0.1)
CPU times: user 3.22 s, sys: 241 ms, total: 3.46 s
Wall time: 1.47 s
9000593

Exercise: Sum of Squares

# Before (keep this the same, for later comparison)

def sum_of_squares(x):
    total = 0.0
    for i in range(len(x)):
        total += x[i] * x[i]
    return total
# Time the function here
%time sum_of_squares(np.ones(100_000))
CPU times: user 15 ms, sys: 0 ns, total: 15 ms
Wall time: 14.7 ms
100000.0
Solution
# After (add @njit to this function)
@njit
def sum_of_squares(x):
    total = 0.0
    for i in range(len(x)):
        total += x[i] * x[i]
    return total
# Time the function here
%time sum_of_squares(np.ones(100_000))
CPU times: user 583 μs, sys: 0 ns, total: 583 μs
Wall time: 468 μs
100000.0

Exercise: Conditional Logic

# Before (keep this the same, for later comparison)

def weighted_sum(x, threshold):
    total = 0.0
    for i in range(len(x)):
        if x[i] > threshold:
            total += x[i]
        else:
            total -= x[i]
    return total
# time it here
%time weighted_sum(np.ones(100_000), 1)
CPU times: user 13.2 ms, sys: 653 μs, total: 13.8 ms
Wall time: 13.7 ms
-100000.0
Solution
# After (add @njit to this one)
@njit
def weighted_sum(x, threshold):
    total = 0.0
    for i in range(len(x)):
        if x[i] > threshold:
            total += x[i]
        else:
            total -= x[i]
    return total
# time it here
%time weighted_sum(np.ones(100_000), 1)
CPU times: user 116 ms, sys: 0 ns, total: 116 ms
Wall time: 114 ms
-100000.0

Section 3: Requesting Parallel Execution with parallel=True and prange

Now we go one step further: we ask Numba to run loops in parallel across CPU cores. This is helpful, for example, if we are computing spike counts independently across 200 recording channels. Each channel can be processed independently — which makes it a good candidate for parallelization.

We do this with:

  • @njit(parallel=True)
  • prange() instead of range()

However:

  • Not all loops can be parallelized.
  • Parallel overhead is real.
  • Numba does not always parallelize even when we request it.

Reference

Code Description
@njit(parallel=True) Enable parallel compilation.
from numba import prange Import parallel range.
prange(n) Parallel loop iterator.
f.parallel_diagnostics(level=4) Show parallel analysis report.

Exercises

Exercise: Add (parallel=True) and prange() to this function. What performance benefit is there?

def sum_seq(x):
    total = 0.0
    for i in range(len(x)):
        total += x[i]
    return total
x = np.random.random(50_000_000)
%time sum_seq(x)
CPU times: user 5.39 s, sys: 0 ns, total: 5.39 s
Wall time: 5.39 s
25001634.642637063
Solution
@njit(parallel=True)
def sum_seq(x):
    total = 0.0
    for i in prange(len(x)):
        total += x[i]
    return total
x = np.random.random(50_000_000)
%time sum_seq(x)
CPU times: user 127 ms, sys: 11.5 ms, total: 139 ms
Wall time: 16 ms
24999944.13798241

Exercise: Run sum_seq.parallel_diagnostics() and give the output to an AI. Was Numba able to paralllize this code?

Solution
sum_seq.parallel_diagnostics(level=4)
 
================================================================================
 Parallel Accelerator Optimizing:  Function sum_seq, 
C:\Users\delgr\AppData\Local\Temp\ipykernel_43432\1605433682.py (1)  
================================================================================


Parallel loop listing for  Function sum_seq, C:\Users\delgr\AppData\Local\Temp\ipykernel_43432\1605433682.py (1) 
--------------------------------|loop #ID
@njit(parallel=True)            | 
def sum_seq(x):                 | 
    total = 0.0                 | 
    for i in prange(len(x)):----| #0
        total += x[i]           | 
    return total                | 
--------------------------------- Fusing loops ---------------------------------
Attempting fusion of parallel loops (combines loops with similar properties)...
----------------------------- Before Optimisation ------------------------------
--------------------------------------------------------------------------------
------------------------------ After Optimisation ------------------------------
Parallel structure is already optimal.
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
 
---------------------------Loop invariant code motion---------------------------
Allocation hoisting:
No allocation hoisting found

Instruction hoisting:
loop #0:
  Failed to hoist the following:
    dependency: $58binary_subscr.5 = getitem(value=x, index=$parfor__index_8.26, fn=<built-in function getitem>)
    dependency: $total_2.29 = inplace_binop(fn=<built-in function iadd>, immutable_fn=<built-in function add>, lhs=total_2, rhs=$58binary_subscr.5, static_lhs=Undefined, static_rhs=Undefined)
    dependency: total_2 = $total_2.29
--------------------------------------------------------------------------------

Exercise: Add (parallel=True) and prange() to this function. What performance benefit is there?

def dependent_loop(x):
    for i in range(1, len(x)):
        x[i] = x[i] + x[i - 1]
    return x
x = np.random.random(5_000_000)
%time dependent_loop(x)
CPU times: user 1.57 s, sys: 0 ns, total: 1.57 s
Wall time: 1.57 s
array([7.07890959e-01, 1.13226800e+00, 1.77534046e+00, ...,
       2.49940318e+06, 2.49940328e+06, 2.49940378e+06])
Solution
@njit(parallel=True)
def dependent_loop(x):
    for i in prange(1, len(x)):
        x[i] = x[i] + x[i - 1]
    return x
x = np.random.random(5_000_000)
%time dependent_loop(x)
CPU times: user 172 ms, sys: 0 ns, total: 172 ms
Wall time: 12.5 ms
array([4.00641515e-01, 4.32420367e-01, 1.09493956e+00, ...,
       1.24947539e+05, 1.24948206e+05, 1.24949077e+05])

Exercise: Run sum_seq.parallel_diagnostics() and give the output to an AI. Was Numba able to paralllize this code?

Solution
dependent_loop.parallel_diagnostics(level=4)
 
================================================================================
 Parallel Accelerator Optimizing:  Function dependent_loop, 
/tmp/ipykernel_877473/2706474430.py (1)  
================================================================================


Parallel loop listing for  Function dependent_loop, /tmp/ipykernel_877473/2706474430.py (1) 
-----------------------------------|loop #ID
@njit(parallel=True)               | 
def dependent_loop(x):             | 
    for i in prange(1, len(x)):----| #1
        x[i] = x[i] + x[i - 1]     | 
    return x                       | 
--------------------------------- Fusing loops ---------------------------------
Attempting fusion of parallel loops (combines loops with similar properties)...
----------------------------- Before Optimisation ------------------------------
--------------------------------------------------------------------------------
------------------------------ After Optimisation ------------------------------
Parallel structure is already optimal.
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
 
---------------------------Loop invariant code motion---------------------------
Allocation hoisting:
No allocation hoisting found

Instruction hoisting:
loop #1:
  Has the following hoisted:
    $const64.7.1 = const(int, 1)
  Failed to hoist the following:
    dependency: $56binary_subscr.4 = getitem(value=x, index=$parfor__index_26.40, fn=<built-in function getitem>)
    dependency: $binop_sub66.8 = $parfor__index_26.40 - $const64.7.1
    dependency: $70binary_subscr.9 = getitem(value=x, index=$binop_sub66.8, fn=<built-in function getitem>)
    dependency: $binop_add74.10 = $56binary_subscr.4 + $70binary_subscr.9
--------------------------------------------------------------------------------