Skip to content

Utilities

Gaussian Helpers

Calculates the multivariate Gaussian PDF at x.

\[\mathcal{N}(x \mid \mu, \Sigma) = \frac{1}{\sqrt{(2\pi)^D |\Sigma|}} \exp\!\left(-\frac{1}{2}(x - \mu)^\top \Sigma^{-1} (x - \mu)\right)\]

Parameters:

Name Type Description Default
x Array

The position at which to evaluate the pdf

required
mu Array

The mean of the distribution

required
sigma Array

The covariance of the distribution

required

Returns:

Name Type Description
pdf float

The probability density at x

Source code in src/simpl/utils.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def gaussian_pdf(
    x: jax.Array,
    mu: jax.Array,
    sigma: jax.Array,
) -> jax.Array:
    """Calculates the multivariate Gaussian PDF at x.

    $$\\mathcal{N}(x \\mid \\mu, \\Sigma) = \\frac{1}{\\sqrt{(2\\pi)^D |\\Sigma|}}
    \\exp\\!\\left(-\\frac{1}{2}(x - \\mu)^\\top \\Sigma^{-1} (x - \\mu)\\right)$$

    Parameters
    ----------

    x: (D,) array
        The position at which to evaluate the pdf
    mu: (D,) array
        The mean of the distribution
    sigma: (D, D) array
        The covariance of the distribution

    Returns
    -------
    pdf: float
        The probability density at x
    """
    assert x.ndim == 1
    assert mu.ndim == 1
    assert sigma.ndim == 2
    assert x.shape[0] == mu.shape[0]
    assert x.shape[0] == sigma.shape[0]
    assert sigma.shape[0] == sigma.shape[1]

    x = x - mu
    norm_const = gaussian_norm_const(sigma)
    return norm_const * jnp.exp(-0.5 * jnp.sum(x @ jnp.linalg.inv(sigma) * x, axis=-1))

Calculates the log of the multivariate Gaussian PDF at x.

\[\log \mathcal{N}(x \mid \mu, \Sigma) = -\frac{D}{2}\log(2\pi) - \frac{1}{2}\log|\Sigma| - \frac{1}{2}(x - \mu)^\top \Sigma^{-1} (x - \mu)\]

Parameters:

Name Type Description Default
x Array

The position at which to evaluate the pdf

required
mu Array

The mean of the distribution

required
sigma Array

The covariance of the distribution

required

Returns:

Name Type Description
log_pdf float

The log probability density at x

Source code in src/simpl/utils.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def log_gaussian_pdf(
    x: jax.Array,
    mu: jax.Array,
    sigma: jax.Array,
) -> jax.Array:
    """Calculates the log of the multivariate Gaussian PDF at x.

    $$\\log \\mathcal{N}(x \\mid \\mu, \\Sigma) = -\\frac{D}{2}\\log(2\\pi)
    - \\frac{1}{2}\\log|\\Sigma|
    - \\frac{1}{2}(x - \\mu)^\\top \\Sigma^{-1} (x - \\mu)$$

    Parameters
    ----------
    x: (D,) array
        The position at which to evaluate the pdf
    mu: (D,) array
        The mean of the distribution
    sigma: (D, D) array
        The covariance of the distribution

    Returns
    -------
    log_pdf: float
        The log probability density at x
    """
    assert x.ndim == 1
    assert mu.ndim == 1
    assert sigma.ndim == 2
    assert x.shape[0] == mu.shape[0]
    assert x.shape[0] == sigma.shape[0]
    assert sigma.shape[0] == sigma.shape[1]

    x = x - mu
    norm_const = gaussian_norm_const(sigma)
    return jnp.log(norm_const) - 0.5 * jnp.sum(x @ jnp.linalg.inv(sigma) * x)

Calculates the normalizing constant of a multivariate normal distribution with covariance sigma.

\[Z = \frac{1}{\sqrt{(2\pi)^D |\Sigma|}}\]

Parameters:

Name Type Description Default
sigma Array

The covariance matrix of the distribution

required

Returns:

Name Type Description
norm_const (ndarray, shape(1))

The normalizing constant

Source code in src/simpl/utils.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def gaussian_norm_const(sigma: jax.Array) -> jax.Array:
    """Calculates the normalizing constant of a multivariate normal distribution with covariance sigma.

    $$Z = \\frac{1}{\\sqrt{(2\\pi)^D |\\Sigma|}}$$

    Parameters
    ----------
    sigma: jnp.ndarray, shape (D, D)
        The covariance matrix of the distribution

    Returns
    -------
    norm_const: jnp.ndarray, shape (1,)
        The normalizing constant
    """
    assert sigma.ndim == 2
    D = sigma.shape[0]
    return 1 / jnp.sqrt((2 * jnp.pi) ** D * jnp.linalg.det(sigma))

Gaussian Fitting

Fits a multivariate-Gaussian to each of T likelihood distributions over spatial bins.

For each timestep, computes the weighted mean, mode, and covariance of the spatial bin coordinates x under the likelihood weights:

\[\mu_t = \frac{\sum_i x_i \, p_{t,i}}{\sum_i p_{t,i}}, \qquad \Sigma_t = \mathbb{E}_t[x x^\top] - \mu_t \mu_t^\top\]

The covariance uses the identity Cov = E[xx^T] - mu mu^T. This lets us precompute x x^T once as a small (N_bins, D, D) array and contract it with the (T, N_bins) likelihoods via a single einsum, rather than materialising a (T, N_bins, D) intermediate as the naive formula would.

The function is JIT-compiled so the XLA computation is traced once and reused on subsequent calls.

Parameters:

Name Type Description Default
x (ndarray, shape(N_bins, D))

The position bins (shared across all time steps).

required
likelihoods (ndarray, shape(T, N_bins))

Likelihood values (not log) at each bin for each time step.

required

Returns:

Name Type Description
means (ndarray, shape(T, D))

The weighted mean position at each time step.

modes (ndarray, shape(T, D))

The bin coordinate with the highest likelihood at each time step.

covariances (ndarray, shape(T, D, D))

The weighted covariance at each time step.

Source code in src/simpl/utils.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
@jax.jit
def fit_gaussian(x: jax.Array, likelihoods: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]:
    """Fits a multivariate-Gaussian to each of T likelihood distributions over spatial bins.

    For each timestep, computes the weighted mean, mode, and covariance of the
    spatial bin coordinates ``x`` under the likelihood weights:

    $$\\mu_t = \\frac{\\sum_i x_i \\, p_{t,i}}{\\sum_i p_{t,i}}, \\qquad
    \\Sigma_t = \\mathbb{E}_t[x x^\\top] - \\mu_t \\mu_t^\\top$$

    The covariance uses the identity ``Cov = E[xx^T] - mu mu^T``.  This lets us
    precompute ``x x^T`` once as a small ``(N_bins, D, D)`` array and contract it
    with the ``(T, N_bins)`` likelihoods via a single einsum, rather than
    materialising a ``(T, N_bins, D)`` intermediate as the naive formula would.

    The function is JIT-compiled so the XLA computation is traced once and
    reused on subsequent calls.

    Parameters
    ----------
    x : jnp.ndarray, shape (N_bins, D)
        The position bins (shared across all time steps).
    likelihoods : jnp.ndarray, shape (T, N_bins)
        Likelihood values (not log) at each bin for each time step.

    Returns
    -------
    means : jnp.ndarray, shape (T, D)
        The weighted mean position at each time step.
    modes : jnp.ndarray, shape (T, D)
        The bin coordinate with the highest likelihood at each time step.
    covariances : jnp.ndarray, shape (T, D, D)
        The weighted covariance at each time step.
    """
    sums = likelihoods.sum(axis=1)  # (T,)

    # Mean: weighted average via matmul
    mu = (likelihoods @ x) / sums[:, None]  # (T, D)

    # Mode: position of max likelihood
    mode = x[jnp.argmax(likelihoods, axis=1)]  # (T, D)

    # Covariance: E[xx^T] - mu mu^T (avoids (T, N_bins, D) intermediate)
    x_outer = x[:, :, None] * x[:, None, :]  # (N_bins, D, D)
    E_xxT = jnp.einsum("tb,bij->tij", likelihoods, x_outer) / sums[:, None, None]  # (T, D, D)
    cov = E_xxT - mu[:, :, None] * mu[:, None, :]  # (T, D, D)

    return mu, mode, cov

Samples from a multivariate normal distribution with mean mu and covariance sigma.

Parameters:

Name Type Description Default
key PRNGKey

The random key

required
mu (ndarray, shape(D))

The mean of the distribution

required
sigma (ndarray, shape(D, D))

The covariance of the distribution

required

Returns:

Name Type Description
sample (ndarray, shape(D))

The sample

Source code in src/simpl/utils.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def gaussian_sample(key: jax.Array, mu: jax.Array, sigma: jax.Array) -> jax.Array:
    """Samples from a multivariate normal distribution with mean mu and covariance sigma.

    Parameters
    ----------
    key : PRNGKey
        The random key
    mu : jnp.ndarray, shape (D,)
        The mean of the distribution
    sigma : jnp.ndarray, shape (D, D)
        The covariance of the distribution

    Returns
    -------
    sample : jnp.ndarray, shape (D,)
        The sample
    """
    assert mu.ndim == 1
    assert sigma.ndim == 2
    sample = random.multivariate_normal(key, mu, sigma)
    return sample

Statistical and Analysis Helpers

Calculates the coefficient of determination (\(R^2\)) between X and Y.

This reflects the proportion of the variance in Y that is predictable from X.

\[R^2 = 1 - \frac{SS_{\textrm{res}}}{SS_{\textrm{tot}}} = 1 - \frac{\sum_i (Y_i - X_i)^2}{\sum_i (Y_i - \bar{Y})^2}\]

Parameters:

Name Type Description Default
X (ndarray, shape(N, D))

The predicted latent positions

required
Y (ndarray, shape(N, D))

The true latent positions

required

Returns:

Name Type Description
R2 (Array, scalar)

The coefficient of determination. 1.0 indicates a perfect prediction; 0.0 indicates the model explains no more variance than the mean of Y; negative values indicate worse-than-mean predictions.

Source code in src/simpl/utils.py
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
def coefficient_of_determination(
    X: jax.Array,
    Y: jax.Array,
) -> jax.Array:
    """Calculates the coefficient of determination (\\(R^2\\)) between X and Y.

    This reflects the proportion of the variance in Y that is predictable from X.

    $$R^2 = 1 - \\frac{SS_{\\textrm{res}}}{SS_{\\textrm{tot}}}
    = 1 - \\frac{\\sum_i (Y_i - X_i)^2}{\\sum_i (Y_i - \\bar{Y})^2}$$

    Parameters
    ----------
    X : jnp.ndarray, shape (N, D)
        The predicted latent positions
    Y : jnp.ndarray, shape (N, D)
        The true latent positions

    Returns
    -------
    R2 : jax.Array, scalar
        The coefficient of determination.  1.0 indicates a perfect prediction;
        0.0 indicates the model explains no more variance than the mean of *Y*;
        negative values indicate worse-than-mean predictions."""
    assert X.shape == Y.shape, "The predicted and true latent positions must have the same shape."
    SST = jnp.sum((Y - jnp.mean(Y, axis=0)) ** 2)
    SSR = jnp.sum((Y - X) ** 2)
    R2 = 1 - SSR / SST
    return R2

Uses canonical correlation between X and Y (the "target") to establish the best linear mapping from X to Y.

Parameters:

Name Type Description Default
X (ndarray, shape(N, D))

The inputs

required
Y (ndarray, shape(N, D))

The targets

required
Returns:

coef : jnp.ndarray, shape (D, D) The coefficients of the linear mapping from X to Y such that Y ~= Y_pred = X @ coef.T + intercept intercept : jnp.ndarray, shape (D,) The intercept of the linear mapping from X to Y such that Y ~= Y_pred = X @ coef.T + intercept

Source code in src/simpl/utils.py
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
def cca(X: jax.Array, Y: jax.Array) -> tuple[np.ndarray, np.ndarray]:
    """Uses canonical correlation between X and Y (the "target") to establish the best linear mapping from X to Y.

    Parameters
    ----------
    X : jnp.ndarray, shape (N, D)
        The inputs
    Y : jnp.ndarray, shape (N, D)
        The targets

    Returns:
    -------
    coef : jnp.ndarray, shape (D, D)
        The coefficients of the linear mapping from X to Y such that Y ~= Y_pred = X @ coef.T + intercept
    intercept : jnp.ndarray, shape (D,)
        The intercept of the linear mapping from X to Y such that Y ~= Y_pred = X @ coef.T + intercept
    """
    assert X.shape == Y.shape, "The predicted and true latent positions must have the same shape."
    D = X.shape[1]

    cca = sklearn.cross_decomposition.CCA(n_components=D, max_iter=2000)
    cca.fit(X, Y)
    coef = cca.coef_  # / cca._x_std # this randomly changed at some point
    intercept = cca.intercept_ - cca._x_mean @ coef.T
    return coef, intercept

Align 1D circular trajectories by a pure rotation (no scaling).

Searches rotation angles in [-pi, pi) and returns the angle that minimises mean squared wrapped angular error. Unlike cca, this only performs a rotation (no shift or scaling), which is the correct transform for angular data.

Parameters:

Name Type Description Default
X (ndarray, shape(N, 1) or (N,))

Source trajectory in radians.

required
Y (ndarray, shape(N, 1) or (N,))

Target trajectory in radians.

required
n_angles int

Number of candidate angles in [-pi, pi), by default 360.

360

Returns:

Name Type Description
best_angle (ndarray, shape())

Rotation angle (radians) that minimises circular error.

best_error (ndarray, shape())

Minimum mean squared wrapped angular error.

Source code in src/simpl/utils.py
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
def cca_angular(
    X: jax.Array,
    Y: jax.Array,
    n_angles: int = 360,
) -> tuple[jax.Array, jax.Array]:
    """Align 1D circular trajectories by a pure rotation (no scaling).

    Searches rotation angles in [-pi, pi) and returns the angle that minimises
    mean squared wrapped angular error. Unlike ``cca``, this only performs a rotation
    (no shift or scaling), which is the correct transform for angular data.

    Parameters
    ----------
    X : jnp.ndarray, shape (N, 1) or (N,)
        Source trajectory in radians.
    Y : jnp.ndarray, shape (N, 1) or (N,)
        Target trajectory in radians.
    n_angles : int, optional
        Number of candidate angles in [-pi, pi), by default 360.

    Returns
    -------
    best_angle : jnp.ndarray, shape ()
        Rotation angle (radians) that minimises circular error.
    best_error : jnp.ndarray, shape ()
        Minimum mean squared wrapped angular error.
    """
    X = jnp.asarray(X).reshape(-1)
    Y = jnp.asarray(Y).reshape(-1)
    assert X.shape == Y.shape, "The predicted and target circular trajectories must have the same shape."

    angles = jnp.linspace(-jnp.pi, jnp.pi, n_angles, endpoint=False)
    diffs = _wrap_minuspi_pi(X[:, None] + angles[None, :] - Y[:, None])
    errs = jnp.mean(diffs**2, axis=0)
    idx = jnp.argmin(errs)
    best_angle = angles[idx]
    return best_angle, errs[idx]

Calculates the correlation between X1 and X2[lag:].

If X is D-dimensional, calculates the average correlation across dimensions.

Parameters:

Name Type Description Default
X1 (ndarray, shape(T, D))

The first time series - remains fixed

required
X2 (ndarray, shape(T, D))

The second time series

required
lag int

The lag to calculate the correlation at

required

Returns:

Type Description
float

The average correlation across dimensions

Source code in src/simpl/utils.py
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
def correlation_at_lag(X1: jax.Array, X2: jax.Array, lag: int) -> jax.Array:
    """Calculates the correlation between X1 and X2[lag:].

    If X is D-dimensional, calculates the average correlation across dimensions.

    Parameters
    ----------

    X1 : jnp.ndarray, shape (T, D)
        The first time series - remains fixed
    X2 : jnp.ndarray, shape (T, D)
        The second time series
    lag : int
        The lag to calculate the correlation at

    Returns
    -------
    float
        The average correlation across dimensions
    """
    T, D = X1.shape
    if lag >= 0:
        X2 = X2[lag:, :]
        X1 = X1[: T - lag]
    else:
        lag = -lag
        X1 = X1[lag:, :]
        X2 = X2[: T - lag]
    return jnp.mean(jnp.diag(jnp.corrcoef(X1.T, X2.T)[D:, :D]))

Data Preparation

Causal rolling sum of spikes over a backward-looking window.

Each time bin accumulates spikes from the current and previous window - 1 bins. This is equivalent to smoothing the spikes with a causal rectangular kernel.

Warning

This changes the interpretation of the estimated receptive fields. Since each bin now contains on average window times more spikes, the fitted firing rates (and therefore F) will be approximately window times higher than the true single-bin rates. The receptive field shapes are unaffected, but their amplitudes should not be interpreted as physical firing rates.

Parameters:

Name Type Description Default
Y (ndarray, shape(T, N_neurons))

Spike counts.

required
window int

Number of bins to sum over (looking backwards). For example, window=5 sums the current bin and the 4 preceding bins.

required

Returns:

Name Type Description
Y_accumulated (ndarray, shape(T, N_neurons))

Spike counts after causal rolling sum.

Source code in src/simpl/utils.py
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
def accumulate_spikes(Y: np.ndarray, window: int) -> np.ndarray:
    """Causal rolling sum of spikes over a backward-looking window.

    Each time bin accumulates spikes from the current and previous
    ``window - 1`` bins. This is equivalent to smoothing the spikes with
    a causal rectangular kernel.

    !!! warning

        This changes the interpretation of the estimated receptive fields.
        Since each bin now contains on average ``window`` times more spikes,
        the fitted firing rates (and therefore ``F``) will be approximately
        ``window`` times higher than the true single-bin rates. The receptive
        field *shapes* are unaffected, but their *amplitudes* should not be
        interpreted as physical firing rates.

    Parameters
    ----------
    Y : np.ndarray, shape (T, N_neurons)
        Spike counts.
    window : int
        Number of bins to sum over (looking backwards). For example,
        ``window=5`` sums the current bin and the 4 preceding bins.

    Returns
    -------
    Y_accumulated : np.ndarray, shape (T, N_neurons)
        Spike counts after causal rolling sum.
    """
    Y_out = np.zeros_like(Y)
    for i in range(Y.shape[0]):
        start = max(0, i - window + 1)
        Y_out[i] = Y[start : i + 1].sum(axis=0)
    return Y_out

Coarsen data by averaging over groups of dt_multiplier time bins.

Spikes are summed (not averaged) so that spike counts remain integers. Positions and time are averaged.

Parameters:

Name Type Description Default
Y (ndarray, shape(T, N_neurons))

Spike counts.

required
Xb (ndarray, shape(T, D))

Behavioral positions.

required
time (ndarray, shape(T))

Time stamps.

required
dt_multiplier int

Factor by which to coarsen the data.

required
Xt (ndarray or None, shape(T, D))

Ground truth positions.

None

Returns:

Name Type Description
Y_coarse ndarray

Coarsened spike counts (summed).

Xb_coarse ndarray

Coarsened behavioral positions (averaged).

time_coarse ndarray

Coarsened time stamps (averaged).

Xt_coarse np.ndarray (only if Xt was provided)

Coarsened ground truth positions (averaged).

Source code in src/simpl/utils.py
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
def coarsen_dt(
    Y: np.ndarray,
    Xb: np.ndarray,
    time: np.ndarray,
    dt_multiplier: int,
    Xt: np.ndarray | None = None,
) -> tuple:
    """Coarsen data by averaging over groups of ``dt_multiplier`` time bins.

    Spikes are summed (not averaged) so that spike counts remain integers.
    Positions and time are averaged.

    Parameters
    ----------
    Y : np.ndarray, shape (T, N_neurons)
        Spike counts.
    Xb : np.ndarray, shape (T, D)
        Behavioral positions.
    time : np.ndarray, shape (T,)
        Time stamps.
    dt_multiplier : int
        Factor by which to coarsen the data.
    Xt : np.ndarray or None, shape (T, D), optional
        Ground truth positions.

    Returns
    -------
    Y_coarse : np.ndarray
        Coarsened spike counts (summed).
    Xb_coarse : np.ndarray
        Coarsened behavioral positions (averaged).
    time_coarse : np.ndarray
        Coarsened time stamps (averaged).
    Xt_coarse : np.ndarray (only if Xt was provided)
        Coarsened ground truth positions (averaged).
    """
    T = Y.shape[0]
    T_new = T // dt_multiplier
    T_trim = T_new * dt_multiplier

    Y_coarse = Y[:T_trim].reshape(T_new, dt_multiplier, -1).sum(axis=1)
    Xb_coarse = Xb[:T_trim].reshape(T_new, dt_multiplier, -1).mean(axis=1)
    time_coarse = time[:T_trim].reshape(T_new, dt_multiplier).mean(axis=1)

    if Xt is not None:
        Xt_coarse = Xt[:T_trim].reshape(T_new, dt_multiplier, -1).mean(axis=1)
        return Y_coarse, Xb_coarse, time_coarse, Xt_coarse

    return Y_coarse, Xb_coarse, time_coarse

TODO : Rewrite this in JAX Creates a boolean mask of size size. This mask is all True except along each column randomly there are contiguous blocks of False of length block_size. Overall ~sparsity of the mask is False. For example, if sparsity is 0.3, block size is 3 and size is (4, 15), a valid mask would be:

[[T, T, T, T, T, T, T, T, F, F, F, T, F, F, F, T, T, T, T, T], [T, T, F, F, F, T, T, T, T, T, T, T, T, T, T, T, F, F, F, T], [T, T, T, T, T, T, T, T, T, F, F, F, T, T, F, F, F, T, T, T], [F, F, F, T, T, T, T, T, T, T, T, T, F, F, F, T, T, T, T, T]]

Parameters:

Name Type Description Default
size tuple of int

The dimensions of the mask to create.

required
sparsity float

The fraction of the mask that should be False.

0.1
block_size int

The size of the contiguous False blocks.

10

Returns:

Name Type Description
mask ndarray

A boolean mask with the specified properties.

Source code in src/simpl/utils.py
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
def create_speckled_mask(
    size: tuple[int, int],
    sparsity: float = 0.1,
    block_size: int = 10,
    random_seed: int = 0,
) -> jax.Array:
    """
    TODO : Rewrite this in JAX
    Creates a boolean mask of size `size`. This mask is all True except along each column randomly
    there are contiguous blocks of False of length `block_size`. Overall ~`sparsity`
    of the mask is False. For example, if sparsity is 0.3, block size is 3 and size is
    (4, 15), a valid mask would be:

    [[T, T, T, T, T, T, T, T, F, F, F, T, F, F, F, T, T, T, T, T],
     [T, T, F, F, F, T, T, T, T, T, T, T, T, T, T, T, F, F, F, T],
     [T, T, T, T, T, T, T, T, T, F, F, F, T, T, F, F, F, T, T, T],
     [F, F, F, T, T, T, T, T, T, T, T, T, F, F, F, T, T, T, T, T]]

    Parameters
    ----------
    size : tuple of int
        The dimensions of the mask to create.
    sparsity : float
        The fraction of the mask that should be False.
    block_size : int
        The size of the contiguous False blocks.

    Returns
    -------
    mask : np.ndarray
        A boolean mask with the specified properties.
    """
    if len(size) != 2 or size[0] <= 0 or size[1] <= 0:
        raise ValueError(f"size must be a pair of positive integers, got {size}")
    if not 0 <= sparsity <= 1:
        raise ValueError(f"sparsity must be between 0 and 1 (inclusive), got {sparsity}")
    if block_size < 0:
        raise ValueError(f"block_size cannot be negative, got {block_size}")
    if block_size == 0:
        return jnp.ones(size, dtype=bool)
    if block_size >= size[0]:
        raise ValueError(
            f"block_size must be smaller than the time dimension so the mask leaves training data, got {block_size}"
        )

    mask = np.ones(size, dtype=bool)
    num_blocks_per_row = int(sparsity * size[0] / block_size)
    np.random.seed(random_seed)
    for row in range(size[1]):
        for block in range(num_blocks_per_row):
            # Randomly choose starting positions within the bounds
            start_idx = np.random.randint(0, size[0] - block_size)
            end_idx = min(start_idx + block_size, size[0])
            mask[start_idx:end_idx, row] = False
    return jnp.array(mask)

Data I/O

Load a demo data file, downloading from GitHub releases if not cached.

Resolution order (skipped when force_download is True):

  1. User-specified directory — if directory is given, look for <directory>/<name> first.
  2. Local source treeexamples/data/ relative to the package root (available in editable / development installs).
  3. User cache~/.simpl/data/.
  4. Download — fetched from the latest GitHub release and saved to the user cache for next time.

Parameters:

Name Type Description Default
name str

Filename to load (e.g. "gridcells_synthetic.npz").

'gridcells_synthetic.npz'
directory str or None

Optional directory to search for name before the default locations.

None
force_download bool

If True, skip local/cache lookups and always download from GitHub, overwriting any cached copy.

False

Returns:

Type Description
NpzFile

The loaded .npz archive.

Raises:

Type Description
ValueError

If name is not one of the available demo data files.

Source code in src/simpl/utils.py
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
def load_demo_data(
    name: str = "gridcells_synthetic.npz",
    directory: str | None = None,
    force_download: bool = False,
) -> np.lib.npyio.NpzFile:
    """Load a demo data file, downloading from GitHub releases if not cached.

    Resolution order (skipped when *force_download* is ``True``):

    1. **User-specified directory** — if *directory* is given, look for
       ``<directory>/<name>`` first.
    2. **Local source tree** — ``examples/data/`` relative to the package root
       (available in editable / development installs).
    3. **User cache** — ``~/.simpl/data/``.
    4. **Download** — fetched from the latest GitHub release and saved to the
       user cache for next time.

    Parameters
    ----------
    name : str
        Filename to load (e.g. ``"gridcells_synthetic.npz"``).
    directory : str or None
        Optional directory to search for *name* before the default locations.
    force_download : bool
        If ``True``, skip local/cache lookups and always download from GitHub,
        overwriting any cached copy.

    Returns
    -------
    np.lib.npyio.NpzFile
        The loaded ``.npz`` archive.

    Raises
    ------
    ValueError
        If *name* is not one of the available demo data files.
    """
    from pathlib import Path

    # 1. Check user-specified directory
    if not force_download and directory is not None:
        dir_path = Path(directory) / name
        if dir_path.is_file():
            print(f"Loaded {name} from user directory: {dir_path}")
            return np.load(dir_path)

    # 2. Check local source tree (editable installs)
    if not force_download:
        local_path = Path(__file__).resolve().parent.parent.parent / "examples" / "data" / name
        if local_path.is_file():
            print(f"Loaded {name} from local source tree: {local_path}")
            return np.load(local_path)

    # 3. Check user cache
    cache_dir = Path.home() / ".simpl" / "data"
    cache_dir.mkdir(parents=True, exist_ok=True)
    cached_path = cache_dir / name

    if not force_download and cached_path.exists():
        print(f"Loaded {name} from cache: {cached_path}")
        return np.load(cached_path)

    # File not found locally — check it's a known release asset before attempting download
    if name not in _AVAILABLE_DEMO_DATA:
        available = ", ".join(f'"{f}"' for f in _AVAILABLE_DEMO_DATA)
        raise FileNotFoundError(
            f'Could not find "{name}" locally and it is not a known release asset. Available for download: {available}'
        )

    # 4. Download from GitHub releases
    import json
    import os
    import sys
    import urllib.request

    def _reporthook(block_num, block_size, total_size):
        if total_size > 0:
            downloaded = block_num * block_size
            pct = min(100, downloaded * 100 // total_size)
            mb_done = downloaded / 1_000_000
            mb_total = total_size / 1_000_000
            print(f"\r  {pct:3d}% ({mb_done:.1f}/{mb_total:.1f} MB)", end="", file=sys.stderr)

    api_url = "https://api.github.com/repos/TomGeorge1234/SIMPL/releases"
    headers = {"Accept": "application/vnd.github+json"}
    token = os.environ.get("GITHUB_TOKEN")
    if token:
        headers["Authorization"] = f"Bearer {token}"
    req = urllib.request.Request(api_url, headers=headers)
    with urllib.request.urlopen(req) as resp:
        releases = json.loads(resp.read())

    download_url = None
    for release in releases:
        for asset in release.get("assets", []):
            if asset["name"] == name:
                download_url = asset["browser_download_url"]
                break
        if download_url:
            break

    if download_url is None:
        raise FileNotFoundError(f'Could not find "{name}" in any GitHub release at {api_url}')

    print(f"Downloading {name} from {download_url} ...", file=sys.stderr)
    try:
        urllib.request.urlretrieve(download_url, cached_path, reporthook=_reporthook)
    except Exception:
        cached_path.unlink(missing_ok=True)
        raise
    print(file=sys.stderr)  # newline after progress
    print(f"Loaded {name} from GitHub (saved to cache: {cached_path})")

    return np.load(cached_path)

Save a SIMPL results xr.Dataset to a netCDF file.

Before writing, the function performs several type conversions required by the netCDF4 format: boolean arrays (e.g. spike_mask) are cast to int32, boolean attrs are cast to int, and trial_slices (a list of Python slice objects) is serialised to a flat int64 array. Use load_results to reload and automatically reverse these conversions.

Parameters:

Name Type Description Default
results Dataset

The results dataset (typically model.results_).

required
path str

Destination file path (e.g. 'results.nc').

required
Source code in src/simpl/utils.py
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
def save_results_to_netcdf(results: xr.Dataset, path: str) -> None:
    """Save a SIMPL results ``xr.Dataset`` to a netCDF file.

    Before writing, the function performs several type conversions required by
    the netCDF4 format: boolean arrays (e.g. ``spike_mask``) are cast to
    ``int32``, boolean ``attrs`` are cast to ``int``, and ``trial_slices``
    (a list of Python ``slice`` objects) is serialised to a flat ``int64``
    array.  Use ``load_results`` to reload and automatically reverse
    these conversions.

    Parameters
    ----------
    results : xr.Dataset
        The results dataset (typically ``model.results_``).
    path : str
        Destination file path (e.g. ``'results.nc'``)."""
    results_to_save = results.copy(deep=True)
    if "spike_mask" in results_to_save:
        results_to_save["spike_mask"] = results_to_save["spike_mask"].astype("int32")
    # Convert boolean 'reshape' attrs to int (netCDF4 doesn't support bool attrs)
    for var in results_to_save.data_vars:
        if "reshape" in results_to_save[var].attrs:
            results_to_save[var].attrs["reshape"] = int(results_to_save[var].attrs["reshape"])
    results_to_save.to_netcdf(path)

Load results from a saved file. Some variables need to be converted back to their original form.

See below issues for detail. https://github.com/TomGeorge1234/SIMPL/issues/5 https://github.com/TomGeorge1234/SIMPL/issues/8

Parameters:

Name Type Description Default
path str

Path to the saved file.

required

Returns:

Type Description
Dataset

Results.

Source code in src/simpl/utils.py
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
def load_results(path: str) -> xr.Dataset:
    """
    Load results from a saved file.
    Some variables need to be converted back to their original form.

    See below issues for detail.
    https://github.com/TomGeorge1234/SIMPL/issues/5
    https://github.com/TomGeorge1234/SIMPL/issues/8

    Parameters
    ----------
    path : str
        Path to the saved file.

    Returns
    -------
    xr.Dataset
        Results.
    """
    results = xr.load_dataset(path)
    # Convert int 'reshape' attrs back to bool
    for var in results.data_vars:
        if "reshape" in results[var].attrs:
            results[var].attrs["reshape"] = bool(results[var].attrs["reshape"])

    if "spike_mask" in results:
        results["spike_mask"] = results["spike_mask"].astype("bool")

    return results

Place-Field Analysis

Get argmax spatial position for each neuron's receptive field.

Parameters:

Name Type Description Default
F (ndarray, shape(N_neurons, N_bins))

Receptive fields.

required
coords (ndarray, shape(N_bins, D))

Spatial coordinates for each bin (e.g. environment.flattened_discretised_coords).

required

Returns:

Type Description
(ndarray, shape(N_neurons, D))

Peak spatial position for each neuron.

Source code in src/simpl/utils.py
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
def get_field_peaks(F: jax.Array, coords: jax.Array) -> jax.Array:
    """Get argmax spatial position for each neuron's receptive field.

    Parameters
    ----------
    F : jnp.ndarray, shape (N_neurons, N_bins)
        Receptive fields.
    coords : jnp.ndarray, shape (N_bins, D)
        Spatial coordinates for each bin (e.g. ``environment.flattened_discretised_coords``).

    Returns
    -------
    jnp.ndarray, shape (N_neurons, D)
        Peak spatial position for each neuron.
    """
    argmax_bins = np.argmax(F, axis=1)
    return coords[argmax_bins]

Analyse tuning curves and return information about place fields.

Terminology: "field" is the whole tuning curve. "place field" (pf) is the portion of the whole tuning curve identified as a particular place field.

Parameters:

Name Type Description Default
F (ndarray, shape(N_neurons, N_bins))

The estimated place fields.

required
N_neurons int

Number of neurons.

required
N_PFmax int

Maximum number of place fields per neuron (for fixed-shape arrays).

required
D int

Dimensionality of the latent space.

required
xF_shape tuple

Shape of the discretised environment grid (e.g. (nx, ny)).

required
xF (ndarray, shape(N_bins, D))

Flattened discretised environment coordinates.

required
dt float

Time-step size (seconds).

required
bin_size float

Spatial bin size of the environment.

required
n_bins int

Total number of spatial bins.

required

Returns:

Type Description
dict

Place-field results dictionary with keys such as place_field_count, place_field_size, etc.

Source code in src/simpl/utils.py
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
def analyse_place_fields(
    F: jax.Array,
    N_neurons: int,
    N_PFmax: int,
    D: int,
    xF_shape: tuple,
    xF: jax.Array,
    dt: float,
    bin_size: float,
    n_bins: int,
) -> dict:
    """Analyse tuning curves and return information about place fields.

    Terminology: "field" is the *whole* tuning curve.  "place field" (pf) is the
    portion of the whole tuning curve identified as a particular place field.

    Parameters
    ----------
    F : jnp.ndarray, shape (N_neurons, N_bins)
        The estimated place fields.
    N_neurons : int
        Number of neurons.
    N_PFmax : int
        Maximum number of place fields per neuron (for fixed-shape arrays).
    D : int
        Dimensionality of the latent space.
    xF_shape : tuple
        Shape of the discretised environment grid (e.g. ``(nx, ny)``).
    xF : jnp.ndarray, shape (N_bins, D)
        Flattened discretised environment coordinates.
    dt : float
        Time-step size (seconds).
    bin_size : float
        Spatial bin size of the environment.
    n_bins : int
        Total number of spatial bins.

    Returns
    -------
    dict
        Place-field results dictionary with keys such as
        ``place_field_count``, ``place_field_size``, etc.
    """

    # Initialise arrays
    pf_count = np.zeros((N_neurons))
    pf_size = np.nan * np.ones((N_neurons, N_PFmax))
    pf_position = np.nan * np.ones((N_neurons, N_PFmax, D))
    pf_covariance = np.nan * np.ones((N_neurons, N_PFmax, D, D))
    pf_maxfr = np.nan * np.zeros((N_neurons, N_PFmax))
    pf_edges = np.zeros((N_neurons, *xF_shape))
    pf_roundness = np.nan * np.zeros((N_neurons, N_PFmax))

    # Reshape the fields
    F_fields = F.reshape(N_neurons, *xF_shape)  # reshape F into fields

    # Threshold the fields
    F_1Hz = jnp.where(F_fields > 1.0 * dt, 1, 0)  # threshold at 1Hz

    # Total environment size
    volume_element = bin_size**D
    env_size = n_bins * volume_element

    # For each cell in turn, analyse the place fields
    for n in range(N_neurons):
        # Finds contiguous field areas, O/False is considered background and labelled "0".
        # Doesn't count diagonal pixel-connections as connections
        field = F_fields[n]
        field_thresh = F_1Hz[n]
        putative_pfs, putative_pfs_count = scipy.ndimage.label(field_thresh)
        n_pfs = 0  # some of which won't meet out criteria so we use our own counter
        combined_pf_mask = np.zeros_like(field)
        for f in range(1, min(N_PFmax, putative_pfs_count + 1)):
            pf_mask = jnp.where(putative_pfs == f, 1, 0)
            pf = jnp.where(putative_pfs == f, field, 0)
            # Check the field isn't too large
            size = pf_mask.sum() * volume_element
            if size > (1 / 2) * env_size:
                continue
            # Check max firing rate is over 2Hz
            maxfr = jnp.max(pf)
            if maxfr < 2.0 * dt:
                continue
            # Assuming it's passed these, it's a legit field. Now fit a Gaussian.
            perimeter = bin_size * skimage.measure.perimeter(pf_mask)
            perimeter_dilated = bin_size * skimage.measure.perimeter(scipy.ndimage.binary_dilation(pf_mask))
            perimeter = (perimeter + perimeter_dilated) / 2
            roundness = 4 * np.pi * size / perimeter**2
            combined_pf_mask += pf_mask
            mu, mode, cov = fit_gaussian(xF, pf.flatten()[None, :])
            mu, mode, cov = mu[0], mode[0], cov[0]
            pf_size[n, n_pfs] = size
            pf_position[n, n_pfs] = mu
            pf_covariance[n, n_pfs] = cov
            pf_maxfr[n, n_pfs] = maxfr
            pf_roundness[n, n_pfs] = roundness
            n_pfs += 1
        # pad combined_pf_mask with zeros
        is_pf = combined_pf_mask > 0
        pf_edges[n] = scipy.ndimage.binary_dilation(is_pf) ^ is_pf
        pf_count[n] = n_pfs

    place_field_results = {
        "place_field_count": jnp.array(pf_count),
        "place_field_size": jnp.array(pf_size),
        "place_field_position": jnp.array(pf_position),
        "place_field_covariance": jnp.array(pf_covariance),
        "place_field_max_firing_rate": jnp.array(pf_maxfr),
        "place_field_roundness": jnp.array(pf_roundness),
        "place_field_outlines": jnp.array(pf_edges),
    }

    return place_field_results

Calculate Skaggs spatial information per neuron (bits/s).

\[I = \sum_x r(x) \log_2 \frac{r(x)}{\bar{r}} \, P(x)\]

where \(r(x)\) is the firing rate at position \(x\), \(\bar{r}\) is the mean firing rate, and \(P(x)\) is the occupancy probability.

Parameters:

Name Type Description Default
r Array(N_neurons, N_bins)

Firing rate maps in Hz (spikes per second, not per bin).

required
PX Array(N_bins)

Occupancy probability over spatial bins (sums to 1).

required

Returns:

Name Type Description
spatial_info Array(N_neurons)

Spatial information per neuron in bits/s.

Source code in src/simpl/utils.py
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
def calculate_spatial_information(
    r: jax.Array,
    PX: jax.Array,
) -> jax.Array:
    """Calculate Skaggs spatial information per neuron (bits/s).

    $$I = \\sum_x r(x) \\log_2 \\frac{r(x)}{\\bar{r}} \\, P(x)$$

    where \\(r(x)\\) is the firing rate at position \\(x\\), \\(\\bar{r}\\) is the mean firing rate,
    and \\(P(x)\\) is the occupancy probability.

    Parameters
    ----------
    r : jax.Array (N_neurons, N_bins)
        Firing rate maps in Hz (spikes per second, not per bin).
    PX : jax.Array (N_bins,)
        Occupancy probability over spatial bins (sums to 1).

    Returns
    -------
    spatial_info : jax.Array (N_neurons,)
        Spatial information per neuron in bits/s.
    """
    r_mean = jnp.sum(r * PX[None, :], axis=1)  # mean firing rate (N_neurons,) Hz
    eps = 1e-10
    ratio = r / (r_mean[:, None] + eps)
    spatial_info = jnp.sum((r * jnp.log2(ratio + eps)) * PX[None, :], axis=1)
    return spatial_info