ImageImpaint_Python_II/DataLoader.py

50 lines
1.7 KiB
Python

import glob
import os
import numpy as np
import torch.utils.data.dataset
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class ImageDataset(Dataset):
def __init__(self, image_dir):
self.image_files = sorted(glob.glob(os.path.join(image_dir, "**", "*.jpg"), recursive=True))
# Mean and std arrays could also be defined as class attributes
self.norm_mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
self.norm_std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
def __getitem__(self, index):
# Open image file, convert to numpy array and scale to [0, 1]
image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255
# Perform normalization for each channel
image = (image - self.norm_mean) / self.norm_std
return image, index
def __len__(self):
return len(self.image_files)
def get_image_loader(path: str):
image_dataset = ImageDataset(path)
totlen = len(image_dataset)
trains, tests = torch.utils.data.dataset.random_split(image_dataset, (int(totlen * .7), totlen - int(totlen * .7)),
generator=torch.Generator().manual_seed(42))
train_loader = DataLoader(
trains,
shuffle=True, # shuffle the order of our samples
batch_size=4, # stack 4 samples to a minibatch
num_workers=0 # no background workers (see comment below)
)
test_loader = DataLoader(
tsts,
shuffle=True, # shuffle the order of our samples
batch_size=4, # stack 4 samples to a minibatch
num_workers=0 # no background workers (see comment below)
)
return train_loader, test_loader