Mining Statistical Patterns with SPADE
Authors
Setup
Import Modules
import numpy as np
from matplotlib import pyplot as plt
import quantities as pq
from elephant.spike_train_generation import compound_poisson_process, homogeneous_poisson_process
from elephant.spade import spade
from viziphant.patterns import plot_patterns
np.random.seed(100)
%matplotlib inlineDefine the utility functions required for this notebook.
class utils:
@staticmethod
def find_synchronous_spikes(spike_trains):
"""
Find the synchronous spikes in a list of spike trains.
Arguments:
sts (List of SpikeTrain): list of spike train objects.
Returns:
(np.ndarray): 1-dimensional array of the synchronous spike times (times are repeated for each synchronous spikes)
(np.ndarray): 1-dimensional array with the indices of the spike trains containing the synchronous spikes
"""
all_spikes = np.concatenate([spike_train.times for spike_train in spike_trains])
all_trains = np.concatenate(
[[i] * len(spike_train.times) for i, spike_train in enumerate(spike_trains)]
)
times = []
units = []
for s in np.unique(all_spikes):
idx = np.where(all_spikes == s)[0]
if len(idx) > 1:
times.append(all_spikes[idx])
units.append(all_trains[idx])
if len(times) > 0:
times = np.concatenate(times)
units = np.concatenate(units)
else:
times = np.array([])
units = np.array([])
print("Found no synchronous spikes")
return times, unitsSection 1: Simulating Synchronous Spiking
In this section, we are going to explore Spike Pattern Detection and Evaluation (SPADE) which is a method to find recurring patterns of synchronous firing in large numbers of neurons. We are going to explore SPADE and its parameters by simulating data using a compound Poisson process. This simulation defines a baseline Poisson process that fires randomly and then makes the baseline process trigger spikes in other neurons, causing synchronous events. In this section, we are going to explore how this simulation works.
| Code | Description |
|---|---|
sts = compound_poission_process(rate, amplitude_distribution, t_stop) |
Generate a list of spike trains from a compound poisson process with a given rate and amplitude_distribution that determines the probability of synchronous spikes. Each spike train starts at time 0 and goes to t_stop |
rasterplot(sts) |
Create a raster plot for a list of spike trains |
x,y = find_synchronous_spikes(spiketrains) |
Returns the times x and indices y of synchronous spikes in a list of spiketrains |
plt.eventplot(x) |
Plot the events in x |
plt.scatter(x,y, color="red") |
Plot the data x, y in a "red" scatterplot |
Exercises
Example: Generate 6 spike trains (1 less then the length of the amplitude_distribution) with a firing rate of 5 Hz and a duration of 10 s from a compound_poisson_process where the probability of synchronous spikes in all 6 spike trains is 0.01 (1%). Print the length of the returned list of spike trains sts.
amplitude_distribution = [0, 0.99, 0, 0, 0, 0, 0.01]
sts = compound_poisson_process(
rate=5 * pq.Hz, amplitude_distribution=amplitude_distribution, t_stop=10 * pq.s
)
len(sts)6Example: Find the time points x and the spike train indices y of spikes in the spike train list sts that occur synchronously in multiple trains. Then use plt.eventplot to create a rasterplot of the spike trains sts and mark the synchronous spikes x, y in red.
x, y = utils.find_synchronous_spikes(sts)
plt.eventplot([st.times for st in sts], color="black")
plt.scatter(x, y, color="red")Exercise: Generate spike trains with a firing rate of 5 Hz and a duration of 10 s from a compound_poisson_process with the amplitude_distribution defined below. How many spike trains are generated by this simulation and what is the probability of synchronous spikes in all neurons?
amplitude_distribution = [0, 0.98, 0, 0, 0, 0, 0, 0, 0.02]Solution
sts = compound_poisson_process(
rate=5 * pq.Hz, amplitude_distribution=amplitude_distribution, t_stop=10 * pq.s
)
len(sts)8Exercise: Generate spike trains with a firing rate of 5 Hz and a duration of 10 s from a compound_poisson_process with the amplitude_distribution defined below. How many spike trains are generated by this simulation and what is the probability of synchronous spikes in 3 neurons?
amplitude_distribution = [0, 0.95, 0, 0.04, 0, 0.01]Solution
sts = compound_poisson_process(
rate=5 * pq.Hz, amplitude_distribution=amplitude_distribution, t_stop=10 * pq.s
)
len(sts)5Exercise: Find the time points x and the spike train indices y of spikes in the spike train list sts that occur synchronously in multiple trains. Then use plt.eventplot to create a rasterplot of the spike trains sts and mark the synchronous spikes x, y in red.
Solution
x, y = utils.find_synchronous_spikes(sts)
plt.eventplot([st.times for st in sts], color="black")
plt.scatter(x, y, color="red")Exercise: Generate 7 spike trains with a firing rate of 5 Hz and a duration of 10 s from a compound_poisson_process where the probability of synchronous spikes in all 7 spike trains is 0.1 (10%). Then, plot the spike trains and mark the synchronous spikes in red.
Solution
amplitude_distribution = [0, 0.9, 0, 0, 0, 0, 0, 0.1]
sts = compound_poisson_process(
rate=5 * pq.Hz, amplitude_distribution=amplitude_distribution, t_stop=10 * pq.s
)
x, y = utils.find_synchronous_spikes(sts)
plt.eventplot([st.times for st in sts], color="black")
plt.scatter(x, y, color="red")Exercise: Generate 4 spike trains with a firing rate of 5 Hz and a duration of 10 s from a compound_poisson_process where the probability of synchronous spikes in 2 spike trains is 0.05 (5%) and the probability of a synchronous spike in 3 spike trains is 0.01 (1%). Then, plot the spike trains and mark the synchronous spikes in red
Solution
amplitude_distribution=[0, 0.94, 0.05, 0.01, 0]
sts = compound_poisson_process(
rate=5 * pq.Hz, amplitude_distribution=amplitude_distribution, t_stop=10 * pq.s
)
print(len(sts))4Section 2: Finding and Visualizing Patterns with Spade
Now we can apply SPADE to the simulated data. SPADE utilizes an algorithm called frequent itemset mining which is a data mining technique used to discover patterns in large datasets by identifying frequently occurring sets of items. Each pattern is assigned a signature that consist of two values: the number of items contained in that pattern and the number of occurrences of that pattern. In this section, we are going to apply SPADE to the simulated data and evaluate the patterns that it is detecting.
| Code | Description |
|---|---|
results = spade(sts, binsize, winlen) |
Run spade on the spike trains sts with the given binsize and winlen |
results = spade(sts, binsize, winlen, min_occ=3) |
Run spade on the spike trains sts but only consider patterns that occur at least 3 times |
results = spade(sts, binsize, winlen, min_neu=3) |
Run spade on the spike trains sts but only consider patterns that contain at least 3 neurons |
patterns = results["patterns"] |
Get the detected "patterns" from the results |
len(patterns) |
Get the number of detected patterns |
sig = [p["signature"] for p in patterns] |
Get a list with the "signature" of every pattern |
np.stack(sig) |
Stack the list sig into a 2D numpy array |
idx = np.argmax(sig[:,0]) |
Get the index where the first column of sig is maximal |
plot_pattern(sts, patterns[0]) |
Plot the spiketrains sts and highlight the first pattern in patterns |
Run the code below to generate 10 spike trains with synchronous spikes from a compound_poisson_process and add 90 purely random spike trains from a homogeneous_poisson_process.
Exercises
rate = 3 * pq.Hz
t_stop = 15 * pq.s
sts = compound_poisson_process(
rate=rate,
amplitude_distribution=[0, 0.92, 0, 0, 0, 0, 0, 0, 0, 0, 0.08],
t_stop=t_stop,
)
for i in range(90):
sts.append(homogeneous_poisson_process(rate=rate, t_stop=t_stop))
f"Number of spike trains: {len(sts)}"'Number of spike trains: 100'Example: Apply spade() to th simulated spike trains sts with binsize=5*pq.ms and winlen=1 and get the detected list of detected "patterns" and print the first element.
results = spade(spiketrains=sts, binsize=5 * pq.ms, winlen=1)
patterns = results["patterns"]
patterns[0]Time for data mining: 0.22992444038391113{'itemset': (63, 99),
'windows_ids': (615, 1127),
'neurons': [63, 99],
'lags': array([0.]) * ms,
'times': array([3075., 5635.]) * ms,
'signature': (2, 2),
'pvalue': -1}Exercise: How many patterns were detected (i.e. what is the length of patterns)
Solution
len(patterns)694Exercise: Use plot_patterns to plot the first pattern in patterns.
Solution
plot_patterns(sts, patterns[0])Example: Get the "signature" of every pattern p in patterns stack them into one numpy array. Then print the signatures. Each row represents one pattern. The first column indicates how many neurons are part of this pattern, the second column indicates how often this pattern occurs.
sig = np.stack([p["signature"] for p in patterns])
sigarray([[ 2, 2],
[ 2, 2],
[ 2, 2],
...,
[ 2, 4],
[ 2, 22],
[ 2, 3]])Exercise: Use np.argmax on the second column of sig to find the pattern that occurred the most often. Get it from the list of patterns and print it
Solution
idx = np.argmax(sig[:,1])
patterns[idx]{'itemset': (2, 5),
'windows_ids': (25,
166,
174,
283,
346,
551,
667,
737,
959,
1063,
1135,
1242,
1292,
1460,
1716,
1988,
2147,
2317,
2362,
2408,
2643,
2815),
'neurons': [2, 5],
'lags': array([0.]) * ms,
'times': array([ 125., 830., 870., 1415., 1730., 2755., 3335., 3685.,
4795., 5315., 5675., 6210., 6460., 7300., 8580., 9940.,
10735., 11585., 11810., 12040., 13215., 14075.]) * ms,
'signature': (2, 22),
'pvalue': -1}Exercise: Use plot_patterns to plot the pattern that occurred most often.
Solution
plot_patterns(sts, patterns[idx])Exercise: Use np.argmax on the first column of sig to find the pattern that contains the most neurons. Get it from the list of patterns and print it.
Solution
idx = np.argmax(sig[:, 0])
patterns[idx]{'itemset': (0, 2, 5, 9, 4, 7, 3, 16, 6, 8, 1),
'windows_ids': (1988, 2408),
'neurons': [0, 2, 5, 9, 4, 7, 3, 16, 6, 8, 1],
'lags': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) * ms,
'times': array([ 9940., 12040.]) * ms,
'signature': (11, 2),
'pvalue': -1}Exercise: Use plot_patterns to plot the pattern that contains the most neurons.
Solution
plot_patterns(sts, patterns[idx])Exercise: Compute the product of both columns of sig and apply np.argmax() to the result to get the pattern where product of the patterns number of occurrences and number of neurons is maximal. Then, get that pattern from the list of patterns and print it. Which "neurons" are contained in this pattern?
Solution
idx = np.argmax(sig[:, 0] * sig[:, 1])
patterns[idx]{'itemset': (0, 2, 5, 9, 4, 7, 3, 6, 8, 1),
'windows_ids': (25,
166,
174,
283,
346,
551,
667,
737,
959,
1063,
1135,
1242,
1292,
1460,
1716,
1988,
2147,
2362,
2408,
2643,
2815),
'neurons': [0, 2, 5, 9, 4, 7, 3, 6, 8, 1],
'lags': array([0., 0., 0., 0., 0., 0., 0., 0., 0.]) * ms,
'times': array([ 125., 830., 870., 1415., 1730., 2755., 3335., 3685.,
4795., 5315., 5675., 6210., 6460., 7300., 8580., 9940.,
10735., 11810., 12040., 13215., 14075.]) * ms,
'signature': (10, 21),
'pvalue': -1}Exercise: Use plot_patterns to plot that pattern.
Solution
plot_patterns(sts, patterns[idx])Exercise: Rerun spade() with min_occ=3 and min_neu=3 to only include patterns that occur at least 3 times and contain at least 3 neurons and return the detected "patterns".
Solution
results = spade(spiketrains=sts, binsize=5 * pq.ms, winlen=1, min_occ=3, min_neu=3)
patterns = results["patterns"]
patternsTime for data mining: 0.173905611038208[{'itemset': (0, 2, 5, 9, 4, 7, 3, 6, 8, 1),
'windows_ids': (25,
166,
174,
283,
346,
551,
667,
737,
959,
1063,
1135,
1242,
1292,
1460,
1716,
1988,
2147,
2362,
2408,
2643,
2815),
'neurons': [0, 2, 5, 9, 4, 7, 3, 6, 8, 1],
'lags': array([0., 0., 0., 0., 0., 0., 0., 0., 0.]) * ms,
'times': array([ 125., 830., 870., 1415., 1730., 2755., 3335., 3685.,
4795., 5315., 5675., 6210., 6460., 7300., 8580., 9940.,
10735., 11810., 12040., 13215., 14075.]) * ms,
'signature': (10, 21),
'pvalue': -1}]Exercise: Get the "signature" of every pattern as shown in @exm-stack and print them.
Solution
sig = np.stack([p["signature"] for p in patterns])
sigarray([[10, 21]])Exercise: Find the pattern that occured most often and print it.
Solution
idx = np.argmax(sig[:, 1])
plot_patterns(sts, patterns[idx])Section 3: Statistical Inference with Surrogate Data
SPADE can not only find frequently occurring items, it can also test their statistical significance. This is done by generating surrogate data sets which are copies of the original data where the spike times are randomly dithered. By performing the frequent itemset mining on the surrogate data, the algorithm obtains a null-distribution of the patterns that ought to be expected from random data. We can then obtain the p-values for the actually observed patterns by checking how often a pattern with the given signature appears in the surrogate data. In this section, we will explore how to do statistical inference with SPADE and select patterns that are significant.
| Code | Description |
|---|---|
results = spade(sts, binsize, winlen, n_surr=100) |
Run spade() on the spike trains sts and generate 100 surrogate data sets for statistical inference |
pvalue_spectrum = results["pvalue_spectrum"] |
Get the spectrum of p-values computed on the surrogate data |
patterns = results["patterns"] |
Get the "patterns" detected by spade() |
[p for p in patterns if p["pvalue"] < 0.05] |
Get all patterns where the "pvalue" is below 0.05 |
results = spade(sts, binsize, winlen, psr_param=[0, 1, 0]) |
Run spade() and perform a pattern set reduction that drops patterns that are contained in a larger superset |
Exercises
Example: Apply spade() to the spike trains sts and test the significance of the detected patterns based on 10 surrogate data sets. Get the "pvalue_spectrum" from the results and print it. Each row contains a pattern signature and the associated p-value. For example [2, 11, 0.1] indicates that a pattern of 2 neurons that occurs 11 times has a p-value of 0.1.
results = spade(spiketrains=sts, binsize=5 * pq.ms, winlen=1, n_surr=10)
pvalue_spectrum = results["pvalue_spectrum"]
pvalue_spectrumTime for data mining: 0.1987929344177246
Time for pvalue spectrum computation: 2.659097194671631[[2, 2, 1.0],
[2, 3, 1.0],
[2, 4, 1.0],
[2, 5, 1.0],
[2, 6, 1.0],
[2, 7, 0.9],
[2, 8, 0.5],
[2, 9, 0.2],
[2, 10, 0.1],
[3, 2, 1.0],
[3, 3, 0.9],
[3, 4, 0.2],
[4, 2, 0.9],
[4, 3, 0.1]]Exercise: Run spade() with binsize=5*pq.ms, winlen=1 and n_surr=20. Then, get the "pvalue_spectrum". What is the p-value of a 2 neuron pattern that occurs 10 times?
Solution
results = spade(spiketrains=sts, binsize=5 * pq.ms, winlen=1, n_surr=20)
pvalue_spectrum = results["pvalue_spectrum"]
pvalue_spectrumTime for data mining: 0.20080137252807617
Time for pvalue spectrum computation: 5.775552034378052[[2, 2, 1.0],
[2, 3, 1.0],
[2, 4, 1.0],
[2, 5, 1.0],
[2, 6, 1.0],
[2, 7, 0.95],
[2, 8, 0.75],
[2, 9, 0.55],
[2, 10, 0.3],
[2, 11, 0.15],
[2, 12, 0.05],
[3, 2, 1.0],
[3, 3, 0.95],
[3, 4, 0.2],
[3, 5, 0.05],
[4, 2, 0.65],
[5, 2, 0.05]]Exercise: Run spade() with binsize=5*pq.ms, winlen=1 and n_surr=100. Then, get the "pvalue_spectrum". What is the p-value of a 2 neuron pattern that occurs 12 times?
Solution
results = spade(spiketrains=sts, binsize=5 * pq.ms, winlen=1, n_surr=100)
pvalue_spectrum = results["pvalue_spectrum"]
pvalue_spectrumTime for data mining: 0.20802736282348633
Time for pvalue spectrum computation: 31.27750515937805[[2, 2, 1.0],
[2, 3, 1.0],
[2, 4, 1.0],
[2, 5, 1.0],
[2, 6, 1.0],
[2, 7, 0.99],
[2, 8, 0.8],
[2, 9, 0.38],
[2, 10, 0.14],
[2, 11, 0.03],
[3, 2, 1.0],
[3, 3, 0.88],
[3, 4, 0.3],
[3, 5, 0.07],
[4, 2, 0.75],
[4, 3, 0.02],
[5, 2, 0.05]]Example: Get the "patterns" from the results and find all patterns with a "pvalue" below 0.05.
patterns = results["patterns"]
significant_patterns = [p for p in patterns if p["pvalue"] < 0.05]
significant_patterns[{'itemset': (0, 2, 5, 9, 4, 7, 3, 16, 6, 8, 1),
'windows_ids': (1988, 2408),
'neurons': [0, 2, 5, 9, 4, 7, 3, 16, 6, 8, 1],
'lags': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) * ms,
'times': array([ 9940., 12040.]) * ms,
'signature': (11, 2),
'pvalue': 0.0},
{'itemset': (0, 2, 5, 9, 4, 7, 3, 6, 8, 1, 54),
'windows_ids': (2408, 2643),
'neurons': [0, 2, 5, 9, 4, 7, 3, 6, 8, 1, 54],
'lags': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) * ms,
'times': array([12040., 13215.]) * ms,
'signature': (11, 2),
'pvalue': 0.0},
{'itemset': (0, 2, 5, 9, 4, 7, 3, 6, 8, 1),
'windows_ids': (25,
166,
174,
283,
346,
551,
667,
737,
959,
1063,
1135,
1242,
1292,
1460,
1716,
1988,
2147,
2362,
2408,
2643,
2815),
'neurons': [0, 2, 5, 9, 4, 7, 3, 6, 8, 1],
'lags': array([0., 0., 0., 0., 0., 0., 0., 0., 0.]) * ms,
'times': array([ 125., 830., 870., 1415., 1730., 2755., 3335., 3685.,
4795., 5315., 5675., 6210., 6460., 7300., 8580., 9940.,
10735., 11810., 12040., 13215., 14075.]) * ms,
'signature': (10, 21),
'pvalue': 0.0},
{'itemset': (2, 5),
'windows_ids': (25,
166,
174,
283,
346,
551,
667,
737,
959,
1063,
1135,
1242,
1292,
1460,
1716,
1988,
2147,
2317,
2362,
2408,
2643,
2815),
'neurons': [2, 5],
'lags': array([0.]) * ms,
'times': array([ 125., 830., 870., 1415., 1730., 2755., 3335., 3685.,
4795., 5315., 5675., 6210., 6460., 7300., 8580., 9940.,
10735., 11585., 11810., 12040., 13215., 14075.]) * ms,
'signature': (2, 22),
'pvalue': 0.0},
{'itemset': (9, 7),
'windows_ids': (23,
25,
166,
174,
283,
346,
551,
667,
737,
959,
1063,
1135,
1242,
1292,
1460,
1716,
1988,
2147,
2362,
2408,
2643,
2815),
'neurons': [9, 7],
'lags': array([0.]) * ms,
'times': array([ 115., 125., 830., 870., 1415., 1730., 2755., 3335.,
3685., 4795., 5315., 5675., 6210., 6460., 7300., 8580.,
9940., 10735., 11810., 12040., 13215., 14075.]) * ms,
'signature': (2, 22),
'pvalue': 0.0},
{'itemset': (4, 7),
'windows_ids': (25,
166,
174,
283,
346,
551,
667,
737,
959,
1063,
1135,
1242,
1292,
1460,
1716,
1988,
2147,
2362,
2408,
2506,
2643,
2815),
'neurons': [4, 7],
'lags': array([0.]) * ms,
'times': array([ 125., 830., 870., 1415., 1730., 2755., 3335., 3685.,
4795., 5315., 5675., 6210., 6460., 7300., 8580., 9940.,
10735., 11810., 12040., 12530., 13215., 14075.]) * ms,
'signature': (2, 22),
'pvalue': 0.0},
{'itemset': (4, 1),
'windows_ids': (25,
166,
174,
283,
346,
551,
667,
737,
959,
1063,
1135,
1242,
1292,
1460,
1661,
1716,
1988,
2147,
2362,
2408,
2643,
2815),
'neurons': [4, 1],
'lags': array([0.]) * ms,
'times': array([ 125., 830., 870., 1415., 1730., 2755., 3335., 3685.,
4795., 5315., 5675., 6210., 6460., 7300., 8305., 8580.,
9940., 10735., 11810., 12040., 13215., 14075.]) * ms,
'signature': (2, 22),
'pvalue': 0.0},
{'itemset': (8, 1),
'windows_ids': (25,
166,
174,
283,
346,
551,
667,
737,
797,
959,
1063,
1135,
1242,
1292,
1460,
1716,
1988,
2147,
2362,
2408,
2643,
2815),
'neurons': [8, 1],
'lags': array([0.]) * ms,
'times': array([ 125., 830., 870., 1415., 1730., 2755., 3335., 3685.,
3985., 4795., 5315., 5675., 6210., 6460., 7300., 8580.,
9940., 10735., 11810., 12040., 13215., 14075.]) * ms,
'signature': (2, 22),
'pvalue': 0.0}]Exercise: Plot all significant patterns (Hint: plot_patterns accepts multiple patterns).
Solution
plot_patterns(sts, significant_patterns)Exercise: Rerun spade() with psr_param=[0,1,0] to perform pattern set reduction where patterns that are subsets of a larger superset are discarded. Then get the "patterns" from the results of spade().
Solution
results = spade(
spiketrains=sts, binsize=5 * pq.ms, winlen=1, n_surr=50, psr_param=[0, 1, 0]
)
patterns = results["patterns"]Time for data mining: 0.23085927963256836
Time for pvalue spectrum computation: 13.977830171585083Exercise: Get the "patterns" from the results, find all patterns with a "pvalue" below 0.05 and plot them.
Solution
patterns = results["patterns"]
significant_patterns = [p for p in patterns if p["pvalue"] < 0.05]
plot_patterns(sts, significant_patterns)Exercise: What are the p-values of the significant patterns?
Solution
[p["pvalue"] for p in significant_patterns][0.0, 0.0, 0.0]