NnxFeatureMapEmbedder¶
- class kooplearn.jax.NnxFeatureMapEmbedder(encoder: Module, decoder: Module | None = None)[source]¶
Bases:
BaseEstimator,TransformerMixinsklearn-style transformer wrapping NNX Modules (encoder/decoder).
This class mirrors the original PyTorch-based
kooplearn.torch.FeatureMapEmbedder, using Flax’ NNX framework. It accepts statefulnnx.Moduleinstances, JIT-compiles their forward pass, and uses eval mode for inference.The interface follows the scikit-learn
TransformerMixinpattern, providingfitandtransformmethods.- Parameters:
encoder (
nnx.Module) – A NNX Module instance mapping input data to latent space. Its__call__method will be JIT-compiled.decoder (
nnx.Module, optional) – A NNX Module instance mapping latent space back to input space. Its__call__method will be JIT-compiled. If not provided, only encoding (transform) is supported.
Examples
Example usage with the modern Flax NNX API:
import jax import jax.numpy as jnp from flax import nnx import numpy as np class SimpleEncoder(nnx.Module): def __init__(self, in_features: int, out_features: int, *, rngs: nnx.Rngs): self.linear = nnx.Linear(in_features, out_features, rngs=rngs) def __call__(self, x): return self.linear(x) # 1. Initialize module and state rngs = nnx.Rngs(0) encoder_module = SimpleEncoder(in_features=5, out_features=10, rngs=rngs) # 2. Create the transformer transformer = NnxFeatureMapEmbedder(encoder=encoder_module) # 3. Use it data = np.random.rand(100, 5).astype(np.float32) latent_features = transformer.transform(data) print(latent_features.shape) # (100, 10)
Notes
Internally, the encoder and decoder forward passes are wrapped in
nnx.jitfor efficient computation. Statefulnnx.Moduleinstances are handled correctly by maintaining their parameter/state dictionaries.Methods
- fit(X: ndarray | Array = None, y: ndarray | Array = None) NnxFeatureMapEmbedder[source]¶
No fitting is performed by this transformer. The encoder/decoder are assumed to be pre-trained.