energy_loss

kooplearn.jax.nn.energy_loss(x: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, y: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, grad_weight: float = 0.001) Array[source]

Energy-based loss function.

Computes an energy-based loss that incorporates second-order information (Jacobians) into the learning process.

The loss is computed as:

\[\mathcal{L}(x, y) = \text{tr}(W^2) - 2\langle x, x \rangle \cdot L\]

where

\[W = \frac{1}{N}(xx^\top + \lambda yy^\top)\]

where:

  • \(x \in \mathbb{R}^{N \times L}\) are input features

  • \(y \in \mathbb{R}^{N \times DL}\) are Jacobian features (reshaped from \((N, D, L)\))

  • \(\lambda\) is the `grad_weight` parameter controlling Jacobian contribution

  • \(N\) is the batch size

  • \(D\) is the state space dimensionality

  • \(L\) is the latent space dimensionality

Hint

Check out the Prinz Potential example for a practical use of this loss function.

Parameters:
  • x (ArrayLike) – Input features of shape \((N, L)\), where \(N\) is the batch size and \(L\) is the dimensionality of the latent space.

  • y (ArrayLike) – Jacobian features of shape \((N, D, L)\), where \(N\) is the batch size, \(D\) is the state space dimensionality, and \(L\) is the latent space dimensionality.

  • grad_weight (float, optional) – Weight for the Jacobian contribution. Must be non-negative. Controls how much the Jacobian term contributes to the total loss. Default is 1e-3.

Returns:

Scalar loss value.

Return type:

jax.Array

Raises:

AssertionError – If \(x\) is not 2-dimensional or \(y\) is not 3-dimensional. If \(\text{grad_weight} < 0\).