Ordered MNIST (torch)¶
Authors: Pietro Novelli and Giacomo Turri
This example closely follows the experiment from “Learning invariant representations of time-homogeneous stochastic dynamical systems” [1], implemented using PyTorch.
Setup¶
We begin by loading the Ordered MNIST dataset from kooplearn and visualizing a small sample.
[1]:
import matplotlib.pyplot as plt
from kooplearn.datasets import fetch_ordered_mnist
# We only use the digits 0 to 4
num_digits = 5
images, labels = fetch_ordered_mnist(num_digits=num_digits)
# Plot the data
fig, axs = plt.subplots(nrows=2, ncols=num_digits, figsize=(0.8*num_digits, 1.3))
for img, ax in zip(images, axs.ravel()):
ax.imshow(img, cmap="Greys")
ax.axis("off")
fig.suptitle("Ordered MNIST", fontsize=16)
plt.show()
We split the dataset into training, validation, and test sets, using 3,000 points for training, 1,000 for validation, and 1,000 for testing.
[2]:
import numpy as np
# train images will be images[train_ids] and so on
train_ids, val_ids, test_ids = np.split(np.arange(5000), [3000, 4000])
Training the Oracle¶
Each evolution operator model will be validated as follows: starting from a test image of digit \(c\), we predict the next image using model.predict. The prediction should resemble an MNIST-style image of digit \(c+1\) (modulo configs.classes).
We then feed this predicted image to a pretrained MNIST classifier (the oracle) and evaluate how its accuracy changes over successive predictions.
We begin by defining the oracle classifier.
[3]:
import torch
from torch.utils.data import DataLoader, TensorDataset
device = "cuda" if torch.cuda.is_available() else "cpu"
class CNNEncoder(torch.nn.Module):
def __init__(self, num_classes):
super(CNNEncoder, self).__init__()
self.conv1 = torch.nn.Sequential(
torch.nn.Conv2d(
in_channels=1,
out_channels=16,
kernel_size=5,
stride=1,
padding=2,
),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2),
)
self.conv2 = torch.nn.Sequential(
torch.nn.Conv2d(16, 32, 5, 1, 2),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2),
)
# Fully connected layer, output num_classes classes
self.out = torch.nn.Linear(32 * 7 * 7, num_classes)
def forward(self, X):
if X.dim() == 3:
X = X.unsqueeze(1) # Add a channel dimension if needed
X = self.conv1(X)
X = self.conv2(X)
# Flatten the output of conv2
X = X.view(X.size(0), -1)
output = self.out(X)
return output
# Will be needed for autoencoder-based koopman operator learning.
class CNNDecoder(torch.nn.Module):
def __init__(self, num_classes=10):
super(CNNDecoder, self).__init__()
self.fc = torch.nn.Sequential(
torch.nn.Linear(num_classes, 32 * 7 * 7)
)
self.conv1 = torch.nn.Sequential(
torch.nn.Upsample(scale_factor=2),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(
in_channels=32,
out_channels=16,
kernel_size=5,
stride=1,
padding=2,
)
)
self.conv2 = torch.nn.Sequential(
torch.nn.Upsample(scale_factor=2),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(16, 1, 5, 1, 2)
)
def forward(self, x):
x = self.fc(x)
x = x.view(x.size(0), 32, 7, 7)
x = self.conv1(x)
x = self.conv2(x)
#Remove the channel dimension
x = x.squeeze(1)
return x
Now we train the oracle classifier.
[ ]:
def train_oracle_classifier():
torch.manual_seed(42)
num_epochs = 20
# Prepare data
# (1) Convert to PyTorch tensors
X_train = torch.tensor(images[train_ids], dtype=torch.float32)
labels_train = torch.tensor(labels[train_ids], dtype=torch.long)
X_val = torch.tensor(images[val_ids], dtype=torch.float32)
labels_val = torch.tensor(labels[val_ids], dtype=torch.long)
# (2) Create data loaders
train_dataset = TensorDataset(X_train, labels_train)
train_dl = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataset = TensorDataset(X_val, labels_val)
val_dl = DataLoader(val_dataset, batch_size=len(val_dataset), shuffle=False)
# Define model and optimizer
oracle = CNNEncoder(num_classes=num_digits).to(device)
optimizer = torch.optim.AdamW(oracle.parameters(), lr=8e-4)
def step(images, labels, is_train: bool = True):
if is_train:
oracle.train()
else:
oracle.eval()
images, labels = images.to(device), labels.to(device)
if is_train:
optimizer.zero_grad()
outputs = oracle(images)
loss = torch.nn.functional.cross_entropy(outputs, labels)
if is_train:
loss.backward()
optimizer.step()
acc = (outputs.argmax(1) == labels).float().mean().item()
return acc
for epoch in range(num_epochs):
train_acc = []
val_acc = []
for batch in train_dl:
train_acc.append(step(*batch))
with torch.no_grad():
for batch in val_dl:
val_acc.append(step(*batch, is_train=False))
if (epoch + 1)%5 == 0 or (epoch == 0):
print(f"EPOCH {epoch + 1:>2} Accuracy: {np.mean(train_acc)*100:.1f}% (train) - " \
f"{np.mean(val_acc)*100:.1f}% (val)")
return oracle
oracle = train_oracle_classifier()
EPOCH 1 Accuracy: 78.5% (train) - 94.8% (val)
EPOCH 5 Accuracy: 97.9% (train) - 97.5% (val)
EPOCH 10 Accuracy: 99.5% (train) - 98.3% (val)
EPOCH 15 Accuracy: 99.8% (train) - 99.1% (val)
EPOCH 20 Accuracy: 100.0% (train) - 98.9% (val)
Evolution Operator Models¶
Next, we will train multiple evolution operator models that predict the next image given the current one. Each model will be stored in trained_models for later evaluation.
[5]:
# Global variable collecting the trained models
trained_models = {}
Linear model¶
As a baseline evolution operator, we fit a simple linear Ridge model (equivalent to Principal Component Regression) on the flattened pixel features using kooplearn’s FeatureFlattener.
[6]:
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from kooplearn.linear_model import Ridge
from kooplearn.preprocessing import FeatureFlattener
# Data preparation
flattener = FeatureFlattener()
scaler = StandardScaler()
data_pipe = Pipeline([("flattener", flattener), ("scaler", scaler)])
data_pipe.fit(images[train_ids])
linear_model = Ridge(n_components=num_digits, eigen_solver="dense")
linear_model.fit(data_pipe.transform(images[train_ids]))
trained_models["Linear"] = {"model": linear_model, "embedder": data_pipe}
Classifier features (as in Sec. 6 of [2])¶
Here, we use the oracle classifier to extract feature embeddings and fit a Ridge model on these classifier features. This is possible using the convenient kooplearn’s FeatureMapEmbedder.
[7]:
from kooplearn.torch.utils import FeatureMapEmbedder
embedder = FeatureMapEmbedder(encoder=oracle)
images_embedded = embedder.transform(images[train_ids])
classifier_model = Ridge(n_components=num_digits).fit(
images_embedded, y=images[train_ids]
)
trained_models["Classifier_Baseline"] = {
"model": classifier_model,
"embedder": embedder,
}
Encoder-only methods¶
We train encoder-only models using the SpectralContrastiveLoss and VampLoss objectives to learn latent spaces in which a linear evolution operator can operate effectively.
[ ]:
class FeatureMap(torch.nn.Module):
def __init__(self, num_digits: int, normalize_latents: bool = True):
super().__init__()
self.normalize_latents = normalize_latents
self.backbone = CNNEncoder(num_classes=num_digits)
self.lin = torch.nn.Linear(num_digits, num_digits, bias=False)
def forward(self, X, lagged:bool=False):
z = self.backbone(X)
if self.normalize_latents:
z = torch.nn.functional.normalize(z, dim=-1)
if lagged:
z = self.lin(z)
return z
def train_encoder_only(criterion: torch.nn.Module):
num_epochs = 50
train_dataset = torch.from_numpy(images[train_ids]).float()
val_dataset = torch.from_numpy(images[val_ids]).float()
# Poor's man lagged dataloaders
batch_size = 64
train_dl = DataLoader(
TensorDataset(
train_dataset[:-1],
train_dataset[1:]),
batch_size=batch_size,
shuffle=True
)
val_dl = DataLoader(TensorDataset(val_dataset[:-1], val_dataset[1:]), batch_size=batch_size)
torch.manual_seed(42)
# Initialize model, loss and optimizer
model = FeatureMap(num_digits).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-3)
def step(batch, is_train:bool = True):
batch_X, batch_Y = batch
batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)
if is_train:
optimizer.zero_grad()
phi_X, phi_Y = model(batch_X), model(batch_Y, lagged=True)
loss = criterion(phi_X, phi_Y)
if is_train:
loss.backward()
optimizer.step()
return loss.item()
for epoch in range(num_epochs):
# Training phase
model.train()
train_loss = []
for batch in train_dl:
train_loss.append(step(batch))
# Validation phase
model.eval()
val_loss = []
with torch.no_grad():
for batch in val_dl:
val_loss.append(step(batch, is_train=False))
if (epoch + 1)%5 == 0 or (epoch == 0):
print(f"EPOCH {epoch + 1:>2} Loss: {np.mean(train_loss):.2f} (train) - " /
f"{np.mean(val_loss):.2f} (val)")
embedder = FeatureMapEmbedder(encoder=model)
evolution_operator_model = Ridge(n_components=num_digits).fit(
embedder.transform(train_dataset), train_dataset.numpy(force=True)
)
return {
"model": evolution_operator_model,
"embedder": embedder,
}
[ ]:
from kooplearn.torch.nn import SpectralContrastiveLoss, VampLoss
for name, criterion in zip(["VAMPNets", "Spectral Contrastive Loss"],
[ VampLoss(center_covariances=False), SpectralContrastiveLoss()]):
print(f"Fitting {name}")
trained_models[name] = train_encoder_only(criterion)
Fitting VAMPNets
EPOCH 1 Loss: -3.79 (train) - -4.35 (val)
EPOCH 5 Loss: -4.81 (train) - -4.61 (val)
EPOCH 10 Loss: -4.93 (train) - -4.69 (val)
EPOCH 15 Loss: -4.97 (train) - -4.70 (val)
EPOCH 20 Loss: -4.97 (train) - -4.71 (val)
EPOCH 25 Loss: -4.98 (train) - -4.70 (val)
EPOCH 30 Loss: -4.98 (train) - -4.72 (val)
EPOCH 35 Loss: -4.98 (train) - -4.72 (val)
EPOCH 40 Loss: -4.99 (train) - -4.72 (val)
EPOCH 45 Loss: -4.99 (train) - -4.72 (val)
EPOCH 50 Loss: -4.99 (train) - -4.73 (val)
Fitting Spectral Contrastive Loss
EPOCH 1 Loss: -0.94 (train) - -1.00 (val)
EPOCH 5 Loss: -1.00 (train) - -1.00 (val)
EPOCH 10 Loss: -1.67 (train) - -1.90 (val)
EPOCH 15 Loss: -2.80 (train) - -2.78 (val)
EPOCH 20 Loss: -2.94 (train) - -2.90 (val)
EPOCH 25 Loss: -3.82 (train) - -3.91 (val)
EPOCH 30 Loss: -4.58 (train) - -4.59 (val)
EPOCH 35 Loss: -4.85 (train) - -4.81 (val)
EPOCH 40 Loss: -4.90 (train) - -4.91 (val)
EPOCH 45 Loss: -4.96 (train) - -4.93 (val)
EPOCH 50 Loss: -5.00 (train) - -4.97 (val)
Dynamical Autoencoder [3]¶
To complement the encoder-only methods, we also train a dynamical autoencoder that jointly learns an encoder, a decoder, and a linear evolution operator.
[ ]:
from kooplearn.torch.nn import AutoEncoderLoss
class DynamicalAutoEncoder(torch.nn.Module):
def __init__(self, num_digits):
super().__init__()
self.encoder = CNNEncoder(num_digits).to(device)
self.decoder = CNNDecoder(num_digits).to(device)
self.evolution_operator = torch.nn.Linear(num_digits, num_digits, bias=False).to(device)
self.loss = AutoEncoderLoss()
def forward(self):
pass
def train_step(self, batch, optimizer, is_eval: bool = False):
X, Y = batch
X, Y = X.to(device), Y.to(device)
if not is_eval:
optimizer.zero_grad()
X_enc, Y_enc = self.encoder(X), self.encoder(Y)
X_rec = self.decoder(X_enc)
X_evo = self.evolution_operator(X_enc)
Y_pred = self.decoder(X_evo)
loss = self.loss(X, Y, X_rec, Y_enc, X_evo, Y_pred)
if not is_eval:
loss.backward()
optimizer.step()
return loss.item()
def train_autoencoder():
torch.manual_seed(42)
num_epochs = 50
train_dataset = torch.from_numpy(images[train_ids]).float()
val_dataset = torch.from_numpy(images[val_ids]).float()
# Poor's man lagged dataloaders
batch_size = 64
train_dl = DataLoader(
TensorDataset(
train_dataset[:-1],
train_dataset[1:]),
batch_size=batch_size,
shuffle=True
)
val_dl = DataLoader(TensorDataset(val_dataset[:-1], val_dataset[1:]), batch_size=batch_size)
# Model and optimizer
model = DynamicalAutoEncoder(num_digits)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-3)
for epoch in range(num_epochs):
model.train()
train_loss = []
for batch in train_dl:
train_loss.append(model.train_step(batch, optimizer))
# Validation phase
model.eval()
val_loss = []
with torch.no_grad():
for batch in val_dl:
val_loss.append(model.train_step(batch, optimizer, is_eval=True))
if (epoch + 1)%5 == 0 or (epoch == 0):
print(f"EPOCH {epoch + 1:>2} Loss: {np.mean(train_loss):.2f} (train) - " \
f"{np.mean(val_loss):.2f} (val)")
embedder = FeatureMapEmbedder(encoder=model.encoder, decoder=model.decoder)
evolution_operator_model = Ridge(n_components=num_digits).fit(embedder.transform(train_dataset))
return {
"model": evolution_operator_model,
"embedder": embedder,
}
trained_models["AutoEncoder"] = train_autoencoder()
EPOCH 1 Loss: 0.20 (train) - 0.14 (val)
EPOCH 5 Loss: 0.13 (train) - 0.12 (val)
EPOCH 10 Loss: 0.10 (train) - 0.10 (val)
EPOCH 15 Loss: 0.10 (train) - 0.10 (val)
EPOCH 20 Loss: 0.10 (train) - 0.10 (val)
EPOCH 25 Loss: 0.10 (train) - 0.10 (val)
EPOCH 30 Loss: 0.10 (train) - 0.10 (val)
EPOCH 35 Loss: 0.09 (train) - 0.10 (val)
EPOCH 40 Loss: 0.09 (train) - 0.10 (val)
EPOCH 45 Loss: 0.09 (train) - 0.10 (val)
Warning: The fitting algorithm discarded 1 dimensions of the 5 requested out of numerical instabilities.
The rank attribute has been updated to 4.
Consider decreasing the rank parameter.
EPOCH 50 Loss: 0.09 (train) - 0.10 (val)
Final Comparison¶
[ ]:
test_data = images[test_ids]
test_labels = labels[test_ids]
def evaluate_model(model, embedder, num_evaluation_steps = 15):
report = {
'accuracy': [],
'label': [],
'image': [],
'times': []
}
img = test_data
for t in range(1, num_evaluation_steps + 1):
if hasattr(embedder, "decoder") and embedder.decoder is None:
img = model.predict(embedder.transform(img), observable=True)
else:
_img = model.predict(embedder.transform(img))
img = embedder.inverse_transform(_img)
pred_labels = oracle(
torch.tensor(img, device=device, dtype=torch.float)
).argmax(axis=1).numpy(force=True)
accuracy = (pred_labels == (test_labels + t)%num_digits).mean()
report['accuracy'].append(accuracy.item())
report['image'].append(img)
report['label'].append(pred_labels)
report['times'].append(t)
return report
report = {}
for model_name, result in trained_models.items():
report[model_name] = evaluate_model(result["model"], result["embedder"])
[12]:
fig, ax = plt.subplots()
for model_name in report.keys():
t = report[model_name]['times']
acc = report[model_name]['accuracy']
ax.plot(t, acc, label=model_name)
ax.axhline(1/num_digits, color='black', linestyle='--', label='Random')
ax.legend(frameon=False, bbox_to_anchor=(1, 1))
ax.margins(x=0)
ax.set_ylim(0, 1.1)
ax.set_xlabel('Time steps')
ax.set_ylabel('Accuracy')
plt.show()
[ ]:
nun_models = len(report)
num_cols = len(report['Linear']['times'])
fig, axes = plt.subplots(
nun_models, num_cols, figsize=(num_cols, nun_models), sharex=True, sharey=True
)
test_seed_idx = 0
# Remove margins between columns
plt.subplots_adjust(wspace=0)
for model_idx, model_name in enumerate(report.keys()):
# First column
ax = axes[model_idx, 0]
ax.imshow(test_data[test_seed_idx], cmap='Greys')
ax.set_axis_off()
for prediction_step in range(num_cols - 1):
pred_label = report[model_name]['label'][prediction_step][test_seed_idx]
true_label = (
test_labels[test_seed_idx] + report[model_name]['times'][prediction_step]
)%num_digits
img = report[model_name]['image'][prediction_step][test_seed_idx]
# Set subplot for the current class
ax = axes[model_idx, prediction_step + 1]
# Plot the MNIST image
ax.imshow(img, cmap='Greys')
# Remove axes and ticks
ax.set_xticks([])
ax.set_yticks([])
ax.axis('off')
# Add a white background for the subplot
ax.set_facecolor('white')
# Add an inset for the predicted label in the upper right corner
if pred_label == true_label:
color = 'green'
else:
color = 'red'
inset_ax = ax.inset_axes([0.75, 0.75, 0.25, 0.25])
inset_ax.set_xlim(0, 1)
inset_ax.set_ylim(0, 1)
inset_ax.text(0.5, 0.4, f"{pred_label}" , color=color, fontsize=9, ha='center', va='center')
inset_ax.set_xticks([])
inset_ax.set_yticks([])
inset_ax.set_facecolor('white')
# Display the model names on the left of each row
for model_idx, model_name in enumerate(report.keys()):
axes[model_idx, 0].text(
-0.1,
0.5,
model_name.replace('_', ' '),
fontsize=14,
ha='right',
va='center',
transform=axes[model_idx, 0].transAxes)
for class_idx in range(num_cols):
title = (test_labels[test_seed_idx] + class_idx)%num_digits
if class_idx == 0:
axes[0, class_idx].set_title(f"Seed: {title}", fontsize=14)
else:
axes[0, class_idx].set_title(f"{title}", fontsize=14)
plt.show()
The plots above illustrate that encoder-only methods outperform other approaches in predicting long-term dynamics, maintaining higher accuracy over extended forecast horizons.
Finally, we visualize the leading eigenfunctions corresponding to the two largest-magnitude eigenvalues of each trained evolution operator model.
[ ]:
from kooplearn._utils import stable_topk
n_models = len(report.keys())
num_rows, num_cols = 2, 3
fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 8))
axes = axes.flatten()
for model_idx, model_name in enumerate(report.keys()):
ax = axes[model_idx]
ax.title.set_text(model_name.replace('_', ' '))
fitted_model = trained_models[model_name]['model']
embedder = trained_models[model_name]['embedder']
vals, lfuncs, rfuncs = fitted_model.eig(
eval_right_on=embedder.transform(test_data),
eval_left_on=embedder.transform(test_data)
)
unique_vals, idx_start = np.unique(np.abs(vals), return_index=True) # returns the unique values
# and the index of the first occurrence of a value
vals, lfuncs, rfuncs = vals[idx_start], lfuncs[:, idx_start], rfuncs[:, idx_start]
top_vals, top_indices = stable_topk(np.abs(vals), 2)
idx_i = top_indices[0]
idx_j = top_indices[1]
fns = lfuncs
fn_i = fns[:, idx_i].real
fn_j = fns[:, idx_j].real
scatter = ax.scatter(fn_i, fn_j, c=test_labels, cmap='tab10', vmax=10, alpha=0.7, linewidths=0)
# remove last axis and add legend
ax = axes[n_models-1]
legend = ax.legend(*scatter.legend_elements(num=4),
title="Digits",
frameon=True,
bbox_to_anchor=(1.3, 1))
ax.add_artist(legend)
fig.delaxes(axes[n_models])
plt.tight_layout()
plt.show()