import numpy as np
import torch
from sklearn.base import BaseEstimator, TransformerMixin
[docs]
class FeatureMapEmbedder(BaseEstimator, TransformerMixin):
"""
sklearn-style transformer wrapping a PyTorch encoder (and optional decoder).
Parameters
----------
encoder : torch.nn.Module
Neural network mapping input data to latent space.
decoder : torch.nn.Module, optional
Neural network mapping latent space back to input space.
device : str, optional
Device for computation ('cpu' or 'cuda'). Defaults to auto-detect.
"""
def __init__(
self,
encoder: torch.nn.Module,
decoder: torch.nn.Module = None,
device: str | None = None,
):
self.encoder = encoder
self.decoder = decoder
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.encoder.to(self.device)
if self.decoder is not None:
self.decoder.to(self.device)
[docs]
def fit(self, X=None, y=None):
"""No fitting needed unless encoder/decoder are trainable elsewhere."""
# sklearn API requires fit(), so we return self.
return self
def _to_tensor(self, array: np.ndarray | torch.Tensor) -> torch.Tensor:
"""Helper: ensure input is a float tensor on the correct device."""
if isinstance(array, np.ndarray):
tensor = torch.from_numpy(array.copy(order="C")).float()
else:
tensor = array.float()
return tensor.to(self.device)
def __repr__(self):
return (
f"FeatureMapEmbedder(encoder={self.encoder.__class__.__name__}, "
f"decoder={self.decoder.__class__.__name__ if self.decoder else None}, "
f"device='{self.device}')"
)