Multi-Channel Spike Detection [ex112.0]ΒΆ

This example shows a spike detection algorithm that uses autonomous linear state space models together with exponentially decaying windows. Given is a multi-channel signal containing multiple spikes (sinusoidal cycle with decaying amplitude) with additive white Gaussian noise and a baseline.

../../_images/sphx_glr_example-ex112.0-mc-pulse-detection_001.png
import matplotlib.pyplot as plt
import numpy as np
import lmlib as lm
from scipy.linalg import block_diag
from scipy.signal import find_peaks

from lmlib.utils.generator import gen_convolve, gen_sine, gen_exponential, gen_pulse, gen_wgn, \
    gen_baseline_sin, k_period_to_omega

# signal generation
K = 550
L = 3  # number of channels
spike_length = 20
spike_decay = 0.88
spike_locations = [100, 240, 370]
spike = gen_sine(spike_length, spike_length) * gen_exponential(spike_length, spike_decay)
y_sp = gen_convolve(gen_pulse(K, spike_locations), spike)
y = np.column_stack([0.8*y_sp + gen_wgn(K, sigma=0.2, seed=10000-l) + 10*l*gen_baseline_sin(K, int(K * 0.9/(l+1))) for l in range(L)]).reshape(K, 1, L)

# Model
alssm_bl = lm.AlssmPoly(poly_degree=3)
alssm_sp = lm.AlssmSin(k_period_to_omega(spike_length), spike_decay)

# Segments
g_bl = 500
g_sp = 5000
len_sp = spike_length
len_bl = int(1.5*spike_length)
segment_left = lm.Segment(a=-len_bl, b=-1, direction=lm.FORWARD, g=g_bl, delta=-1)
segment_middle = lm.Segment(a=0, b=len_sp, direction=lm.BACKWARD, g=g_sp)
segment_right = lm.Segment(a=len_sp+1, b=len_sp+1+len_bl, direction=lm.BACKWARD, g=g_bl, delta=len_sp)

# Cost
F = [[0, 1, 0],
     [1, 1, 1]]
cost = lm.CompositeCost((alssm_sp, alssm_bl), (segment_left, segment_middle, segment_right), F)

se_param = lm.SEParam(cost)
se_param.filter(y)
H_sp = block_diag([[1], [0]], np.eye(alssm_bl.N))
xs_sp = se_param.minimize_x(H_sp)
H_bl = np.vstack([np.zeros((alssm_sp.N, alssm_bl.N)), np.eye(alssm_bl.N)])
xs_bl = se_param.minimize_x(H_bl)

# Error
J = se_param.eval_errors(xs_sp, range(K))
J_bl = se_param.eval_errors(xs_bl, range(K))
J_sum = np.sum(J, axis=-1)
J_bl_sum = np.sum(J_bl, axis=-1)

lcr = -0.5 * np.log(J_sum / J_bl_sum)

peaks, _ = find_peaks(lcr, height=0.041, distance=30)

# Plot
fig, axs = plt.subplots(4, 1, figsize=(9, 8), gridspec_kw={'height_ratios': [1, 3, 1, 1]}, sharex='all')

# Window
wins = lm.map_window(cost.window(segment_selection=[1, 1, 1]), peaks, K, merge_ks=True)

axs[0].set(ylabel='$w_k$')
axs[0].plot(range(K), wins[0], lw=1)
axs[0].plot(range(K), wins[1], lw=1)
axs[0].plot(range(K), wins[2], lw=1)
axs[0].legend(('left segment (bl)', 'middle segment (bl+sp)', 'right segment (bl)'), loc=1, fontsize='small')

# Signals
axs[1].set(ylabel='$y_k$')
axs[1].plot(range(K), y_sp + 7.5, c='g', lw=1)
axs[1].plot(range(K), y[:, 0] + [0, 2.5, 5], c='k', lw=1)
axs[1].legend(('true spikes', 'observations'), loc=1)

# LCR
axs[2].set(ylabel='lcr', ylim=[0, 0.15])
axs[2].plot(range(K), lcr, c='red', lw=0.7, label='LCR')
axs[2].scatter(peaks, lcr[peaks], marker=7)
axs[2].legend(loc=1)

# Error
axs[3].set(ylabel='$J_k$', xlabel='$k$')
axs[3].plot(range(K), J_sum, c='b', lw=0.5, label="$J^{bl+sp}_{sum}$")
axs[3].plot(range(K), J_bl_sum, c='r', lw=0.5, label="$J^{bl}_{sum}$")
axs[3].legend(loc=1)

plt.show()

Total running time of the script: ( 0 minutes 0.528 seconds)

Gallery generated by Sphinx-Gallery