ImageImpaint_Python_II/DataLoader.py

78 lines
2.6 KiB
Python

import glob
import os
import numpy as np
import torch.utils.data.dataset
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import ex4
IMG_SIZE = 100
class ImageDataset(Dataset):
def __init__(self, image_dir):
self.image_files = sorted(glob.glob(os.path.join(image_dir, "**", "*.jpg"), recursive=True))
# Mean and std arrays could also be defined as class attributes
self.norm_mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
self.norm_std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
def __getitem__(self, index):
# Open image file, convert to numpy array and scale to [0, 1]
target_image = Image.open(self.image_files[index])
# image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255
resize_transforms = transforms.Compose([
transforms.Resize(size=IMG_SIZE),
transforms.CenterCrop(size=(IMG_SIZE, IMG_SIZE)),
])
target_image = resize_transforms(target_image)
# normalize image from 0-1
target_image = np.array(target_image, dtype=np.float64) / 255.0
# Perform normalization for each channel
# image = (image - self.norm_mean) / self.norm_std
# calculate image with black grid
doomed_image = ex4.ex4(target_image, (5, 5), (4, 4))
# convert image to grayscale
# target_image = rgb2gray(target_image) # todo look if gray image makes sense
return doomed_image[0], np.transpose(target_image, (2, 0, 1))
def __len__(self):
return len(self.image_files)
def get_image_loader(path: str):
image_dataset = ImageDataset(path)
totlen = len(image_dataset)
trains, tests = torch.utils.data.dataset.random_split(image_dataset, (int(totlen * .7), totlen - int(totlen * .7)),
generator=torch.Generator().manual_seed(42))
train_loader = DataLoader(
trains,
shuffle=True, # shuffle the order of our samples
batch_size=5, # stack 4 samples to a minibatch
num_workers=2 # no background workers (see comment below)
)
test_loader = DataLoader(
tests,
shuffle=True, # shuffle the order of our samples
batch_size=1, # stack 4 samples to a minibatch
num_workers=0 # no background workers (see comment below)
)
return train_loader, test_loader
def rgb2gray(rgb_array: np.ndarray):
r, g, b = rgb_array[:, :, 0], rgb_array[:, :, 1], rgb_array[:, :, 2]
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
return gray