energy_loss¶
- kooplearn.jax.nn.energy_loss(x: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, y: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, grad_weight: float = 0.001) Array[source]¶
Energy-based loss function.
Computes an energy-based loss that incorporates second-order information (Jacobians) into the learning process.
The loss is computed as:
\[\mathcal{L}(x, y) = \text{tr}(W^2) - 2\langle x, x \rangle \cdot L\]where
\[W = \frac{1}{N}(xx^\top + \lambda yy^\top)\]where:
\(x \in \mathbb{R}^{N \times L}\) are input features
\(y \in \mathbb{R}^{N \times DL}\) are Jacobian features (reshaped from \((N, D, L)\))
\(\lambda\) is the
`grad_weight`parameter controlling Jacobian contribution\(N\) is the batch size
\(D\) is the state space dimensionality
\(L\) is the latent space dimensionality
Hint
Check out the Prinz Potential example for a practical use of this loss function.
- Parameters:
x (
ArrayLike) – Input features of shape \((N, L)\), where \(N\) is the batch size and \(L\) is the dimensionality of the latent space.y (
ArrayLike) – Jacobian features of shape \((N, D, L)\), where \(N\) is the batch size, \(D\) is the state space dimensionality, and \(L\) is the latent space dimensionality.grad_weight (
float, optional) – Weight for the Jacobian contribution. Must be non-negative. Controls how much the Jacobian term contributes to the total loss. Default is 1e-3.
- Returns:
Scalar loss value.
- Return type:
jax.Array- Raises:
AssertionError – If \(x\) is not 2-dimensional or \(y\) is not 3-dimensional. If \(\text{grad_weight} < 0\).