Inferring single-trial neural population dynamics using sequential auto-encoders

Chethan Pandarinath, Daniel J O'Shea, Jasmine Collins, Rafal Jozefowicz, Sergey D Stavisky, Jonathan C Kao, Eric M Trautmann, Matthew T Kaufman, Stephen I Ryu, Leigh R Hochberg, Jaimie M Henderson, Krishna V Shenoy, L F Abbott, David Sussillo, Chethan Pandarinath, Daniel J O'Shea, Jasmine Collins, Rafal Jozefowicz, Sergey D Stavisky, Jonathan C Kao, Eric M Trautmann, Matthew T Kaufman, Stephen I Ryu, Leigh R Hochberg, Jaimie M Henderson, Krishna V Shenoy, L F Abbott, David Sussillo

Abstract

Neuroscience is experiencing a revolution in which simultaneous recording of thousands of neurons is revealing population dynamics that are not apparent from single-neuron responses. This structure is typically extracted from data averaged across many trials, but deeper understanding requires studying phenomena detected in single trials, which is challenging due to incomplete sampling of the neural population, trial-to-trial variability, and fluctuations in action potential timing. We introduce latent factor analysis via dynamical systems, a deep learning method to infer latent dynamics from single-trial neural spiking data. When applied to a variety of macaque and human motor cortical datasets, latent factor analysis via dynamical systems accurately predicts observed behavioral variables, extracts precise firing rate estimates of neural dynamics on single trials, infers perturbations to those dynamics that correlate with behavioral choices, and combines data from non-overlapping recording sessions spanning months to improve inference of underlying dynamics.

Figures

Figure 1.
Figure 1.
LFADS is a generative model that assumes that observed single-trial spiking activity is generated by an underlying dynamical system. (a) LFADS takes a given recording (far left), reduces it to a latent code consisting of an inferred initial condition (middle), and then attempts to infer rates that are consistent with the observed data (right, pink panel) from that latent code. I.e. LFADS auto-encodes the trial via a sequential auto-encoder. Working from right to left in the panel, for theith neuron, LFADS infers rates at time t, rt,i, for each of 202 channels, and the observed spike counts (blue panel) are assumed to be Poisson distributed count observations of these underlying rates. The likelihood of the observed spikes given the inferred rates serves as the cost function used to optimize the weights of the model. The rates are linear readouts from a set of low-dimensional factorsft(40 in this example) via a readout matrixWrate. The factors are defined as linear readouts from a dynamical generator (an RNN), via a readout matrixWfac. Activity of the generator is determined by its per-trial component, the initial condition (g0), and its recurrent connectivity, which is fixed for all trials. The initial conditiong0 is determined for individual trials via an encoder RNN. (b) Example spiking activity recorded from M1/PMd as a monkey performed a reaching task, as well as the corresponding ratesrtand factorsftinferred by LFADS (7 example trials are shown). Circles denote time of movement onset.
Figure 2.
Figure 2.
Application of LFADS to a “Maze” reaching task. (a) A monkey was trained to perform arm reaching movements to guide a cursor in a 2-D plane from a starting location (center of the workspace) to peripheral targets. Individual reaches are colored by target location. Virtual barriers in the workspace facilitated instruction of curved (or straight) reaches on a per-condition basis (see (d) for examples). (b) Comparison of condition-averaged (left) and single-trial (right) rates for 4 individual neurons (columns) for three different methods (rows). Left: Each trace represents a different reach condition (8 selected of 108 total). Right: Each trace represents an individual trial (same color scheme as the condition-averaged panels).Top row: PSTHs created by smoothing observed spikes with a Gaussian kernel (30 ms s.d.). Middle row: LFADS-inferred rates.Bottom row: GPFA-inferred firing rates, created by fitting a generalized linear model (GLM) to map the GPFA-inferred factor representations onto the true spiking activity. Horizontal scale bar represents 300 ms. Vertical scale bar denotes rate (spikes/sec). PSTHs for all neurons are shown in Supp. Data 1. (c) Application of t-SNE to the generator initial conditions (g0). Each point represents the reduction of the g0 vector into a 3-D t-SNE space for an individual trial (2296 trials total), 2-D projection shown, full 3-D projection shown in Supp. Video 1. Trials are color coded by the angle of the reach target [same as (a)]. (d) Decoding reaching kinematics using optimal linear estimation. Each row shows an example condition (3 shown, of 108 total). Column 1: true reach trajectories (black traces, 10 example trials per condition). Columns 2–4: examples of cross-validated reconstruction of these trajectories using OLE applied to the neural data, which was first de-noised either via LFADS, by smoothing with a Gaussian filter (40 ms s.d.), or using GPFA to reduce its dimensionality. (e) Decoding accuracy was quantified by measuring variance explained (R2) between the true and decoded velocities for individual trials across the entire dataset (2296 trials), for all three techniques and additionally for simple binning of the neural data. Accuracy was also measured for random sub-samples from the full neural population of 202 neurons. Dotted lines connect the median R2 values for each population size. (f) LFADS-inferred factors are informative about neurons that are held-out from model training. LFADS models and GPFA were fit to subsets of the full population of the 202 neurons [same populations as in (e)]. We then used a GLM to map the latent state estimates produced by LFADS or GPFA onto the binned spike counts (20 ms bins) for the remaining held-out neurons, e.g., for a model trained with 25 neurons, there are 177=202–25 held out neurons. We evaluated the cross-validated performance of the fit GLM models using log likelihood (LL) per spike. Each point represents a given held-out neuron for a given random sampling of the population.
Figure 3.
Figure 3.
LFADS uncovers known rotational dynamics in monkey and human motor cortical activity on a single-trial basis. (a, c) Rotational dynamics underlying the neural population state accompany the transition between pre- and peri-movement activity, and have been previously described for monkey and human motor cortical activity by projecting condition-averaged activity into a low-dimensional plane using jPCA. Each trace shows the neural population state trajectories for a single task condition (monkey: 108 reaching conditions; human: 8 intended movement directions). (b, d) When the same low-dimensional projection is applied to the single-trial data, dynamics are less clear due to the inherent noise of single-trial neural population activity. (e, g) When LFADS is applied, the condition-averaged inferred rates exhibit similar underlying dynamic structure for monkey and human. (f, h) Additionally, the same dynamic structure is now clearly present on individual trials (monkey: 2296 trials; human: 114 trials). (i-k) Testing generalizability of the generator’s dynamics to held out conditions. (i) Conditions were binned by the angle of the reach target (black dashed lines), resulting in 19 sets. 19 LFADS models’ generator dynamics were then trained, each on 18 subsets of the data with 1 subset held out, and then evaluated on the held-out subset. (j) LFADS-inferred rates for held-out conditions were combined across the 19 models and were projected into the jPCA space found by training an LFADS model on all conditions (i.e., panel f). (k) Correspondence between initial position in jPCA space when a trial is used in the training set for an LFADS model and when it is held-out (Pearson’s correlation coefficientr = 0.97, 0.77 for jPC1, jPC2 respectively). Each dot represents an individual trial (2296 trials).
Figure 4.
Figure 4.
Using “dynamic neural stitching,” LFADS combines data from separately collected, non-overlapping recordings of the neural population by learning one consistent dynamical model. (a) Schematic of the LFADS architecture adapted for dynamic neural stitching. Per-session “readin” matricesWsinputand “readout” matricesWsrateare used to map from each dataset’s rates to the input factors and from the factors back out to rates, respectively (pink areas). The encoder RNN, generator RNN, and factor readout matrixWfac are shared among datasets (blue area). For this example, eachWsratewas learned whereasWsinputwas set using a principal components regression approach (see Online Methods). A total of 44 individual recording sessions using 24 channel linear multielectrode arrays were used. (b) Locations of linear electrode array penetrations in the precentral gyrus from which each dataset was collected. Dashed lines indicate approximate locations of nearby sulcal features based on stereotaxic locations. Arc. Sp.: arcuate spur, PCd: precentral dimple, CS: central sulcus. (c) Example single-trial rasters for nearly identical upwards reaches performed on a subset of 5 of the 44 recording sessions. Each raster has 24 rows corresponding to the 24 channels of the linear array, but the neurons recorded on each session are entirely distinct from each other. (d) After training, the multi-session stitched LFADS model produced consistent factor trajectories for each behavioral condition across recording sessions. Traces are condition-averaged factor trajectories for the multi-session stitched LFADS model projected into a subspace which spans the condition independent signal (CIS) and the first jPCA plane (see Methods). LFADS factors are averaged over all trials in each reach direction for each recording session and projected into this subspace to produce a single trajectory; the color of each trajectory represents the reach direction. The spatial proximity of the trajectories for a given direction across the sessions (44 trajectories of each color) illustrates the consistency of the representation across sessions. (e) R2 values between arm kinematics and either smoothing neural data, GPFA, single-session or stitched LFADS factor decodes. A single shared decoder was fit for the stitched model; a separate decoder was fit for each single-session model. “***” indicates significant improvement in median R2, p < 10−8, Wilcoxon signed-rank test. (f) Actual recorded hand position traces for center out reaching task (left), alongside kinematic decodes for a representative single session (session 32), for smoothed neural data, GPFA, single-session LFADS, and stitched LFADS (left to right). Colors indicate reach direction. (g) Single-trial factor trajectories from the stitched LFADS model. Only the first seven of 44 sessions are shown for ease of presentation (see also Supp. Video 3).
Figure 5.
Figure 5.
LFADS uncovers the presence, identity and timing of unexpected perturbations in the “Cursor Jump” task. (a) Schematic of the LFADS architecture adapted for inferring inputs to a neural population. As before, LFADS reduces individual trials to the initial state of the generator RNN (g0). However, now the activity of the generator is additionally determined by a set of time-varying inferred inputs (ut), modeled stochastically like g0 with a mean and variance, which are inputted to the generator at each time point. The inferred input utis output by a controller RNN, which receives time-varying input from the encoding network, as well as the factors representation at the preceding timestep. (b) Schematic depicting the “Cursor Jump” task. The position of a monkey’s hand was linked to the position of an on-screen cursor, and the monkey made reaching movements to steer the cursor toward upward or downward targets. In unperturbed trials (grey traces), the monkey made straight reaches to the target. In perturbed trials (orange traces), the cursor’s position was offset to the left or right during the course of the reaching movement, and the monkey made corrective movements to acquire the target. (c) Spiking activity from M1/PMd arrays during three example reach trials to downward targets for the unperturbed (top), perturb right (middle), and perturb left (bottom) conditions. Squares denote time of target onset, and triangles denote the time of an unexpected perturbation. (d) LFADS was allowed 4 inferred inputs to model the neural activity. For presentation, two trial alignments were used prior to averaging: the initial portion of the trials was aligned to the time of target onset, while the latter portion of the trials was aligned by perturbation time (or, for unperturbed trials, the time at which a perturbation would have occurred based on the cursor’s trajectory). The gap in the traces denotes the break in alignment. Inferred input values were averaged across trials for upward (top) and downward (bottom) trials (mean ± s.e.m. is shown, grey: unperturbed trials, blue: perturb left trials, red: perturb right trials). Around the time of target onset, the identity of the target (up vs. down) is modeled by the inputs (e.g., dimension 1). Around the time of the perturbation, LFADS used specific inferred input patterns to model each perturbation type (e.g., dimensions 1 & 2). Input traces were smoothed with a causal Gaussian filter (20 ms s.d.). (e) The single-trial input patterns around the time of perturbation (all downward trials) were projected into a low-dimensional space using t-SNE and colored by the three perturbation types (unperturbed, left perturbation, right perturbation). Black boxes denote locations in t-SNE space for the example trials shown in panel c.
Figure 6.
Figure 6.
When inferred inputs are allowed (Fig. 5a), LFADS uncovers fast oscillatory structure in neural firing patterns. (a) Example single-trial spiking activity recorded from human M1 and monkey M1/PMd, as well as LFADS-inferred rates, and local field potentials. 400 ms of data are shown, beginning at the time of target presentation during an 8-target center-out-and-back movement paradigm. For T7, analyses were restricted to channels that showed significant modulation during movement attempts (78/192 channels). Dashed red lines overlaid on monkey data segregate the M1 array (upper halves) and PMd array (lower halves). Squares denote time of target onset. For Monkey J, where movement was measurable, circle denotes time of movement onset. (b) Cross-correlations between the local field potentials recorded on each electrode and the observed spiking activity (black traces; mean ± s.e.m.) or the LFADS-inferred rates (red traces) for several example channels (participant T7: 142 trials; monkey J: 373 trials). LFP were first low-pass filtered (75 Hz cutoff frequency). Randomly shuffling the trial identity (i.e., correlating spikes from one trial with LFP from another) largely removed the fast, oscillatory components in the cross-correlograms (blue traces).

References

    1. Afshar A et al. Single-trial neural correlates of arm movement preparation. Neuron 71, 555–564 (2011).
    1. Carnevale F, de Lafuente V, Romo R, Barak O & Parga N Dynamic Control of Response Criterion in Premotor Cortex during Perceptual Detection under Temporal Uncertainty. Neuron 86, 1067–1077 (2015).
    1. Churchland MM et al. Neural population dynamics during reaching. Nature (2012). doi:10.1038/nature11129
    1. Harvey CD, Coen P & Tank DW Choice-specific sequences in parietal cortex during a virtual-navigation decision task. Nature 484, 62–68 (2012).
    1. Kaufman MT, Churchland MM, Ryu SI & Shenoy KV Cortical activity in the null space: permitting preparation without movement. Nat. Neurosci 17, 440–448 (2014).
    1. Kobak D et al. Demixed principal component analysis of neural population data. Elife 5, (2016).
    1. Mante V, Sussillo D, Shenoy KV & Newsome WT Context-dependent computation by recurrent dynamics in prefrontal cortex. Nature 503, 78–84 (2013).
    1. Pandarinath C et al. Neural population dynamics in human motor cortex during movements in people with ALS. Elife 4, (2015).
    1. Sadtler PT et al. Neural constraints on learning. Nature in press, 423–426 (2014).
    1. Shenoy KV, Sahani M & Churchland MM Cortical control of arm movements: a dynamical systems perspective. Annu. Rev. Neurosci 36, 337–359 (2013).
    1. Ahrens MB et al. Brain-wide neuronal dynamics during motor adaptation in zebrafish. Nature 485, 471–477 (2012).
    1. Yu BM et al. Gaussian-Process Factor Analysis for Low-Dimensional Single-Trial Analysis of Neural Population Activity. J. Neurophysiol 102, 614–635 (2009).
    1. Zhao Y & Park IM Variational Latent Gaussian Process for Recovering Single-Trial Dynamics from Population Spike Trains. Neural Comput. 29, 1293–1316 (2017).
    1. Aghagolzadeh M & Truccolo W Latent state-space models for neural decoding. Conf. Proc. IEEE Eng. Med. Biol. Soc 2014, 3033–3036 (2014).
    1. Gao Y, Archer EW, Paninski L & Cunningham JP in Advances in Neural Information Processing Systems 29 (eds. Lee DD, Sugiyama M, Luxburg UV, Guyon I & Garnett R) 163–171 (Curran Associates, Inc., 2016). at <>
    1. Kao JC et al. Single-trial dynamics of motor cortex and their applications to brain-machine interfaces. Nat. Commun 6, (2015).
    1. Macke JH et al. Empirical models of spiking in neural populations. Advances in neural information processing systems 1350–1358 (2011).
    1. Linderman S et al. Bayesian Learning and Inference in Recurrent Switching Linear Dynamical Systems. Artificial Intelligence and Statistics 914–922 (2017). at <>
    1. Petreska B et al. in Advances in Neural Information Processing Systems 24 (eds. Shawe-Taylor J, Zemel RS, Bartlett PL, Pereira F & Weinberger KQ) 756–764 (Curran Associates, Inc., 2011). at <>
    1. Kato S et al. Global brain dynamics embed the motor command sequence of Caenorhabditis elegans. Cell 163, 656–669 (2015).
    1. Kaufman MT et al. The largest response component in motor cortex reflects movement timing but not movement type. eNeuro 3, ENEURO.0085–16.2016 (2016).
    1. Gao P & Ganguli S On simplicity and complexity in the brave new world of large-scale neuroscience. Curr. Opin. Neurobiol 32, 148–155 (2015).
    1. Kingma DP & Welling M Auto-Encoding Variational Bayes. arXiv [] (2013). at <>
    1. Doersch C Tutorial on Variational Autoencoders. arXiv [] (2016). at <>
    1. Sussillo D, Jozefowicz R, Abbott LF & Pandarinath C LFADS - Latent Factor Analysis via Dynamical Systems. arXiv (2016). at <>
    1. Salinas E & Abbott LF Vector reconstruction from firing rates. J. Comput. Neurosci 1, 89–107 (1994).
    1. Willett FR et al. Feedback control policies employed by people using intracortical brain-computer interfaces. J. Neural Eng 14, (2017).
    1. Turaga S et al. in Advances in Neural Information Processing Systems 26 (eds. Burges CJC, Bottou L, Welling, Ghahramani Z & Weinberger KQ) 539–547 (Curran Associates, Inc., 2013). at <>
    1. Nonnenmacher M, Turaga SC & Macke JH in Advances in Neural Information Processing Systems 30 (eds. Guyon I et al.) 5706–5716 (Curran Associates, Inc., 2017). at <>
    1. Donoghue JP, Sanes JN, Hatsopoulos NG & Gaal G Neural discharge and local field potential oscillations in primate motor cortex during voluntary movements. J Neurophysiol 79, 159–173 (1998).
    1. Murthy VN & Fetz EE Synchronization of neurons during local field potential oscillations in sensorimotor cortex of awake monkeys. J. Neurophysiol 76, 3968–3982 (1996).
    1. Fries P A mechanism for cognitive dynamics: neuronal communication through neuronal coherence. Trends Cogn. Sci 9, 474–480 (2005).
    1. Yuste R From the neuron doctrine to neural networks. Nat. Rev. Neurosci 16, (2015).
    1. Gilja V et al. Clinical translation of a high-performance neural prosthesis. Nat. Med 21, (2015).
    1. Pandarinath C et al. High performance communication by people with paralysis using an intracortical brain-computer interface. Elife 6, (2017).
    1. Sussillo D et al. A recurrent neural network for closed-loop intracortical brain-machine interface decoders. J. Neural Eng 9, 26027 (2012).
    1. Sussillo D, Stavisky SD, Kao JC, Ryu SI & Shenoy KV Making brain–machine interfaces robust to future neural variability. Nat. Commun 7, 13749 (2016).
    1. Ezzyat Y et al. Closed-loop stimulation of temporal cortex rescues functional networks and improves memory. Nat. Commun 9, 365 (2018).
    1. Klinger NV & Mittal S Clinical efficacy of deep brain stimulation for the treatment of medically refractory epilepsy. Clin. Neurol. Neurosurg 140, 11–25 (2016).
    1. Little S et al. Adaptive deep brain stimulation in advanced Parkinson disease. Ann. Neurol 449–457 (2013). doi:10.1002/ana.23951
    1. Rosin B et al. Closed-loop deep brain stimulation is superior in ameliorating parkinsonism. Neuron 72, 370–384 (2011).
    1. Williamson RS, Sahani M & Pillow JW The equivalence of information-theoretic and likelihood-based methods for neural dimensionality reduction. PLoS Comput. Biol 11, e1004141 (2015).
Methods-only References
    1. Rezende DJ, Mohamed S & Wierstra D Stochastic backpropagation and approximate inference in deep generative models. in International Conference on Machine Learning, 2014 (2014).
    1. Gregor K, Danihelka I, Graves A, Rezende DJ & Wierstra D DRAW: A Recurrent Neural Network For Image Generation. arXiv [] (2015). at <>
    1. Krishnan RG, Shalit U & Sontag D Deep Kalman Filters. arXiv Prepr. arXiv1511.05121 (2015).
    1. Chung J, Gulcehre C, Cho K & Bengio Y Empirical evaluation of gated recurrent neural networks on sequence modeling. arXiv Prepr. arXiv1412.3555 (2014).
    1. Hinton GE, Srivastava N, Krizhevsky A, Sutskever I & Salakhutdinov RR Improving neural networks by preventing co-adaptation of feature detectors. arXiv Prepr. arXiv1207.0580 (2012).
    1. Zaremba W, Sutskever I & Vinyals O Recurrent neural network regularization. arXiv Prepr. arXiv1409.2329 (2014).
    1. Bowman SR et al. Generating sentences from a continuous space. Conf. Comput. Nat. Lang. Learn (2016).
    1. Sussillo D & Abbott LF Generating coherent patterns of activity from chaotic neural networks. Neuron 63, 544–557 (2009).
    1. Sussillo D, Churchland MM, Kaufman MT & Shenoy KV A neural network that finds a naturalistic solution for the production of muscle activity. Nat. Neurosci 18, 1025–1033 (2015).
    1. Rajan K, Harvey CD & Tank DW Recurrent Network Models of Sequence Generation and Memory. Neuron 90, 1–15 (2016).
    1. Chung J et al. A Recurrent Latent Variable Model for Sequential Data. in Advances in Neural Information Processing Systems (NIPS) (2015).
    1. Gao Y, Buesing L, Shenoy KV & Cunningham JP High-dimensional neural spike train analysis with generalized count linear dynamical systems. Adv. Neural Inf. Process. Syst 1–9 (2015). at <>
    1. Bayer J & Osendorfer C Learning stochastic recurrent networks. arXiv Prepr. arXiv1411.7610 (2014).
    1. Watter M, Springenberg J, Boedecker J & Riedmiller M Embed to control: A locally linear latent dynamics model for control from raw images. in Advances in Neural Information Processing Systems 2746–2754 (2015).
    1. Karl M, Soelch M, Bayer J & van der Smagt P Deep variational Bayes filters: Unsupervised learning of state space models from raw data. arXiv Prepr. arXiv1605.06432 (2016).
    1. Gilja V et al. A high-performance neural prosthesis enabled by control algorithm design. Nat. Neurosci 15, 1752–7 (2012).
    1. Maaten L van der & Hinton G Visualizing data using t-SNE. J. Mach. Learn. Res 9, 2579–2605 (2008).
    1. Fan JM et al. Intention estimation in brain--machine interfaces. J. Neural Eng 11, 16004 (2014).

Source: PubMed

3
Subscribe