vamp_loss¶
- kooplearn.jax.nn.vamp_loss(x: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, y: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, schatten_norm: int = 2, center_covariances: bool = True) Array[source]¶
Variational Approach for learning Markov Processes (VAMP) score.
Computes the negative VAMP-p score as introduced by Wu and Noé [1]. The VAMP score measures the quality of a feature transformation by quantifying how well it captures slow processes in the dynamics.
\[\mathcal{L}(x, y) = -\sum_{i} \sigma_{i}(A)^{p}\]where
\[A = (x^{\top}x)^{\dagger/2} x^{\top}y (y^{\top}y)^{\dagger/2}\]and \(\sigma_i(A)\) are the singular values of A.
Hint
Check out the Ordered MNIST example for a practical use of this loss function.
- Parameters:
x (
ArrayLike) – Input features of shape (N, D_x), where N is the number of samples and D_x is the input feature dimension.y (
ArrayLike) – Output features of shape (N, D_y), where N is the number of samples and D_y is the output feature dimension.schatten_norm (
int, optional) – Order p of the Schatten norm. Computes the VAMP-p score. Currently supports p=1 (nuclear norm) and p=2 (Frobenius norm). Default is 2.center_covariances (
bool, optional) – If True, use centered covariances (subtract means). If False, use uncentered covariances. Default is True.
- Returns:
Scalar loss value (negative VAMP-p score).
- Return type:
jax.Array- Raises:
NotImplementedError – If schatten_norm is not 1 or 2.
Notes
For p=2, a numerically stable least-squares formulation is used instead of direct pseudoinverse computation.
References
Hao Wu and Frank Noé. Variational approach for learning markov processes from time series data. Journal of Nonlinear Science, 30(1):23–66, August 2019. URL: https://doi.org/10.1007/s00332-019-09567-y, doi:10.1007/s00332-019-09567-y.