Source code for kooplearn.jax.nn._functional

"""Loss functions for Koopman operator learning with JAX.

This module provides differentiable loss functions commonly used in
Koopman operator learning, including:

- VAMP score variants
- Spectral contrastive loss
- Dynamic autoencoder losses
- Orthonormality regularization terms

All functions are JAX-compatible and support automatic differentiation.
"""

import jax
import jax.numpy as jnp
from jax.typing import ArrayLike

from kooplearn.jax.nn._linalg import sqrtmh
from kooplearn.jax.nn._stats import cov_norm_squared_unbiased, covariance


[docs] def vamp_loss( x: ArrayLike, y: ArrayLike, schatten_norm: int = 2, center_covariances: bool = True, ) -> jax.Array: r"""Variational Approach for learning Markov Processes (VAMP) score. Computes the negative VAMP-p score as introduced by :cite:t:`vamp_loss-Wu2019`. The VAMP score measures the quality of a feature transformation by quantifying how well it captures slow processes in the dynamics. .. math:: \mathcal{L}(x, y) = -\sum_{i} \sigma_{i}(A)^{p} where .. math:: A = (x^{\top}x)^{\dagger/2} x^{\top}y (y^{\top}y)^{\dagger/2} and :math:`\sigma_i(A)` are the singular values of A. .. hint:: Check out the `Ordered MNIST <../examples/ordered_mnist_jax.html>`_ example for a practical use of this loss function. Parameters ---------- x : ArrayLike Input features of shape `(N, D_x)`, where N is the number of samples and D_x is the input feature dimension. y : ArrayLike Output features of shape `(N, D_y)`, where N is the number of samples and D_y is the output feature dimension. schatten_norm : int, optional Order p of the Schatten norm. Computes the VAMP-p score. Currently supports p=1 (nuclear norm) and p=2 (Frobenius norm). Default is 2. center_covariances : bool, optional If True, use centered covariances (subtract means). If False, use uncentered covariances. Default is True. Returns ------- jax.Array Scalar loss value (negative VAMP-p score). Raises ------ NotImplementedError If `schatten_norm` is not 1 or 2. Notes ----- For p=2, a numerically stable least-squares formulation is used instead of direct pseudoinverse computation. References ---------- .. cite:t:`vamp_loss-Wu2019` """ cov_x, cov_y, cov_xy = ( covariance(x, center=center_covariances), covariance(y, center=center_covariances), covariance(x, y, center=center_covariances), ) if schatten_norm == 2: # Using least squares in place of pinv for numerical stability M_x = jnp.linalg.lstsq(cov_x, cov_xy)[0] M_y = jnp.linalg.lstsq(cov_y, cov_xy.T)[0] return -jnp.trace(M_x @ M_y) elif schatten_norm == 1: sqrt_cov_x = sqrtmh(cov_x) sqrt_cov_y = sqrtmh(cov_y) M = jnp.linalg.multi_dot( [ jnp.linalg.pinv(sqrt_cov_x, hermitian=True), cov_xy, jnp.linalg.pinv(sqrt_cov_y, hermitian=True), ] ) return -jnp.linalg.norm(M, "nuc") else: raise NotImplementedError( f"Schatten norm {schatten_norm} not implemented. " "Supported values are 1 and 2." )
[docs] def spectral_contrastive_loss(x: ArrayLike, y: ArrayLike) -> jax.Array: r"""Spectral contrastive loss for self-supervised learning. Originally introduced by :cite:t:`spectral_contrastive_loss-haochen2021provable` and adapted for evolution operators in :cite:t:`spectral_contrastive_loss-turri2025self` and :cite:t:`spectral_contrastive_loss-jeong2025efficient`. The loss encourages alignment of paired samples (x_i, y_i) while discouraging alignment of unpaired samples: .. math:: \mathcal{L}(x, y) = \frac{1}{N(N-1)}\sum_{i \neq j}\langle x_{i}, y_{j} \rangle^2 - \frac{2}{N}\sum_{i=1}^N\langle x_{i}, y_{i} \rangle .. hint:: Check out the `Ordered MNIST <../examples/ordered_mnist_jax.html>`_ example for a practical use of this loss function. Parameters ---------- x : ArrayLike Input features of shape `(N, D)`, where N is the number of samples and D is the feature dimension. y : ArrayLike Output features of shape `(N, D)`. Must have the same shape as `x`. Returns ------- jax.Array Scalar loss value. Raises ------ ValueError If x and y do not have the same shape or if x is not 2-dimensional. References ---------- .. cite:t:`spectral_contrastive_loss-haochen2021provable` .. cite:t:`spectral_contrastive_loss-turri2025self` .. cite:t:`spectral_contrastive_loss-jeong2025efficient` """ x = jnp.asarray(x) y = jnp.asarray(y) if x.shape != y.shape: raise ValueError( f"x and y must have the same shape, got {x.shape} and {y.shape}" ) if x.ndim != 2: raise ValueError(f"x must be 2-dimensional, got {x.ndim} dimensions") npts, dim = x.shape diag = 2 * jnp.mean(x * y) * dim square_term = (x @ y.T) ** 2 off_diag = ( jnp.mean(jnp.triu(square_term, k=1) + jnp.tril(square_term, k=-1)) * npts / (npts - 1) ) return off_diag - diag
[docs] def autoencoder_loss( x: ArrayLike, y: ArrayLike, x_rec: ArrayLike, y_enc: ArrayLike, x_evo: ArrayLike, y_pred: ArrayLike, alpha_rec: float = 1.0, alpha_lin: float = 1.0, alpha_pred: float = 1.0, ) -> jax.Array: r"""Single-step Dynamic Autoencoder (DAE) loss. Introduced by :cite:t:`autoencoder_loss-Lusch2018`. This loss combines three objectives to train dynamic autoencoders for learning Koopman operators: 1. **Reconstruction loss**: Measures how well the autoencoder reconstructs inputs 2. **Linearity loss**: Enforces linear evolution in the latent space 3. **Prediction loss**: Measures prediction quality in the original space The total loss is: .. math:: \mathcal{L} = \alpha_\mathrm{rec} \|x - \phi^{-1}(\phi(x))\|^2 + \alpha_\mathrm{lin} \|\phi(y) - K\phi(x)\|^2 + \alpha_\mathrm{pred} \|y - \phi^{-1}(K\phi(x))\|^2 where :math:`\phi` is the encoder, :math:`\phi^{-1}` is the decoder, and :math:`K` is the Koopman operator in latent space. .. hint:: Check out the `Ordered MNIST <../examples/ordered_mnist_jax.html>`_ example for a practical use of this loss function. Parameters ---------- x : ArrayLike Input features of shape `(N, D)`, where N is the number of samples and D is the input dimension. y : ArrayLike Target output features of shape `(N, D)`. x_rec : ArrayLike Reconstructed input :math:`\phi^{-1}(\phi(x))` of shape `(N, D)`. y_enc : ArrayLike Encoded target :math:`\phi(y)` of shape `(N, d)`, where d is the latent dimension. x_evo : ArrayLike Evolved latent representation :math:`K\phi(x)` of shape `(N, d)`. y_pred : ArrayLike Predicted decoded output :math:`\phi^{-1}(K\phi(x))` of shape `(N, D)`. alpha_rec : float, optional Weight for the reconstruction term. Default is 1.0. alpha_lin : float, optional Weight for the linearity term. Default is 1.0. alpha_pred : float, optional Weight for the prediction term. Default is 1.0. Returns ------- jax.Array Scalar total loss value. References ---------- .. cite:t:`autoencoder_loss-Lusch2018` """ def mse(true, pred): return jnp.mean((true - pred) ** 2) rec_loss = mse(x, x_rec) lin_loss = mse(y_enc, x_evo) pred_loss = mse(y, y_pred) return alpha_rec * rec_loss + alpha_lin * lin_loss + alpha_pred * pred_loss
def orthonormal_fro_reg(x: ArrayLike, key: jax.random.PRNGKey) -> jax.Array: r"""Orthonormality regularization using Frobenius norm. Encourages the features to have an identity covariance matrix and zero mean. This regularization promotes orthonormal representations in the latent space. The regularization term is: .. math:: \frac{1}{D} \left( \|\mathbf{C}_{X} - I\|_F^2 + 2\|\mathbb{E}[x]\|^2 \right) where :math:`\mathbf{C}_X` is the covariance matrix of x and D is the feature dimension. Parameters ---------- x : ArrayLike Input features of shape `(N, D)`, where N is the number of samples and D is the feature dimension. key : jax.random.PRNGKey JAX random key used for unbiased covariance estimation via permutation. Returns ------- jax.Array Scalar regularization value. Notes ----- The covariance norm is computed using an unbiased estimator that requires random permutations, hence the need for a PRNG key. """ x = jnp.asarray(x) x_mean = x.mean(axis=0, keepdims=True) x_centered = x - x_mean Cx_fro_2 = cov_norm_squared_unbiased(x_centered, key=key) tr_Cx = jnp.einsum("ij,ij->", x_centered, x_centered) / x.shape[0] centering_loss = (x_mean**2).sum() D = x.shape[-1] reg = Cx_fro_2 - 2 * tr_Cx + D + 2 * centering_loss return reg / D def orthonormal_logfro_reg(x: ArrayLike) -> jax.Array: r"""Orthonormality regularization using log-Frobenius norm. An alternative to `orthonormal_fro_reg` that uses a logarithmic penalty on the eigenvalues of the covariance matrix. This can provide better conditioning and avoid issues with very small or large eigenvalues. The regularization term is: .. math:: \frac{1}{D}\text{Tr}(C_X^{2} - C_X - \ln(C_X)) + 2\|\mathbb{E}[x]\|^2 where :math:`C_X` is the covariance matrix of x. Parameters ---------- x : ArrayLike Input features of shape `(N, D)`, where N is the number of samples and D is the feature dimension. Returns ------- jax.Array Scalar regularization value. Notes ----- Eigenvalues below machine epsilon are clamped to avoid numerical issues with the logarithm. """ x = jnp.asarray(x) cov = covariance(x) # shape: (D, D) eps = jnp.finfo(cov.dtype).eps * cov.shape[0] vals_x = jnp.linalg.eigvalsh(cov) vals_x = jnp.where(vals_x > eps, vals_x, eps) orth_loss = jnp.mean(-jnp.log(vals_x) + vals_x * (vals_x - 1.0)) centering_loss = (x.mean(0, keepdims=True) ** 2).sum() reg = orth_loss + 2 * centering_loss return reg
[docs] def energy_loss(x: ArrayLike, y: ArrayLike, grad_weight: float = 1e-3) -> jax.Array: r"""Energy-based loss function. Computes an energy-based loss that incorporates second-order information (Jacobians) into the learning process. The loss is computed as: .. math:: \mathcal{L}(x, y) = \text{tr}(W^2) - 2\langle x, x \rangle \cdot L where .. math:: W = \frac{1}{N}(xx^\top + \lambda yy^\top) where: - :math:`x \in \mathbb{R}^{N \times L}` are input features - :math:`y \in \mathbb{R}^{N \times DL}` are Jacobian features (reshaped from :math:`(N, D, L)`) - :math:`\lambda` is the ```grad_weight``` parameter controlling Jacobian contribution - :math:`N` is the batch size - :math:`D` is the state space dimensionality - :math:`L` is the latent space dimensionality .. hint:: Check out the `Prinz Potential <../examples/prinz_potential.html>`_ example for a practical use of this loss function. Parameters ---------- x : ArrayLike Input features of shape :math:`(N, L)`, where :math:`N` is the batch size and :math:`L` is the dimensionality of the latent space. y : ArrayLike Jacobian features of shape :math:`(N, D, L)`, where :math:`N` is the batch size, :math:`D` is the state space dimensionality, and :math:`L` is the latent space dimensionality. grad_weight : float, optional Weight for the Jacobian contribution. Must be non-negative. Controls how much the Jacobian term contributes to the total loss. Default is 1e-3. Returns ------- jax.Array Scalar loss value. Raises ------ AssertionError If :math:`x` is not 2-dimensional or :math:`y` is not 3-dimensional. If :math:`\text{grad_weight} < 0`. """ x = jnp.asarray(x) y = jnp.asarray(y) assert x.ndim == 2, "x must be 2-dimensional" assert y.ndim == 3, "y must be 3-dimensional" assert grad_weight >= 0.0, "grad_weight must be non-negative" npts, dim = x.shape y_reshaped = y.reshape(npts, -1) W = jnp.matmul(x, x.T) + grad_weight * jnp.matmul(y_reshaped, y_reshaped.T) W = W / npts diag = 2 * jnp.mean(x * x) * dim square_term = jnp.sum(W**2) return square_term - diag