Spike Inference From Calcium Traces

Authors
Dr. Sangeetha Nandakumar | Dr. Nicholas Del Grosso

Setup

Import Libraries

#| output: false
import numpy as np
from scipy.signal import convolve, windows
import matplotlib.pyplot as plt
from oasis.functions import deconvolve, gen_data
from scipy.signal import find_peaks

Generate Data

y1, c1, s1 = map(np.squeeze, gen_data(N=1, seed=5, sn=0.1, framerate=10))
y2, c2, s2 = map(np.squeeze, gen_data(N=1, seed=5, sn=0.1, framerate=20))
y3, c3, s3 = map(np.squeeze, gen_data(N=1, seed=5, sn=0.1, framerate=30))

Inferring spikes from calcium imaging data is a key step in understanding neuronal activity. In this notebook we explore how the fast and discrete events of neural spiking are transformed into the slow and continuous signals observed in calcium imaging. We begin by using convolution to simulate how spike trains generate fluorescence signals, helping us understand the shape and dynamics of calcium responses. Then we reverse the process using the OASIS algorithm to extract likely spike times from observed calcium traces. We also learn how to apply thresholding to convert continuous deconvolution output into discrete spike events. Finally, we save these inferred spike times for further analysis.

Section 1: How would a spike train appear as a calcium trace? (Convolution)

In this section, we will see how calcium signals are produced from spikes using a process called convolution. When a neuron fires, it does not just cause a sharp, brief change in the signal. Instead, it produces a smooth, slowly fading signal that we observe in calcium imaging. We simulate this by convolving a spike train with a calcium kernel which a shape that describes how the signal should look after a single spike. This helps us understand how fast spiking activity is transformed into the slower calcium traces we record.

Code Description
windows.boxcar(win_len) Create a boxcar kernel of specified length (win_len).
windows.triang(win_len) Create a triangle kernel of specified length (win_len).
np.exp(-t / tau) Create an exponential decay kernel with decay constant tau.
np.exp(-t / tau_decay) - np.exp(-t / tau_rise) Create a double exponential decay kernel with rise (tau_rise) and decay (tau_decay) times.
kernel_unnorm / kernel_unnorm.sum() Normalize the kernel by dividing by the sum of its elements.
convolve(s1, kernel, mode='full') Convolve the spike train (s1) with the kernel, generating a calcium trace.
plt.plot(kernel) Plot the kernel.

Exercises

Example: How will my spikes look if they were convolved with a boxcar kernel of window size 3?

win_len = 3
kernel_unnorm = windows.boxcar(win_len)
kernel = kernel_unnorm / kernel_unnorm.sum()
convolved_trace = convolve(s1, kernel, mode='full')

plt.subplot(211)
plt.plot(kernel)

plt.subplot(212)
plt.plot(s1)
plt.plot(convolved_trace[:-win_len+1], color='r')
plt.xlim(0, 100)
(0.0, 100.0)

Exercise: How will my spikes look if they were convolved with a triangle kernel of window size 3?

Solution
win_len = 3
kernel_unnorm = windows.triang(win_len)
kernel = kernel_unnorm / kernel_unnorm.sum()
convolved_trace = convolve(s1, kernel, mode='full')

plt.subplot(211)
plt.plot(kernel)

plt.subplot(212)
plt.plot(s1)
plt.plot(convolved_trace[:-win_len+1], color='r')
plt.xlim(0, 100)
(0.0, 100.0)

Exercise: How will my spikes look if they were convolved with a triangle kernel of window size 4?

Solution
win_len = 4
kernel_unnorm = windows.triang(win_len)
kernel = kernel_unnorm / kernel_unnorm.sum()
convolved_trace = convolve(s1, kernel, mode='full')

plt.subplot(211)
plt.plot(kernel)

plt.subplot(212)
plt.plot(s1)
plt.plot(convolved_trace[:-win_len+1], color='r')
plt.xlim(0, 100)
(0.0, 100.0)

Example: How will my spikes look when convolved with an exponential decay kernel with tau of 10 frames and window size of 101 frames?

tau = 10
win_len = 101
t = np.arange(win_len)
kernel_unnorm = np.exp(-t / tau)
kernel = kernel_unnorm / kernel_unnorm.sum()
convolved_trace = convolve(s1, kernel, mode='full')

plt.subplot(211)
plt.plot(kernel)

plt.subplot(212)
plt.plot(s1)
plt.plot(convolved_trace[:-win_len+1], color='r')
plt.xlim(0, 100)
(0.0, 100.0)

Exercise: How will my spikes look when convolved with an exponential decay kernel with tau of 1 frames and window size of 101 frames?

Solution
tau = 1
win_len = 101
t = np.arange(win_len)
kernel_unnorm = np.exp(-t / tau)
kernel = kernel_unnorm / kernel_unnorm.sum()
convolved_trace = convolve(s1, kernel, mode='full')

plt.subplot(211)
plt.plot(kernel)

plt.subplot(212)
plt.plot(s1)
plt.plot(convolved_trace[:-win_len+1], color='r')
plt.xlim(0, 100)
(0.0, 100.0)

Exercise: How will my spikes look when convolved with an exponential decay kernel with tau of 200 frames and window size of 101 frames?

Solution
tau = 201
win_len = 101
t = np.arange(win_len)
kernel_unnorm = np.exp(-t / tau)
kernel = kernel_unnorm / kernel_unnorm.sum()
convolved_trace = convolve(s1, kernel, mode='full')

plt.subplot(211)
plt.plot(kernel)

plt.subplot(212)
plt.plot(s1)
plt.plot(convolved_trace[:-win_len+1], color='r')
plt.xlim(0, 500)
(0.0, 500.0)

Example: How will my spikes look when convolved with a double exponential decay kernel with tau rise of 0.1 frame, tau_decay of 1.5 frames, and window size of 101 frames?

tau_rise = 0.1
tau_decay = 1.5
win_len = 101
t = np.arange(win_len)
kernel_unnorm = np.exp(-t / tau_decay) - np.exp(-t / tau_rise)
kernel = kernel_unnorm / kernel_unnorm.sum()
convolved_trace = convolve(s1, kernel, mode='full')

plt.subplot(211)
plt.plot(kernel)

plt.subplot(212)
plt.plot(s1)
plt.plot(convolved_trace[:-win_len+1], color='r')
plt.xlim(0, 100)
(0.0, 100.0)

Exercise: How will my spikes look when convolved with a double exponential decay kernel with tau rise of 29.9 frame, tau_decay of 30.0 frames, and window size of 101 frames?

Solution
tau_rise = 29.9
tau_decay = 30
win_len = 101
t = np.arange(win_len)
kernel_unnorm = np.exp(-t / tau_decay) - np.exp(-t / tau_rise)
kernel = kernel_unnorm / kernel_unnorm.sum()
convolved_trace = convolve(s1, kernel, mode='full')

plt.subplot(211)
plt.plot(kernel)

plt.subplot(212)
plt.plot(s1)
plt.plot(convolved_trace[:-win_len+1], color='r')
plt.xlim(0, 100)
(0.0, 100.0)

Exercise: How will my spikes look when convolved with a double exponential decay kernel with tau rise of 0.1 frame, tau_decay of 4.0 frames, and window size of 101 frames?

Solution
tau_rise = 0.1
tau_decay = 4
win_len = 101
t = np.arange(win_len)
kernel_unnorm = np.exp(-t / tau_decay) - np.exp(-t / tau_rise)
kernel = kernel_unnorm / kernel_unnorm.sum()
convolved_trace = convolve(s1, kernel, mode='full')

plt.subplot(211)
plt.plot(kernel)

plt.subplot(212)
plt.plot(s1)
plt.plot(convolved_trace[:-win_len+1], color='r')
plt.xlim(0, 100)
(0.0, 100.0)

Section 2: OASIS

Now that we know how spikes generate calcium signals using a kernel, we want to do the reverse: go from calcium traces back to the original spikes. This is done through deconvolution. OASIS is a commonly used algorithm that estimates spike timings by finding a sparse set of events that, when convolved with a known calcium kernel, best matches the observed signal. The output is a continuous signal where higher values suggest stronger or more likely spike events.

Code Description
plt.subplot(211) Set up the first subplot for plotting.
plt.subplot(212) Set up the second subplot for plotting.
deconvolve(y) Apply the deconvolution function to the calcium trace y to infer spikes and baseline.
plt.axhline(baseline) Plot a horizontal line at the estimated baseline value.
deconvolve(y, g=(0.9,)) Apply deconvolution to the calcium trace y with the parameter g=(0.9,) to modify spike inference.
deconvolve(y, g=(1.8, -0.81)) Apply deconvolution to the calcium trace y with the parameter g=(1.8, -0.81) to modify spike inference.

Exercises

Example: Estimate spikes from calcium trace y1.

inferred_trace, estimated_spikes, estimated_baseline, g, _ = deconvolve(y1)

plt.subplot(211)
plt.plot(y1)

plt.subplot(212)
plt.plot(estimated_spikes)

Exercise: Estimate spikes from calcium trace y2 and also plot the estimated baseline.

Solution
inferred_trace, estimated_spikes, estimated_baseline, g, _ = deconvolve(y2)

plt.subplot(211)
plt.plot(y2)
plt.axhline(estimated_baseline, color='r')

plt.subplot(212)
plt.plot(estimated_spikes)

Exercise: Estimate spikes from calcium trace y3 and also add estimated baseline.

Solution
inferred_trace, estimated_spikes, estimated_baseline, g, _ = deconvolve(y3)

plt.subplot(211)
plt.plot(y3)
plt.axhline(estimated_baseline, color='r')


plt.subplot(212)
plt.plot(estimated_spikes)

Example: Give g1 co-efficient as 0.9.

inferred_trace, estimated_spikes, estimated_baseline, g, _ = deconvolve(y3, g=(0.9,))

plt.subplot(211)
plt.plot(y3)

plt.subplot(212)
plt.plot(estimated_spikes)

Exercise: Give g1 co-efficient as 0.1.

Solution
inferred_trace, estimated_spikes, estimated_baseline, g, _ = deconvolve(y3, g=(0.1,))

plt.subplot(211)
plt.plot(y3)

plt.subplot(212)
plt.plot(estimated_spikes)

Exercise: Give g1 co-efficient as 0.99.

Solution
inferred_trace, estimated_spikes, estimated_baseline, g, _ = deconvolve(y3, g=(0.99,))

plt.subplot(211)
plt.plot(y3)

plt.subplot(212)
plt.plot(estimated_spikes)

Exercise: Give g1 co-efficient as 1.8 and g2 as -0.81.

Solution
inferred_trace, estimated_spikes, estimated_baseline, g, _ = deconvolve(y3, g=(1.8, -0.81))

plt.subplot(211)
plt.plot(y3)

plt.subplot(212)
plt.plot(estimated_spikes)

Example: Compare the inferred calcium traces with real calcium trace of y1.

inferred_trace, estimated_spikes, estimated_baseline, g, _ = deconvolve(y1)

plt.subplot(211)
plt.plot(y1)

plt.subplot(212)
plt.plot(inferred_trace)

Exercise: Compare the inferred calcium traces with real calcium trace of y2 with g1 as 0.9.

Solution
inferred_trace, estimated_spikes, estimated_baseline, g, _ = deconvolve(y2, g=(0.90,))

plt.subplot(211)
plt.plot(y2)

plt.subplot(212)
plt.plot(inferred_trace)

Exercise: Compare the inferred calcium traces with real calcium trace of y1 with g1 as 1.8 and g2 as -0.81.

Solution
inferred_trace, estimated_spikes, estimated_baseline, g, _ = deconvolve(y3, g=(1.8, -0.81))

plt.subplot(211)
plt.plot(y3)

plt.subplot(212)
plt.plot(inferred_trace)

Section 3: Thresholding and Spike times

The output from OASIS is continuous, showing how likely or strong each spike might be. But for many kinds of analysis, we need clear events where either a spike happened, or it did not. To do this, we apply a threshold to the OASIS output. Any value above the threshold is considered a spike.

Code Description
np.max(spikes) Get the maximum spike value.
np.percentile(spikes, 95) Find the 95th percentile of spikes.
np.mean(spikes) Calculate the mean of spikes.
np.std(spikes) Compute the standard deviation of spikes.
fr = 10 Set sampling frequency to 10 Hz.
spk_inds = np.where(spikes > threshold)[0] Identify spike indices above threshold.
spk_times = spk_inds / fr Convert spike indices to times.
plt.eventplot(spk_times) Plot spike times as events.
find_peaks(spikes, height=0.5, distance=5) Detect peaks in spikes with height > 0.5 and distance > 5.

Exercises

_, inferred_spikes1, _, _, _ = deconvolve(y1)
_, inferred_spikes2, _, _, _ = deconvolve(y2)
_, inferred_spikes3, _, _, _ = deconvolve(y3)

Example: For y1, find spike times of all spikes with amplitude larger than 1.0.

threshold = 1.0
fr = 10
spk_inds = np.where(inferred_spikes1 > threshold)[0]
spk_times = spk_inds / fr
plt.eventplot(spk_times)

Exercise: For y1, find spike times of all spikes with amplitude larger than 0.01.

Solution
threshold = 0.01
fr = 10
spk_inds = np.where(inferred_spikes1 > threshold)[0]
spk_times = spk_inds / fr
plt.eventplot(spk_times)

Exercise: For y1, find spike times of all spikes with amplitude larger than 0.5.

Solution
threshold = 0.5
fr = 10
spk_inds = np.where(inferred_spikes1 > threshold)[0]
spk_times = spk_inds / fr
plt.eventplot(spk_times)

Example: For y1, set threshold to be higher than 10% of maximum amplitude.

threshold = 0.1 * np.max(inferred_spikes1)
fr = 10
spk_inds = np.where(inferred_spikes1 > threshold)[0]
spk_times = spk_inds / fr
plt.eventplot(spk_times)

Exercise: For y1, set threshold to be higher than 95th-percentile of the amplitudes.

Solution
threshold = np.percentile(inferred_spikes1, 95)
fr = 10
spk_inds = np.where(inferred_spikes1 > threshold)[0]
spk_times = spk_inds / fr
plt.eventplot(spk_times)

Exercise: For y1, set threshold to be higher than three-sigma.

Solution
threshold = np.mean(inferred_spikes1) + 3*np.std(inferred_spikes1)
fr = 10
spk_inds = np.where(inferred_spikes1 > threshold)[0]
spk_times = spk_inds / fr
plt.eventplot(spk_times)

Example: For y1, only get spike times for spikes with amplitudes larger than 0.5 with minimum distance of at least 5 frames.

spk_inds, properties = find_peaks(inferred_spikes1, height=0.5, distance=5)
fr = 10
spk_times = spk_inds / fr
plt.eventplot(spk_times)

Exercise: For y1, only get spike times for spikes with amplitudes larger than 0.5 with minimum distance of at least 100 frames.

Solution
spk_inds, properties = find_peaks(inferred_spikes1, height=0.5, distance=100)
fr = 10
spk_times = spk_inds / fr
plt.eventplot(spk_times)

Exercise: For y1, only get spike times for spikes with amplitudes larger than 0.5 with minimum distance of at least 10 frames.

Solution
spk_inds, properties = find_peaks(inferred_spikes1, height=0.4, distance=10)
fr = 10
spk_times = spk_inds / fr
plt.eventplot(spk_times)

Section 4: Saving Timestamped Data

Once we have identified when spikes likely occurred by thresholding, we save the corresponding time points. These timestamped events are useful for further analysis, such as comparing activity across cells, aligning activity to behavioral events, or building summary statistics. In this section, we will save these spike times as an array of indices or timestamps.

Code Description
np.save('spks1.npy', spk_times1) Save spike times for neuron 1 to a .npy file.
spk1 = np.load('spks1.npy') Load spike times for neuron 1 from a .npy file.
plt.eventplot(spk2) Plot spike times for neuron 2 as an event plot.
spks = np.array([spk_times1, spk_times2], dtype=object) Create an array of spike times for multiple neurons.
spks = np.load('spks_1_2.npy', allow_pickle=True) Load spike times for multiple neurons from a .npy file.
plt.eventplot(spks[1]) Plot spike times for neuron 2 from the loaded data.

Exercises

threshold = np.percentile(inferred_spikes1, 95)
fr = 10
spk_inds = np.where(inferred_spikes1 > threshold)[0]
spk_times1 = spk_inds / fr
threshold = np.percentile(inferred_spikes2, 95)
fr = 20
spk_inds = np.where(inferred_spikes2 > threshold)[0]
spk_times2 = spk_inds / fr
threshold = np.percentile(inferred_spikes3, 95)
fr = 30
spk_inds = np.where(inferred_spikes3 > threshold)[0]
spk_times3 = spk_inds / fr

Example: Save spk_times1 as spk1.npy.

np.save('spk1.npy', spk_times1)

Exercise: Save spk_times2 as spk2.npy.

Solution
np.save('spk2.npy', spk_times2)

Exercise: Save spk_times3 as spk3.npy.

Solution
np.save('spk3.npy', spk_times3)

Example: Load spk1.npy.

spk1 = np.load('spk1.npy')
plt.eventplot(spk1)

Exercise: Load spk2.npy.

Solution
spk2 = np.load('spk2.npy')
plt.eventplot(spk2)

Exercise: Load spk3.npy.

Solution
spk3 = np.load('spk3.npy')
plt.eventplot(spk3)

Example: Save spk_times1 and spk_times1 together as spk_1_2.npy.

spks = np.array([spk_times1, spk_times2], dtype=object)
np.save('spk_1_2.npy', spks)

Exercise: Save spk_times2 and spk_times3 together as spk_2_3.npy.

Solution
spks = np.array([spk_times2, spk_times3], dtype=object)
np.save('spk_2_3.npy', spks)

Exercise: Save spk_times1, spk_times2, and spk_times3 together as spk_1_2_3.npy.

Solution
spks = np.array([spk_times1, spk_times2, spk_times3], dtype=object)
np.save('spks.npy', spks)

Example: Load spk_1_2.npy and plot the events from first neuron.

spks = np.load('spk_1_2.npy', allow_pickle=True)
plt.eventplot(spks[0])

Exercise: Load spk_1_2.npy and plot the events from second neuron.

Solution
spks = np.load('spk_2_3.npy', allow_pickle=True)
plt.eventplot(spks[1])

Exercise: Load spk.npy and plot the events from the last neuron.

Solution
spks = np.load('spks.npy', allow_pickle=True)
plt.eventplot(spks[2])