import glob import os import numpy as np import torch.utils.data.dataset from PIL import Image from torch.utils.data import Dataset from torch.utils.data import DataLoader class ImageDataset(Dataset): def __init__(self, image_dir): self.image_files = sorted(glob.glob(os.path.join(image_dir, "**", "*.jpg"), recursive=True)) # Mean and std arrays could also be defined as class attributes self.norm_mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) self.norm_std = np.array([0.229, 0.224, 0.225], dtype=np.float32) def __getitem__(self, index): # Open image file, convert to numpy array and scale to [0, 1] image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255 # Perform normalization for each channel image = (image - self.norm_mean) / self.norm_std return image, index def __len__(self): return len(self.image_files) def get_image_loader(path: str): image_dataset = ImageDataset(path) totlen = len(image_dataset) trains, tests = torch.utils.data.dataset.random_split(image_dataset, (int(totlen * .7), totlen - int(totlen * .7)), generator=torch.Generator().manual_seed(42)) train_loader = DataLoader( trains, shuffle=True, # shuffle the order of our samples batch_size=4, # stack 4 samples to a minibatch num_workers=0 # no background workers (see comment below) ) test_loader = DataLoader( tsts, shuffle=True, # shuffle the order of our samples batch_size=4, # stack 4 samples to a minibatch num_workers=0 # no background workers (see comment below) ) return train_loader, test_loader