Let the Compiler Help Us Compute: Practical Compiled Code for Neuroscience

Let the Compiler Help Us Compute: Practical Compiled Code for Neuroscience

Author
Dr. Nicholas Del Grosso

import numpy as np
import numexpr as ne
import matplotlib.pyplot as plt
from numba import njit, prange

Let the Compiler Help Us Compute: Practical Compiled Code for Neuroscience

Vectorized NumPy code is often fast. The more work we can hand over to compiled functions, the less often they have to interact with Python in order to do their work.

In this notebook, we’ll look at two powerful tools for doing more computational work outside of Python’s runtime:

  • numexpr — to evaluate array expressions more efficiently
  • numba — to compile Python functions into machine code

By the end of this notebook, we should be able to:

  • Recognize when Python overhead limits performance
  • Use numexpr to reduce interpreter overhead in array expressions
  • Use Numba’s @njit decorator to compile loop-based code
  • Understand compilation overhead and when it matters
  • Avoid common pitfalls that prevent compilation

Section 1: Using numexpr to Reduce Python Overhead

NumPy is fast because its core operations run in compiled C code. But every time we write an expression like:

np.arange(10_000_000) + 1 + 2 + 3 + 4 + 5 + 6

Python still:

  • Parses the expression
  • Allocates temporary arrays
  • Manages intermediate results

numexpr lets you create a micro-function that:

  • Minimizes temporary arrays
  • Uses multiple CPU cores automatically

Reference

Code Description
import numexpr as ne Import numexpr.
ne.evaluate(expr) Evaluate an expression using numexpr.
ne.NumExpr(expr) Pre-compile an expression.
f(x=x) Execute a compiled expression.
ne.set_num_threads(n) Control number of threads.

Exercises

Example: Speed up the calculation below with ne.evaluate:

# Before (keep this cell the same, for later comparison)

x = np.random.random(10_000_000)

y = %time x * 2 + 3
y.sum()
CPU times: user 13.7 ms, sys: 90.3 ms, total: 104 ms
Wall time: 104 ms
np.float64(39998907.27311504)
# After (modify this code cell)

x = np.random.random(10_000_000)

y = %time ne.evaluate('x * 2 + 3')
y.sum()
CPU times: user 75.5 ms, sys: 399 ms, total: 475 ms
Wall time: 35.5 ms
np.float64(40004385.33772915)

Example: Speed up the calculation below with ne.evaluate:

# Before (keep this cell the same, for later comparison)

x = np.random.random(10_000_000)

y = %time 3 * x ** 2 + 2 * x + 1
y.sum()
CPU times: user 72.2 ms, sys: 77.7 ms, total: 150 ms
Wall time: 150 ms
np.float64(30000783.167188667)
# After (modify this code cell)

x = np.random.random(10_000_000)

y = %time ne.evaluate('3 * x ** 2 + 2 * x + 1')
y.sum()
CPU times: user 96.4 ms, sys: 275 ms, total: 371 ms
Wall time: 32.1 ms
np.float64(30000495.308897942)

Exercise: Speed up the timed calculation below using ne.evaluate():

# Before (keep this the same, for comparison)


x = np.random.random(20_000_000)
y = np.random.random(20_000_000)
z = np.random.random(20_000_000)

total = %time x + y + z

plt.figure(figsize=(10, 1.5))
plt.subplot(1, 4, 1); plt.hist(x, bins=201);
plt.subplot(1, 4, 2); plt.hist(y, bins=201);
plt.subplot(1, 4, 3); plt.hist(z, bins=201);
plt.subplot(1, 4, 4); plt.hist(total, bins=201);
plt.tight_layout();
CPU times: user 71.2 ms, sys: 125 ms, total: 196 ms
Wall time: 196 ms

Solution
# After (modify this cell)

x = np.random.random(20_000_000)
y = np.random.random(20_000_000)
z = np.random.random(20_000_000)

total = %time x + y + z

plt.figure(figsize=(10, 1.5))
plt.subplot(1, 4, 1); plt.hist(x, bins=201);
plt.subplot(1, 4, 2); plt.hist(y, bins=201);
plt.subplot(1, 4, 3); plt.hist(z, bins=201);
plt.subplot(1, 4, 4); plt.hist(total, bins=201);
plt.tight_layout();
CPU times: user 95 ms, sys: 32 ms, total: 127 ms
Wall time: 127 ms

Exercise: Speed up the following calculation with `np.evaluate():

x = np.random.random(5_000_000)
%%time
# Before (keep this the same, for comparison)

for _ in range(50):
    x = x * 2 * 2 * 2 * 2
CPU times: user 1.23 s, sys: 641 ms, total: 1.87 s
Wall time: 1.86 s
Solution
%%time
# After (modify this cell)

for _ in range(50):
    x = ne.evaluate('x * 2 * 2 * 2 * 2')
CPU times: user 4.19 s, sys: 21.1 s, total: 25.3 s
Wall time: 1.77 s

Exercise: Let’s do it again, but this time with a smaller array. What is different about the performance?

x = np.random.random(50)
%%time
# Before (keep this the same, for comparison)

for _ in range(50):
    x = x * 2 * 2 * 2 * 2
CPU times: user 1.47 ms, sys: 0 ns, total: 1.47 ms
Wall time: 1.5 ms
Solution
%%time
# After (modify this cell)

for _ in range(50):
    x = ne.evaluate('x * 2 + 1')
CPU times: user 832 μs, sys: 5.14 ms, total: 5.97 ms
Wall time: 5.19 ms

Exercise: Let’s try the smaller case again, but this time pre-compiling the function using the pattern f = ne.NumExpr('x + 1'); f(x). How does this affect the performance?

%%time
# Before (keep this the same, for comparison)

for _ in range(50):
    x = x * 2 * 2 * 2 * 2
CPU times: user 620 μs, sys: 0 ns, total: 620 μs
Wall time: 627 μs
Solution
%%time
# After (modify this cell)

f = ne.NumExpr('x * 2 + 1')
for _ in range(50):
    x = f(x)
CPU times: user 353 μs, sys: 448 μs, total: 801 μs
Wall time: 820 μs

Exercise: numexpr is great at helping with broadcasting, but is not a full numpy replacement. Let’s try it out here, and see if it helps:

x = np.random.random(10_000_000)

%time np.sum(np.sqrt(x))
CPU times: user 31.6 ms, sys: 88.5 ms, total: 120 ms
Wall time: 119 ms
np.float64(6665823.263745766)
Solution
x = np.random.random(10_000_000)

%time ne.evaluate("sum(sqrt(x))")
CPU times: user 71.8 ms, sys: 0 ns, total: 71.8 ms
Wall time: 73.5 ms
array(6665720.97519577)

Section 2: Just-in-Time Compilation with @njit

In Section 1, we improved performance for vectorized array expressions, but not all problems are easy to express with pure NumPy operations.

In neuroscience, we often need:

  • Custom spike detection rules
  • Trial-by-trial logic
  • Nested loops
  • Conditional branching
  • Simulation steps

Pure Python loops are slow because:

  • Every iteration runs through the Python interpreter
  • Types are dynamic
  • Operations are dispatched at runtime

Numba allows us to compile Python functions into machine code using Just-in-Time (JIT) compilation.

With @njit, we:

  • Keep normal Python syntax
  • Restrict ourselves to supported features
  • Gain near-C speed

Reference

Code Description
from numba import njit Import JIT decorator.
@njit Compile function in nopython mode.
@jit Compile but allow Python fallback.
f.inspect_types() Show inferred types.

Exercises

For exach of the exercises below, we’ll take an existing Python function and add @njit, and compare performance. On the first run @njit compiles the function, on the second run, you’ll see (even more of) the performance benefit.

Exercise: Basic Spike Counter

## Before (keep this the sam, for later comparison)

def count_spikes(x, threshold):
    count = 0
    for i in range(len(x)):
        if x[i] > threshold:
            count += 1
    return count
## Call this cell multiple times and check if the performance changes

x = np.random.random(10_000_000)
%time count_spikes(x, 0.1)
CPU times: user 2.01 s, sys: 0 ns, total: 2.01 s
Wall time: 2.01 s
9000080
Solution
## After (moidfy this cell)

@njit
def count_spikes(x, threshold):
    count = 0
    for i in range(len(x)):
        if x[i] > threshold:
            count += 1
    return count
## Call this cell multiple times and check if the performance changes

x = np.random.random(10_000_000)
%time count_spikes(x, 0.1)
CPU times: user 3.1 s, sys: 240 ms, total: 3.33 s
Wall time: 966 ms
9000304

Exercise: Sum of Squares

# Before (keep this the same, for later comparison)

def sum_of_squares(x):
    total = 0.0
    for i in range(len(x)):
        total += x[i] * x[i]
    return total
# Time the function here
%time sum_of_squares(np.ones(100_000))
CPU times: user 42.2 ms, sys: 1.19 ms, total: 43.3 ms
Wall time: 42.3 ms
np.float64(100000.0)
Solution
# After (add @njit to this function)
@njit
def sum_of_squares(x):
    total = 0.0
    for i in range(len(x)):
        total += x[i] * x[i]
    return total
# Time the function here
%time sum_of_squares(np.ones(100_000))
CPU times: user 115 ms, sys: 20.8 ms, total: 136 ms
Wall time: 134 ms
100000.0

Exercise: Conditional Logic

# Before (keep this the same, for later comparison)

def weighted_sum(x, threshold):
    total = 0.0
    for i in range(len(x)):
        if x[i] > threshold:
            total += x[i]
        else:
            total -= x[i]
    return total
# time it here
%time weighted_sum(np.ones(100_000), 1)
CPU times: user 35.8 ms, sys: 3.08 ms, total: 38.9 ms
Wall time: 37.8 ms
np.float64(-100000.0)
Solution
# After (add @njit to this one)
@njit
def weighted_sum(x, threshold):
    total = 0.0
    for i in range(len(x)):
        if x[i] > threshold:
            total += x[i]
        else:
            total -= x[i]
    return total
# time it here
%time weighted_sum(np.ones(100_000), 1)
CPU times: user 153 ms, sys: 5.6 ms, total: 159 ms
Wall time: 157 ms
-100000.0

Section 3: Requesting Parallel Execution with parallel=True and prange

Now we go one step further: we ask Numba to run loops in parallel across CPU cores. This is helpful, for example, if we are computing spike counts independently across 200 recording channels. Each channel can be processed independently — which makes it a good candidate for parallelization.

We do this with:

  • @njit(parallel=True)
  • prange() instead of range()

However:

  • Not all loops can be parallelized.
  • Parallel overhead is real.
  • Numba does not always parallelize even when we request it.

Reference

Code Description
@njit(parallel=True) Enable parallel compilation.
from numba import prange Import parallel range.
prange(n) Parallel loop iterator.
f.parallel_diagnostics(level=4) Show parallel analysis report.

Exercises

Exercise: Add (parallel=True) and prange() to this function. What performance benefit is there?

def sum_seq(x):
    total = 0.0
    for i in range(len(x)):
        total += x[i]
    return total
x = np.random.random(50_000_000)
%time sum_seq(x)
CPU times: user 10.3 s, sys: 0 ns, total: 10.3 s
Wall time: 10.2 s
np.float64(24996706.24008694)
Solution
@njit(parallel=True)
def sum_seq(x):
    total = 0.0
    for i in prange(len(x)):
        total += x[i]
    return total
x = np.random.random(50_000_000)
%time sum_seq(x)
CPU times: user 1.41 s, sys: 108 ms, total: 1.52 s
Wall time: 825 ms
24999219.29505931

Exercise: Run sum_seq.parallel_diagnostics() and give the output to an AI. Was Numba able to paralllize this code?

Solution
sum_seq.parallel_diagnostics(level=4)
 
================================================================================
 Parallel Accelerator Optimizing:  Function sum_seq, 
/tmp/ipykernel_838177/1605433682.py (1)  
================================================================================


Parallel loop listing for  Function sum_seq, /tmp/ipykernel_838177/1605433682.py (1) 
--------------------------------|loop #ID
@njit(parallel=True)            | 
def sum_seq(x):                 | 
    total = 0.0                 | 
    for i in prange(len(x)):----| #0
        total += x[i]           | 
    return total                | 
--------------------------------- Fusing loops ---------------------------------
Attempting fusion of parallel loops (combines loops with similar properties)...
----------------------------- Before Optimisation ------------------------------
--------------------------------------------------------------------------------
------------------------------ After Optimisation ------------------------------
Parallel structure is already optimal.
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
 
---------------------------Loop invariant code motion---------------------------
Allocation hoisting:
No allocation hoisting found

Instruction hoisting:
loop #0:
  Failed to hoist the following:
    dependency: $58binary_subscr.5 = getitem(value=x, index=$parfor__index_2.20, fn=<built-in function getitem>)
    dependency: $total_2.23 = inplace_binop(fn=<built-in function iadd>, immutable_fn=<built-in function add>, lhs=total_2, rhs=$58binary_subscr.5, static_lhs=Undefined, static_rhs=Undefined)
    dependency: total_2 = $total_2.23
--------------------------------------------------------------------------------

Exercise: Add (parallel=True) and prange() to this function. What performance benefit is there?

def dependent_loop(x):
    for i in range(1, len(x)):
        x[i] = x[i] + x[i - 1]
    return x
x = np.random.random(5_000_000)
%time dependent_loop(x)
CPU times: user 2.02 s, sys: 0 ns, total: 2.02 s
Wall time: 2.01 s
array([1.31466274e-01, 9.42240651e-01, 1.88078036e+00, ...,
       2.49941364e+06, 2.49941389e+06, 2.49941395e+06], shape=(5000000,))
Solution
@njit(parallel=True)
def dependent_loop(x):
    for i in prange(1, len(x)):
        x[i] = x[i] + x[i - 1]
    return x
x = np.random.random(5_000_000)
%time dependent_loop(x)
CPU times: user 655 ms, sys: 14.6 ms, total: 670 ms
Wall time: 532 ms
array([6.39768101e-02, 9.46011035e-01, 1.76934503e+00, ...,
       1.24927298e+05, 1.24927854e+05, 1.24928719e+05], shape=(5000000,))

Exercise: Run sum_seq.parallel_diagnostics() and give the output to an AI. Was Numba able to paralllize this code?

Solution
dependent_loop.parallel_diagnostics(level=4)
 
================================================================================
 Parallel Accelerator Optimizing:  Function dependent_loop, 
/tmp/ipykernel_838177/2706474430.py (1)  
================================================================================


Parallel loop listing for  Function dependent_loop, /tmp/ipykernel_838177/2706474430.py (1) 
-----------------------------------|loop #ID
@njit(parallel=True)               | 
def dependent_loop(x):             | 
    for i in prange(1, len(x)):----| #1
        x[i] = x[i] + x[i - 1]     | 
    return x                       | 
--------------------------------- Fusing loops ---------------------------------
Attempting fusion of parallel loops (combines loops with similar properties)...
----------------------------- Before Optimisation ------------------------------
--------------------------------------------------------------------------------
------------------------------ After Optimisation ------------------------------
Parallel structure is already optimal.
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
 
---------------------------Loop invariant code motion---------------------------
Allocation hoisting:
No allocation hoisting found

Instruction hoisting:
loop #1:
  Has the following hoisted:
    $const60.7.1 = const(int, 1)
  Failed to hoist the following:
    dependency: $54binary_subscr.4 = getitem(value=x, index=$parfor__index_26.40, fn=<built-in function getitem>)
    dependency: $binop_sub62.8 = $parfor__index_26.40 - $const60.7.1
    dependency: $66binary_subscr.9 = getitem(value=x, index=$binop_sub62.8, fn=<built-in function getitem>)
    dependency: $binop_add70.10 = $54binary_subscr.4 + $66binary_subscr.9
--------------------------------------------------------------------------------