Skip to content

SIMPL

Tests Colab demo Paper

SIMPL is a JAX-based Python package for jointly decoding latent neural representations and optimising tuning curves from spike data. It uses an EM algorithm alternating between Kalman-smoothed decoding and kernel density estimation. Published at ICLR 2025.

✨ Key Features

  • Fast — fits 100 neurons over 1 hour of data in under 10 seconds on CPU. GPU optional but rarely needed.
  • 🎯 Simple — scikit-learn-style fit() / predict() API. Minimal intuitive hyperparameters. Get started in <10 lines of code.
  • 🧠 Flexible — works with 1D angular data (e.g. head direction), 2D spatial data (e.g. place/grid cells), and higher dimensions. Examples and demo provided.
  • 📊 Rich outputs — results stored as xarray.Dataset with per-iteration metrics, units, baselines, and diagnostics.
  • 📈 Visual — built-in plotting for trajectories, receptive fields, spike rasters, and fitting summaries.


Neural data analysis in < 5 seconds

🚀 Installation

pip install simpl-neuro

To access the demo notebook:

pip install "simpl-neuro[demos]"
simpl demo                # downloads the demo notebook into the current directory

🔧 API

SIMPL follows sklearn conventions: configure hyperparameters at init, pass data to fit().

from simpl import SIMPL

# 1. Configure the model (no data, no computation)
model = SIMPL(
    speed_prior=0.4,        # prior on agent speed (m/s) — controls Kalman smoothing
    kernel_bandwidth=0.02,  # KDE bandwidth for fitting receptive fields
    bin_size=0.02,          # spatial bin size for environment discretisation
    env_pad=0.0,            # padding around data bounds
)

# 2. Fit
model.fit(
    Y,                      # spike counts (T, N_neurons)
    Xb,                     # behavioural initialisation positions (T, D)
    time,                   # timestamps (T,)
    n_iterations=5,
    )

# 3. Access results
model.X_           # final decoded latent positions, shape (T, D)
model.F_           # final receptive fields, shape (N_neurons, *env_dims)
model.results_     # full xarray.Dataset with metrics, likelihoods, and baselines, across iterations.

# 4. Plot results 
model.plot_fitting_summary()  # Shows bits-per-spike metric and spike-latent mutual information. 

# (optional) Resume training if not yet converged
model.fit(Y, Xb, time, n_iterations=5, resume=True)

See the Getting Started guide for a full walkthrough, or jump to the API Reference.

📝 Cite

If you use SIMPL in your work, please cite it as:

Tom George, Pierre Glaser, Kim Stachenfeld, Caswell Barry, & Claudia Clopath (2025). SIMPL: Scalable and hassle-free optimisation of neural representations from behaviour. In The Thirteenth International Conference on Learning Representations.

@inproceedings{
    george2025simpl,
    title={{SIMPL}: Scalable and hassle-free optimisation of neural representations from behaviour},
    author={Tom George and Pierre Glaser and Kim Stachenfeld and Caswell Barry and Claudia Clopath},
    booktitle={The Thirteenth International Conference on Learning Representations},
    year={2025},
    url={https://openreview.net/forum?id=9kFaNwX6rv}
}