vamp_loss

kooplearn.jax.nn.vamp_loss(x: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, y: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, schatten_norm: int = 2, center_covariances: bool = True) Array[source]

Variational Approach for learning Markov Processes (VAMP) score.

Computes the negative VAMP-p score as introduced by Wu and Noé [1]. The VAMP score measures the quality of a feature transformation by quantifying how well it captures slow processes in the dynamics.

\[\mathcal{L}(x, y) = -\sum_{i} \sigma_{i}(A)^{p}\]

where

\[A = (x^{\top}x)^{\dagger/2} x^{\top}y (y^{\top}y)^{\dagger/2}\]

and \(\sigma_i(A)\) are the singular values of A.

Hint

Check out the Ordered MNIST example for a practical use of this loss function.

Parameters:
  • x (ArrayLike) – Input features of shape (N, D_x), where N is the number of samples and D_x is the input feature dimension.

  • y (ArrayLike) – Output features of shape (N, D_y), where N is the number of samples and D_y is the output feature dimension.

  • schatten_norm (int, optional) – Order p of the Schatten norm. Computes the VAMP-p score. Currently supports p=1 (nuclear norm) and p=2 (Frobenius norm). Default is 2.

  • center_covariances (bool, optional) – If True, use centered covariances (subtract means). If False, use uncentered covariances. Default is True.

Returns:

Scalar loss value (negative VAMP-p score).

Return type:

jax.Array

Raises:

NotImplementedError – If schatten_norm is not 1 or 2.

Notes

For p=2, a numerically stable least-squares formulation is used instead of direct pseudoinverse computation.

References

[1]

Hao Wu and Frank Noé. Variational approach for learning markov processes from time series data. Journal of Nonlinear Science, 30(1):23–66, August 2019. URL: https://doi.org/10.1007/s00332-019-09567-y, doi:10.1007/s00332-019-09567-y.