Source code for kooplearn.datasets._ordered_mnist

import os
from typing import Optional, Union

import numpy as np
from sklearn.datasets import fetch_openml


[docs] def fetch_ordered_mnist( *, num_digits: int = 10, data_home: Optional[Union[str, os.PathLike]] = None, cache: bool = True, n_retries: int = 3, delay: float = 1.0, ): """ Fetch the MNIST dataset and return an *ordered* subset interleaving samples from each digit class. This function wraps :func:`sklearn.datasets.fetch_openml` for the MNIST dataset (OpenML ID 554) and reorders the samples so that digits ``0`` through ``num_digits - 1`` are **interleaved** in the output. This is useful for generating class-balanced or periodic sequences for Koopman operator regression experiments. The MNIST dataset contains 70,000 grayscale handwritten digits (60,000 for training and 10,000 for testing) of size 28×28. Parameters ---------- num_digits : int, default=10 Number of digit classes to include, from 1 to 10. For example, ``num_digits=3`` returns only digits ``0``, ``1``, and ``2``. data_home : str or path-like, optional Specify an alternative download and cache folder for the dataset. By default, scikit-learn stores data in ``~/scikit_learn_data``. cache : bool, default=True Whether to cache the downloaded dataset. n_retries : int, default=3 Number of times to retry downloading if network errors occur. delay : float, default=1.0 Number of seconds between retries during download. Returns ------- images : ndarray of shape (n_samples, 28, 28) Array of grayscale MNIST images, with values scaled to the [0, 1] range (float64). targets : ndarray of shape (n_samples,) Corresponding digit labels (integers in ``[0, num_digits - 1]``). Notes ----- The dataset is reordered so that classes are interleaved in the returned arrays. For example, with ``num_digits=3``, the ordering will be: ``[0, 1, 2, 0, 1, 2, 0, 1, 2, ...]`` Examples -------- >>> from kooplearn.datasets import fetch_ordered_mnist >>> images, targets = fetch_ordered_mnist(num_digits=3) >>> images.shape (20709, 28, 28) >>> np.unique(targets) array([0, 1, 2]) """ if not (1 <= num_digits <= 10): raise ValueError(f"`num_digits` must be between 1 and 10, got {num_digits}.") # Fetch raw MNIST data (OpenML ID 554) X, y = fetch_openml( data_id=554, data_home=data_home, cache=cache, n_retries=n_retries, delay=delay, return_X_y=True, as_frame=False, ) # Convert to arrays X = np.asarray(X, dtype=np.uint8) y = np.asarray(y, dtype=int) # Reshape to (n_samples, 28, 28) images = X.reshape(-1, 28, 28) targets = y # Get class indices for digits 0..num_digits-1 digits_indexes = { digit: np.flatnonzero(targets == digit) for digit in range(num_digits) } # Compute interleaved ordering: # e.g. [0_0, 1_0, 2_0, 0_1, 1_1, 2_1, ...] def interleave(indices_dict, classes): sizes = [len(indices_dict[d]) for d in classes] n_min = min(sizes) n_classes = len(classes) dtype = indices_dict[classes[0]].dtype # Allocate array for interleaved indices ordered = np.empty(n_min * n_classes, dtype=dtype) for start, cls in enumerate(classes): ordered[start::n_classes] = indices_dict[cls][:n_min] return ordered ordering_perm = interleave(digits_indexes, list(range(num_digits))) # Apply the new ordering images = images[ordering_perm] targets = targets[ordering_perm] # Scaling data images = (images / 255.0).astype("float64") return images, targets