import numpy as np
import matplotlib.pyplot as plt
import quantities as pq
from quantities import s, ms, mV, Hz, kHz
import neo, elephant

t = pq.Quantity(range(10000),units='ms')
f1, f2 = 20. * Hz, 23. * Hz
signal = neo.AnalogSignal(np.array([
                          np.sin(f1 * 2. * np.pi * t.rescale(s)),
                          np.sin(f2 * 2. * np.pi * t.rescale(s))]).T,
                          units=pq.mV, sampling_rate=1. * kHz)
spiketrain = neo.SpikeTrain(
   range(t[0], t[-1], 50), units='ms',
   t_start=t[0], t_stop=t[-1])
sfc, freqs = elephant.sta.spike_field_coherence(
   signal, spiketrain, window='boxcar')

plt.plot(freqs, sfc[:,0])
plt.plot(freqs, sfc[:,1])
plt.xlabel('Frequency [Hz]')
plt.ylabel('SFC')
plt.xlim((0, 60))
plt.show()