NnxFeatureMapEmbedder

class kooplearn.jax.NnxFeatureMapEmbedder(encoder: Module, decoder: Module | None = None)[source]

Bases: BaseEstimator, TransformerMixin

sklearn-style transformer wrapping NNX Modules (encoder/decoder).

This class mirrors the original PyTorch-based kooplearn.torch.FeatureMapEmbedder, using Flax’ NNX framework. It accepts stateful nnx.Module instances, JIT-compiles their forward pass, and uses eval mode for inference.

The interface follows the scikit-learn TransformerMixin pattern, providing fit and transform methods.

Parameters:
  • encoder (nnx.Module) – A NNX Module instance mapping input data to latent space. Its __call__ method will be JIT-compiled.

  • decoder (nnx.Module, optional) – A NNX Module instance mapping latent space back to input space. Its __call__ method will be JIT-compiled. If not provided, only encoding (transform) is supported.

Examples

Example usage with the modern Flax NNX API:

import jax
import jax.numpy as jnp
from flax import nnx
import numpy as np

class SimpleEncoder(nnx.Module):
    def __init__(self, in_features: int, out_features: int, *, rngs: nnx.Rngs):
        self.linear = nnx.Linear(in_features, out_features, rngs=rngs)

    def __call__(self, x):
        return self.linear(x)

# 1. Initialize module and state
rngs = nnx.Rngs(0)
encoder_module = SimpleEncoder(in_features=5, out_features=10, rngs=rngs)

# 2. Create the transformer
transformer = NnxFeatureMapEmbedder(encoder=encoder_module)

# 3. Use it
data = np.random.rand(100, 5).astype(np.float32)
latent_features = transformer.transform(data)

print(latent_features.shape)
# (100, 10)

Notes

Internally, the encoder and decoder forward passes are wrapped in nnx.jit for efficient computation. Stateful nnx.Module instances are handled correctly by maintaining their parameter/state dictionaries.

Methods

fit(X: ndarray | Array = None, y: ndarray | Array = None) NnxFeatureMapEmbedder[source]

No fitting is performed by this transformer. The encoder/decoder are assumed to be pre-trained.

inverse_transform(Z: ndarray | Array) ndarray[source]

Decode data using the JAX NNX decoder, if available.

transform(X: ndarray | Array) ndarray[source]

Encode data using the JAX NNX encoder.