autoencoder_loss¶
- kooplearn.jax.nn.autoencoder_loss(x: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, y: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, x_rec: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, y_enc: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, x_evo: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, y_pred: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, alpha_rec: float = 1.0, alpha_lin: float = 1.0, alpha_pred: float = 1.0) Array[source]¶
Single-step Dynamic Autoencoder (DAE) loss.
Introduced by Lusch et al. [1]. This loss combines three objectives to train dynamic autoencoders for learning Koopman operators:
Reconstruction loss: Measures how well the autoencoder reconstructs inputs
Linearity loss: Enforces linear evolution in the latent space
Prediction loss: Measures prediction quality in the original space
The total loss is:
\[\mathcal{L} = \alpha_\mathrm{rec} \|x - \phi^{-1}(\phi(x))\|^2 + \alpha_\mathrm{lin} \|\phi(y) - K\phi(x)\|^2 + \alpha_\mathrm{pred} \|y - \phi^{-1}(K\phi(x))\|^2\]where \(\phi\) is the encoder, \(\phi^{-1}\) is the decoder, and \(K\) is the Koopman operator in latent space.
Hint
Check out the Ordered MNIST example for a practical use of this loss function.
- Parameters:
x (
ArrayLike) – Input features of shape (N, D), where N is the number of samples and D is the input dimension.y (
ArrayLike) – Target output features of shape (N, D).x_rec (
ArrayLike) – Reconstructed input \(\phi^{-1}(\phi(x))\) of shape (N, D).y_enc (
ArrayLike) – Encoded target \(\phi(y)\) of shape (N, d), where d is the latent dimension.x_evo (
ArrayLike) – Evolved latent representation \(K\phi(x)\) of shape (N, d).y_pred (
ArrayLike) – Predicted decoded output \(\phi^{-1}(K\phi(x))\) of shape (N, D).alpha_rec (
float, optional) – Weight for the reconstruction term. Default is 1.0.alpha_lin (
float, optional) – Weight for the linearity term. Default is 1.0.alpha_pred (
float, optional) – Weight for the prediction term. Default is 1.0.
- Returns:
Scalar total loss value.
- Return type:
jax.Array
References
Bethany Lusch, J. Nathan Kutz, and Steven L. Brunton. Deep learning for universal linear embeddings of nonlinear dynamics. Nature Communications, November 2018. URL: https://doi.org/10.1038/s41467-018-07210-0, doi:10.1038/s41467-018-07210-0.