Note
Click here to download the full example code
This example shows a spike detection algorithm that uses autonomous linear state space models together with exponentially decaying windows. Given is a multi-channel signal containing multiple spikes (sinusoidal cycle with decaying amplitude) with additive white Gaussian noise and a baseline.
import matplotlib.pyplot as plt import numpy as np import lmlib as lm from scipy.linalg import block_diag from scipy.signal import find_peaks from lmlib.utils.generator import gen_convolve, gen_sinusoidal, gen_exponential, gen_unit_impulse, gen_wgn, \ gen_baseline_sin, k_period_to_omega # signal generation K = 550 L = 3 # number of channels spike_length = 20 spike_decay = 0.88 spike_locations = [100, 240, 370] spike = gen_sinusoidal(spike_length, spike_length) * gen_exponential(spike_length, spike_decay) y_sp = gen_convolve(gen_unit_impulse(K, spike_locations), spike) y = np.column_stack([0.8*y_sp + gen_wgn(K, sigma=0.2, seed=10000-l) + 10*l*gen_baseline_sin(K, int(K * 0.9/(l+1))) for l in range(L)]).reshape(K, 1, L) # Model alssm_bl = lm.AlssmPoly(poly_degree=3) alssm_sp = lm.AlssmSin(k_period_to_omega(spike_length), spike_decay) # Segments g_bl = 500 g_sp = 5000 len_sp = spike_length len_bl = int(1.5*spike_length) segment_left = lm.Segment(a=-len_bl, b=-1, direction=lm.FORWARD, g=g_bl, delta=-1) segment_middle = lm.Segment(a=0, b=len_sp, direction=lm.BACKWARD, g=g_sp) segment_right = lm.Segment(a=len_sp+1, b=len_sp+1+len_bl, direction=lm.BACKWARD, g=g_bl, delta=len_sp) # Cost F = [[0, 1, 0], [1, 1, 1]] cost = lm.CompositeCost((alssm_sp, alssm_bl), (segment_left, segment_middle, segment_right), F) se_param = lm.SEParam(cost) se_param.filter(y) H_sp = block_diag([[1], [0]], np.eye(alssm_bl.N)) xs_sp = se_param.minimize_x(H_sp) H_bl = np.vstack([np.zeros((alssm_sp.N, alssm_bl.N)), np.eye(alssm_bl.N)]) xs_bl = se_param.minimize_x(H_bl) # Error J = se_param.eval_errors(xs_sp, range(K)) J_bl = se_param.eval_errors(xs_bl, range(K)) J_sum = np.sum(J, axis=-1) J_bl_sum = np.sum(J_bl, axis=-1) lcr = -0.5 * np.log(J_sum / J_bl_sum) peaks, _ = find_peaks(lcr, height=0.041, distance=30) # Plot fig, axs = plt.subplots(4, 1, figsize=(9, 8), gridspec_kw={'height_ratios': [1, 3, 1, 1]}, sharex='all') # Window wins = lm.map_window(cost.window(segment_selection=[1, 1, 1]), peaks, K, merge_ks=True) axs[0].set(ylabel='$w_k$') axs[0].plot(range(K), wins[0], lw=1) axs[0].plot(range(K), wins[1], lw=1) axs[0].plot(range(K), wins[2], lw=1) axs[0].legend(('left segment (bl)', 'middle segment (bl+sp)', 'right segment (bl)'), loc=1, fontsize='small') # Signals axs[1].set(ylabel='$y_k$') axs[1].plot(range(K), y_sp + 7.5, c='g', lw=1) axs[1].plot(range(K), y[:, 0] + [0, 2.5, 5], c='k', lw=1) axs[1].legend(('true spikes', 'observations'), loc=1) # LCR axs[2].set(ylabel='lcr', ylim=[0, 0.15]) axs[2].plot(range(K), lcr, c='red', lw=0.7, label='LCR') axs[2].scatter(peaks, lcr[peaks], marker=7) axs[2].legend(loc=1) # Error axs[3].set(ylabel='$J_k$', xlabel='$k$') axs[3].plot(range(K), J_sum, c='b', lw=0.5, label="$J^{bl+sp}_{sum}$") axs[3].plot(range(K), J_bl_sum, c='r', lw=0.5, label="$J^{bl}_{sum}$") axs[3].legend(loc=1) plt.show()
Total running time of the script: ( 0 minutes 0.552 seconds)
Gallery generated by Sphinx-Gallery