autoencoder_loss

kooplearn.jax.nn.autoencoder_loss(x: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, y: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, x_rec: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, y_enc: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, x_evo: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, y_pred: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, alpha_rec: float = 1.0, alpha_lin: float = 1.0, alpha_pred: float = 1.0) Array[source]

Single-step Dynamic Autoencoder (DAE) loss.

Introduced by Lusch et al. [1]. This loss combines three objectives to train dynamic autoencoders for learning Koopman operators:

  1. Reconstruction loss: Measures how well the autoencoder reconstructs inputs

  2. Linearity loss: Enforces linear evolution in the latent space

  3. Prediction loss: Measures prediction quality in the original space

The total loss is:

\[\mathcal{L} = \alpha_\mathrm{rec} \|x - \phi^{-1}(\phi(x))\|^2 + \alpha_\mathrm{lin} \|\phi(y) - K\phi(x)\|^2 + \alpha_\mathrm{pred} \|y - \phi^{-1}(K\phi(x))\|^2\]

where \(\phi\) is the encoder, \(\phi^{-1}\) is the decoder, and \(K\) is the Koopman operator in latent space.

Hint

Check out the Ordered MNIST example for a practical use of this loss function.

Parameters:
  • x (ArrayLike) – Input features of shape (N, D), where N is the number of samples and D is the input dimension.

  • y (ArrayLike) – Target output features of shape (N, D).

  • x_rec (ArrayLike) – Reconstructed input \(\phi^{-1}(\phi(x))\) of shape (N, D).

  • y_enc (ArrayLike) – Encoded target \(\phi(y)\) of shape (N, d), where d is the latent dimension.

  • x_evo (ArrayLike) – Evolved latent representation \(K\phi(x)\) of shape (N, d).

  • y_pred (ArrayLike) – Predicted decoded output \(\phi^{-1}(K\phi(x))\) of shape (N, D).

  • alpha_rec (float, optional) – Weight for the reconstruction term. Default is 1.0.

  • alpha_lin (float, optional) – Weight for the linearity term. Default is 1.0.

  • alpha_pred (float, optional) – Weight for the prediction term. Default is 1.0.

Returns:

Scalar total loss value.

Return type:

jax.Array

References

[1]

Bethany Lusch, J. Nathan Kutz, and Steven L. Brunton. Deep learning for universal linear embeddings of nonlinear dynamics. Nature Communications, November 2018. URL: https://doi.org/10.1038/s41467-018-07210-0, doi:10.1038/s41467-018-07210-0.