32 lines
1.1 KiB
Python
32 lines
1.1 KiB
Python
import glob
|
|
import os
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
from torch.utils.data import Dataset
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
class ImageDataset(Dataset):
|
|
def __init__(self, image_dir):
|
|
self.image_files = sorted(glob.glob(os.path.join(image_dir, "**", "*.jpg"), recursive=True))
|
|
# Mean and std arrays could also be defined as class attributes
|
|
self.norm_mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
|
self.norm_std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
|
|
|
def __getitem__(self, index):
|
|
# Open image file, convert to numpy array and scale to [0, 1]
|
|
image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255
|
|
# Perform normalization for each channel
|
|
image = (image - self.norm_mean) / self.norm_std
|
|
return image, index
|
|
|
|
def __len__(self):
|
|
return len(self.image_files)
|
|
|
|
|
|
def get_image_loader(path: str):
|
|
image_dataset = ImageDataset(path)
|
|
image_loader = DataLoader(image_dataset, shuffle=True, batch_size=10)
|
|
return image_loader
|