Latent circuit inference from heterogeneous neural responses during cognitive tasks

Fitting a latent circuit model

We fit the latent circuit model Eqs. (1)–(3) to neural response data y by minimizing the mean squared error loss function,

$$L=\sum _\sum _\parallel _-Q__+\parallel _-_}}__,$$

(7)

using custom Python code59. Here, k indexes the trials, t indexes the time within a trial, and Q is an orthonormal embedding matrix. Because the variable x depends implicitly on the latent circuit parameters wrec and win, the minimization of L is a nonlinear least squares optimization problem60 in which we simultaneously search for a behaviorally relevant projection of the high-dimensional activity and a low-dimensional neural circuit that generates dynamics in this projection. Because orthonormal matrices define a nonlinear submanifold within the space of all matrices, minimizing L corresponds to solving a constrained optimization problem over this submanifold. To transform it into an unconstrained problem, we use the Cayley transform to parameterize orthonormal matrices by the linear space of skew symmetric matrices61,

where πn represents projection onto the first n columns, and A is skew symmetric. We parameterize A by an arbitrary square N × N matrix B,

With these reparameterizations, we can minimize L over the vector space of square matrices B. The parameterization of a skew symmetric matrix A with the auxiliary matrix B has a degeneracy because A has only N(N − 1) / 2 distinct elements. We did not attempt to eliminate this degeneracy because B is an auxiliary matrix, and we did not observe any degeneracy arising in matrix Q during fitting.

At each step of the optimization, we generate a set of trajectories x from the latent circuit dynamics and embed these trajectories using the matrix Q. The parameters B, wrec, win and wout are then updated to minimize L. We perform this minimization using PyTorch and the Adam optimizer with default values 0.9 and 0.999 for the decay rate of the first and second moment estimates, respectively, a learning rate of 0.02 and a weight decay of 0.001. We use a minibatch size of 128 trials. We stop the optimization when the loss has not improved by a threshold of 0.001 after a patience of 25 epochs. We used the Python software package Seaborn for visualizing model parameters and responses after training.

We initialize the recurrent matrix wrec from a uniform distribution centered on 0 with a standard deviation of 1/n. We initialize win with zeros except for positive entries along the diagonal on connections from inputs u to their corresponding nodes and wout with zeros except for positive entries on connections from choice nodes to their corresponding outputs z. We initialize the entries of matrix B from a uniform distribution on [0, 1].

When fitting the latent circuit model, we found some amount of variability in solutions across multiple optimization runs with different initialization. To control for this variability, we fitted a large ensemble of latent circuit models (n = 100) with different initialization of the parameters for latent connectivity and embedding Q for each RNN model. This ensemble of latent circuit models for a single RNN has variable fit quality because many optimization runs do not converge to the optimal solution (which is typical for nonconvex optimization). Therefore, we selected the best ten latent circuit models from this ensemble in terms of fit quality on held out test data, which formed a set of converged solutions (Fig. 5a). To quantify the uniqueness of the latent circuit solution in each RNN, we computed the correlation coefficients between the recurrent connectivity weights of the best model and the remaining nine converged models. We can use the correlation coefficient because the identity of each node in the latent circuit is defined by its input and output connectivity, eliminating any permutation symmetries.

Testing the dependence of latent connectivity on neural responses

To determine whether the inferred latent circuit connectivity significantly depends on neural response data beyond the constraints imposed by the task alone, we performed a permutation test (Extended Data Figs. 1, 7c and 8c), which proceeds in three steps. First, we fit N latent circuit models to neural responses and select the best model in terms of fit quality on held out test data. We then compute the correlation coefficients between recurrent connectivity of the best model and all other models. The distribution of these correlation coefficients estimates how variable the latent connectivity is across models fitted to the original neural responses. Second, we shuffle neural responses N times and fit a latent circuit model to each shuffle, resulting in N latent circuit models. Our shuffling procedure randomly permutes neural responses with respect to trial conditions while preserving the input–output relationship on each trial so that the fitted latent circuit models can still perform the task. We confirm that the latent circuit models fitted to the shuffled RNN responses perform the task at high accuracy (Extended Data Fig. 1). These latent models serve as a control to assess whether the inferred latent connectivity emerges merely from the task constraints alone, and they should not be viewed as models of any specific high-dimensional network. We then compute the correlation coefficients between the connectivity of all models fitted to the shuffled data and the best model from the original data fit. Third, we use a Mann–Whitney U-test to determine whether the correlation coefficients are significantly smaller for models fitted to the shuffled responses than original neural responses. This outcome would indicate that models fitted to shuffled neural responses use more diverse connectivity to perform the task than models fitted to the original data; thus, neural responses significantly constrain the inferred connectivity above the effects of the task. We used the same test for both RNN (N = 500; Extended Data Fig. 1) and PFC data (N = 1,000; Extended Data Figs. 7c and 8c).

Relationship between connectivity of the RNN and latent circuit

We consider RNNs of the form

$$\tau \dot=-y+_}}\,y+_}}u\right]}_.$$

(10)

Here, [⋅]+ is a rectified linear (ReLU) activation function, τ is a time constant, and u are external task inputs. Wrec and Win are the recurrent and input connectivity matrices, respectively. We read out a set of task outputs z from the network activity via the output connectivity matrix Wout,

We derive a relationship between the connectivity matrices of the RNN and latent circuit, which allows us to interpret the latent circuit connectivity as a latent connectivity structure in the RNN. To derive this relationship, we differentiate the embedding Eq. (1) with respect to time and obtain the relationship between the vector fields of the RNN and latent circuit,

Here, the vector fields \(\dot=V(\,y)\) of the RNN and \(\dot=v(x)\) of the latent circuit are given by Eqs. (10) and (2), respectively. This equation states that the subspace spanned by the columns of Q is an invariant subspace of the high-dimensional system; that is, the vector field at any point in this subspace lies entirely in this subspace. We then use the orthonormality condition QTQ = I to obtain

Substituting the vector fields Eq. (10) and Eq. (2) in this relation gives us the equality

$$^_}}Qx+_}}u]}_=_}}x+_}}u]}_.$$

(12)

Because this is an equality of two piecewise-linear systems, it holds for each local linear piece individually. In particular, assuming that both the inputs u and Win are positive, we take x sufficiently near 0, where the argument of the nonlinearity is positive for all units. In this local linear piece, we have the equality of linear systems

$$^_}}Qx+^_}}u=_}}x+_}}u.$$

(13)

If we assume that this equality holds in some open set, then we can equate the terms to obtain an equality of connectivity matrices:

This assumption is likely not fully satisfied in the setting of cognitive tasks because the sets of inputs u and latent states x are typically low dimensional. Therefore, the above equalities may hold only approximately. In addition, the equality of piecewise-linear dynamical systems Eq. (12) depends on the correspondence between trajectories of the RNN and latent circuit Eq. (1). Because, in practice, we search for the latent circuit by minimizing the loss function L, if L is not exactly equal to 0, then Eq. (1) and consequently Eqs. (14) and (15) hold only approximately.

We derived the analytical relations between connectivity in the latent circuit and RNN (Eqs. (14) and (15)) assuming that the latent circuit provides a good fit of RNN responses and that their dynamical equations (Eqs. (2) and (10)) have the same nonlinearity. In general, it is unclear whether a latent circuit model can satisfactorily fit responses of a high-dimensional network that has a different nonlinearity and to what extent the relation between their connectivity will hold in this case. To test whether our results extend to networks with a different biologically plausible nonlinearity, we trained RNNs that had a Softplus activation function \(f(x)=\frac\log (1+^)\) for a range of parameter β and also with varying gain g across units. We fitted responses of these RNNs with our latent circuit model that had a rectified linear (ReLU) activation function and found that this architecture mismatch did not significantly affect the fit quality and the relationship between connectivity (Extended Data Figs. 5 and 6).

To understand how perturbations of connectivity in the latent circuit map onto the RNN, we view perturbations as vectors in the space of matrices. We denote A ⋅ B the dot product between the matrices A and B represented as vectors in the space of matrices; that is, A ⋅ B = ∑i∑jAijBij. Using Eqs. (14) and (15), we then translate connectivity perturbations from the latent circuit to the RNN:

$$_=w\cdot _=^WQ)}_$$

(16)

$$=\mathop\limits_^_\left(\mathop\limits_^__\right)$$

(17)

$$=\mathop\limits_^\mathop\limits_^___$$

(18)

$$=\mathop\limits_^\mathop\limits_^___^)}_$$

(19)

where qi is the ith column of Q, and ei is the ith standard unit vector. This chain of equalities shows how to translate perturbations of the latent circuit connectivity in the direction δji onto rank-one connectivity perturbations in the RNN,

$$w\cdot _=W\cdot Q_^.$$

(23)

Thus, to perturb the latent connection wji, we perturb the matrix W in the direction QδjiQT. In other words, to increase the dot product between W and QδjiQT in the space of matrices, we add multiples of QδjiQT to W. Any perturbation orthogonal to QδjiQT does not change the dot product and hence has no effect on the latent connection wji.

RNN simulations

We simulate dynamics of time-discretized RNNs using the general framework for modeling cognitive tasks22. We consider RNNs with positive activity and N = 50 recurrent units. We obtained the same results with networks consisting of N = 150 units. We discretize the RNN dynamics Eq. (10) using the first-order Euler scheme with a time step Δt and add a noise term to obtain

$$_=(1-\alpha )_+\alpha _}}_+_}}_+\sqrt}_}}_\right]}_.$$

(24)

Here, α = Δt/τ and \(_ \sim }(0,1)\) is a random variable sampled from the standard normal distribution. We set the time constant τ = 200 ms, the discretization time step Δt = 40 ms, and the noise magnitude σrec = 0.15. When fitting RNN responses with the latent circuit model, we discretize the latent circuit dynamics Eq. (2) using the same hyperparameter α and the same noise magnitude as was used when training the RNN. The input and output matrices are constrained to have positive entries. The recurrent matrix is constrained to satisfy Dale’s law with 80% excitatory units and 20% inhibitory units. For RNNs shown in the main text, the concatenation of input and output matrices is constrained to be orthogonal. However, our conclusions do not depend on this constraint, and we find similar latent circuit fits and the inhibitory mechanism in RNNs trained with unconstrained inputs (Supplementary Fig. 6). The RNN simulation and training were implemented in Python using the software package PyTorch.

Context-dependent decision-making task

In the context-dependent decision-making task, at the beginning of each trial, a context cue briefly appears to indicate either the color or motion context for the current trial. After a short delay, a sensory stimulus appears that consists of motion and color features. The right motion and red color are associated with the right choice, and the left motion and green color are associated with the left choice. The strength of motion and color stimuli varies from trial to trial as quantified by the motion and color coherence. In the color context, the choice should be made according to the color, ignoring the motion stimulus, and vice versa in the motion context.

To model the context-dependent decision-making task, the network receives six inputs u corresponding to two context cues (um: motion context; uc: color context) and sensory evidence streams for motion (um,L: motion left; um,R: motion right) and color (uc,R: color red; uc,G: color green). The network has two outputs, z1 and z2, for which we define two targets ztarget,1 and ztarget,2. Each trial begins with a presentation of a context cue from t = 320 to t = 1, 000 ms. On motion context trials, the cue input is set to um = 1.2 and uc = 0.2 and vice versa on color context trials. During this epoch, we require that the network does not respond on the outputs by setting ztarget = 0.2. After a delay of 200 ms, so that the network must maintain a memory of the context cue, the inputs corresponding to motion and color sensory evidence are presented at t = 1,200 ms for the remaining duration of the trial. From 2,250 ms after the start of the trial and extending to the end of the trial, the targets are defined by ztarget,1 = 1.2 and ztarget,2 = 0.2 for right choices and vice versa for left choices. The strength of sensory evidence for motion and color varies randomly from trial to trial controlled by the stimulus coherence. We use motion coherence mc and color coherence cc ranging from −0.2 to 0.2 chosen from the set . For each coherence level, the motion and color inputs are given by

$$\begin_},}}=\frac_}}},&&\;\;_},}}=\frac_}}},\\ _},}}=\frac_}}},&&\;\;_},}}=\frac_}}}.\end$$

With these definitions, positive motion and color coherence provide evidence for the right choice, and negative motion and color coherence provide evidence for the left choice. At each simulation time step, we add an independent noise term to each of the inputs \(_}}=\sqrt^}_}}_\), where \(_ \sim }(0,1)\) is a random variable sampled from the standard normal distribution. The input noise strength is σin = 0.01. A baseline input u0 = 0.2 is added to each of the inputs at each time step.

RNN training

To train the RNN, we minimize the mean squared error between the output z(t) and the target ztarget(t):

$$}:= \sum __-_},ikt})}^+_\sum __^.$$

(25)

Here, k is the trial number, t is the time step within a trial, zikt is the ith output on trial k and time t, and yikt is the response of the ith RNN unit on trial k at time t. The first term is the task error, and the second term serves to regularize by penalizing the magnitude of the firing rates. To encourage the network to integrate sensory evidence over time and to not output responses during the context cue, these task errors are only penalized in the last 750 ms of each trial and during the presentation of the contextual cue. The training is performed with the Adam algorithm. We used the default values 0.9 and 0.999 for the decay rate of the first and second moment estimates, respectively. We used a learning rate of 0.01 and a weight decay of 0.001 and set the hyperparameter λr = 0.05.

We control the degree of correlation between the input and output vectors in the RNN by adding an L2 penalty

$$_}}\parallel ^B-\,\text\,(^B)_$$

(26)

to the loss function in Eq. (25) during training. Here, B is the matrix corresponding to the concatenation of Win and \(_}}^\) along their second dimension, with columns normalized to unit length. The hyperparameter λorth controls the penalty weight. For RNNs in the main text, we set λorth = 1, which results in nearly orthogonal input vectors (Supplementary Fig. 6). We fit responses of these RNNs with latent circuit models in which the matrix B is constrained to be diagonal during fitting by setting off-diagonal elements to 0 after each gradient update. By setting λorth to a smaller value during RNN training, the input vectors in the trained RNN become slightly correlated (Supplementary Fig. 6). To test the effect of these correlations in the latent circuit model, we add the penalty Eq. (26) to the loss function Eq. (7) during latent circuit fitting (Supplementary Fig. 6). These correlations can be captured in the latent circuit model fitted with smaller values of the corresponding λorth hyperparameter. Allowing for these input correlations in RNNs and the latent circuit does not have a strong effect on either fits or the underlying circuit mechanism (Supplementary Fig. 6).

The recurrent connection matrix Wrec is initialized so that excitatory connections are independent Gaussian random variables with mean \(1/\sqrt\) and variance 1/N. Inhibitory connections are initialized with mean \(4/\sqrt\) and variance 1/N. The matrix is then scaled so that its spectral radius is 1.5. To implement Dale’s law, connections are clipped to 0 after each training step if they change sign. During training, we used minibatches of 128 trials with 1,800 trials total.

To assess performance, a choice for the RNN was defined as the sign of the difference between output units at the end of the trial. Psychometric functions were then computed as the percentage of choices to the right for each combination of context, motion coherence and color coherence.

Linear decoding

To decode motion coherence from RNN responses, we fit a linear regression model

where \(\beta \in }}^\) is the vector of regression coefficients, \(c\in }}^\) is the motion coherence on each trial, \(b\in }\) is a bias term, and \(y\in }}^\) is the RNN responses at each time step during the stimulus epoch of each trial. Here, K is the number of trials, and T is the number of time points within a trial. We split the data into training and test sets and fit the model on the training set. There was no large difference between training and test scores (r2 = 0.535 and r2 = 0.531), suggesting that the model did not overfit. After fitting, we used the vector of regression coefficients β to define the decoder axis on which we project RNN responses.

Analysis of PFC data

We analyzed a publicly available dataset of neural activity recordings from the PFC (in and around the frontal eye field) from two monkeys performing a context-dependent decision-making task19. This dataset consisted of 762 units from monkey A and 640 units from monkey F (including single neurons and multiunits). To facilitate comparison with previous studies analyzing the same dataset19,21, we used identical initial preprocessing of the neural data (using the publicly available code at https://www.ini.uzh.ch/en/research/groups/mante/data.html). Because stimulus coherence levels varied across monkeys and days, to equate performance in the motion and color contexts, we replaced the coherences on each trial with their average values for each stimulus difficulty (average motion coherences: 0.05, 0.15 and 0.50 in monkey A and 0.07, 0.19 and 0.54 in monkey F; average color coherences: 0.06, 0.18 and 0.50 in monkey A and 0.12, 0.30 and 0.75 in monkey F). Monkeys reported their choice with a saccade to one of two targets presented shortly after fixation for the entire trial duration. The monkeys were rewarded for saccades to the target location corresponding to the motion direction in the motion context and to the target whose color matched the dominant color of the dots in the color context. The stimulus coherence was assigned a sign (positive or negative) according to the target location indicated by the stimulus. Because the color of the targets was randomized between locations on each trial, the sign of the color coherence reflects both the dominant color of the dots and the location of the red and green targets. The task therefore had 72 unique stimulus conditions defined by all combinations of six motion coherence levels, six color coherence levels and two contexts.

We fitted the latent circuit model to trial-averaged neural responses on correct trials. In our analyses, we included neurons that had at least four correct trials for each of the 72 unique trial conditions, which produced 483 neurons for monkey A and 323 neurons for monkey F. For cross-validation, we then split the trials into two equal disjoint sets and computed the trial-averaged response of each neuron for each trial condition within each set. We used the training set for model fitting and the validation set for visualizing projections of neural responses and quantifying the fit quality. For the analysis of error trials (Extended Data Figs. 7d and 8d), we considered the set of error trial conditions for which all analyzed neurons had at least one trial, which resulted in 16 conditions for monkey A and 26 conditions for monkey F. We then computed the trial-averaged response of each neuron for each trial condition within this set of error trials.

We analyzed neural responses during the presentation of the random dots stimulus because the available data consisted of neural responses starting at 100 ms after stimulus onset for a duration of 750 ms. For each trial, we computed time-varying firing rates by counting spikes in a 50-ms sliding square window (50-ms steps). The first window was centered at 100 ms after the onset of the stimulus, and the last window was centered at 100 ms after stimulus offset. Within the training and test sets, we z scored and smoothed (Gaussian kernel, σ = 40 ms) the response of each unit. Following previous studies19, from activity of each unit we subtracted a condition-independent term corresponding to the mean response at each time across trial conditions. To construct population responses, we combined the single-neuron responses for each trial condition. This resulted in 72 neural trajectories for each combination of context, motion coherence and color coherence. Last, to denoise these trajectories, we projected them onto the principal components explaining 50% of their total variance (corresponding to the first 40 and 31 principal components for monkeys A and F, respectively).

We fitted latent circuit models to the PFC data following similar procedures as for RNNs. For each of the 72 conditions, we constructed input to the latent circuit from the context, motion and color coherence corresponding to that condition. In the experimental task, the stimulus is presented 650 ms after the context cue for a duration of 750 ms. Neural recordings correspond to 100 ms after stimulus onset to 100 ms after stimulus offset. We thus modeled the task with 150 time steps (10 ms in length) extending from the initial presentation of the contextual cue to 100 ms after stimulus offset. Contextual input was given to the model from t = 0 to t = 1,500 ms. Stimulus input was given to the model from t = 750 ms to t = 1,500 ms. We constructed two target outputs (z1 and z2) for each trial such that on trials for which the monkey chose the right target, the first target output was high (z1 = 1.2) and the second target output was low (z2 = 0.2) and vice versa for the left choice trials. We penalized errors between target and model outputs only in the last 250 ms of each trial. Responses of the latent circuit were fitted to the PFC data only on the last 15 time steps of each trajectory for which there were available PFC data. The latent circuit model was fitted with hyperparameter α = 0.2. The latent circuit model was fitted with a recurrent noise term of magnitude σrec = 0.15, which was added to each unit at each time step (Eq. (24)). Because neural responses were centered, we additionally fit an intercept term b so that the resulting model for PFC data was

$$\tau \dot=-x+_}}x+_}}u]}_,\,}$$

(29)

Because of high dimensionality of PFC responses (40 and 31 principal components are required to account for ~50% of the total variance in PFC activity for monkeys A and F, respectively), we find a notable tradeoff between the task fit and data fit when fitting the low-dimensional latent circuit model to the PFC data. To control this tradeoff, we used a modified loss function when fitting PFC data,

$$L=\sum _\sum _\lambda \frac^y_}_}+\frac^(\,y-b)-x_}^y_}+\frac_}}x_}_}$$

(31)

$$=\lambda _^+_^+_^,$$

(32)

designed to balance variance explained by the task-relevant subspace \(_^\), the fits between projected PFC responses and the latent circuit trajectories in this subspace \(_^\) and the performance of the latent circuit on the task \(_^\). The hyperparameter λ = 0.5 was chosen via a grid search over the range λ ∈ [0, 1.5]. We found that near this value, the metrics \(_^\) and \(_^\) were maximized under the constraint that the latent model still performed well on the task (Extended Data Figs. 7a and 8a).

Statistics and reproducibility

We analyzed data from 200 RNN models trained with random initializations. Results were consistent across networks; therefore, we found this sample size to be sufficient for our study. No statistical method was used to predetermine sample size. For each of these networks, we trained 100 latent circuit models. This sample size was chosen so that the top ten latent models converged to a high fit accuracy. For PFC data, we fitted 200 latent circuit models to neural responses from each monkey. Neural recording data were previously described in Mante et al.19; no randomization or blinding was performed because there was only one experimental group. All recorded units that had at least four correct trials in each task condition were included in the analysis.

Reporting summary

Further information on research design is available in the Nature Portfolio Reporting Summary linked to this article.

Comments (0)

No login
gif