ImageImpaint_Python_II/DataLoader.py

92 lines
3.0 KiB
Python

import glob
import os
import numpy as np
import torch.utils.data.dataset
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import ex4
IMG_SIZE = 100
class ImageDataset(Dataset):
def __init__(self, image_dir):
self.image_files = sorted(glob.glob(os.path.join(image_dir, "**", "*.jpg"), recursive=True))
# Mean and std arrays could also be defined as class attributes
# self.norm_mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
# self.norm_std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
def __getitem__(self, index):
# Open image file, convert to numpy array and scale to [0, 1]
target_image = Image.open(self.image_files[index])
# image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255
resize_transforms = transforms.Compose([
transforms.Resize(size=IMG_SIZE),
transforms.CenterCrop(size=(IMG_SIZE, IMG_SIZE)),
])
target_image = resize_transforms(target_image)
target_image = preprocess(target_image)
# calculate image with black grid
doomed_image = ex4.ex4(target_image, (5, 5), (4, 4))
# convert image to grayscale
# target_image = rgb2gray(target_image) # todo look if gray image makes sense
return doomed_image[0], np.transpose(target_image, (2, 0, 1))
def __len__(self):
return len(self.image_files)
def preprocess(input: np.array) -> np.array:
# normalize image from 0-1
target_image = np.array(input, dtype=np.float64) / 255.0
# Perform normalization for each channel
# image = (image - self.norm_mean) / self.norm_std
return target_image
# postprecess should be the inverese function of preprocess!
def postprocess(input: np.array) -> np.array:
target_image = (input * 255.0).astype(np.uint8)
return target_image
def get_image_loader(path: str):
image_dataset = ImageDataset(path)
totlen = len(image_dataset)
test_set_size = .001
trains, tests = torch.utils.data.dataset.random_split(image_dataset, lengths=(totlen - int(totlen * test_set_size),
int(totlen * test_set_size)),
generator=torch.Generator().manual_seed(42))
train_loader = DataLoader(
trains,
shuffle=True, # shuffle the order of our samples
batch_size=5, # stack 4 samples to a minibatch
num_workers=4 # no background workers (see comment below)
)
test_loader = DataLoader(
tests,
shuffle=True, # shuffle the order of our samples
batch_size=1, # stack 4 samples to a minibatch
num_workers=0 # no background workers (see comment below)
)
return train_loader, test_loader
def rgb2gray(rgb_array: np.ndarray):
r, g, b = rgb_array[:, :, 0], rgb_array[:, :, 1], rgb_array[:, :, 2]
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
return gray