MNISTDataModule
This class provides a PyTorch Lightning DataModule for handling the MNIST dataset. It manages data preparation, setup, and provides dataloaders for training. Key features include configurable batch size and number of workers for data loading.
Attributes
-
root_dir: string
- The root directory where the MNIST dataset is stored.
-
batch_size: int = 64
- The batch size for the DataLoader.
-
dataloader_num_workers: int = 0
- The number of worker processes to use for the DataLoader.
Constructors
-
Initializes the MNISTDataModule with root directory, batch size, and dataloader workers.
-
Parameters
- root_dir: string
- The root directory for storing MNIST dataset.
- batch_size: int
- The batch size for the dataloader.
- dataloader_num_workers: int
- The number of worker processes for the dataloader.
- root_dir: string
Methods
def prepare_data()
- Downloads the MNIST dataset if it does not already exist.
def setup(stage: string = None)
-
Sets up the dataset for training. It initializes the MNIST dataset with transformations.
-
Parameters
- stage: string
- The stage of the setup (e.g., 'fit', 'test'). Not used in this implementation.
- stage: string
def train_dataloader()
-
Creates and returns the DataLoader for the training dataset.
-
Return Value: DataLoader
- A PyTorch DataLoader for the training dataset.