ImageImpaint_Python_II/DataLoader.py
2022-07-11 23:38:51 +02:00

116 lines
4.5 KiB
Python

import glob
import os
import numpy as np
import torch.utils.data.dataset
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import random
import ex4
IMG_SIZE = 100
class ImageDataset(Dataset):
def __init__(self, image_dir, offsetrange: (int, int), spacingrange: (int, int), transform_chain: transforms,
precision: np.float32 or np.float64 = np.float32):
self.image_files = sorted(glob.glob(os.path.join(image_dir, "**", "*.jpg"), recursive=True))
self.precision = precision
self.offsetrange = offsetrange
self.spacingrange = spacingrange
self.transform_chain = transform_chain
def __getitem__(self, index):
# Open image file, convert to numpy array and scale to [0, 1]
target_image = Image.open(self.image_files[index])
target_image = crop_image(target_image)
target_image = self.transform_chain(target_image)
target_image = preprocess(target_image, self.precision)
# calculate image with black grid
offset = (random.randint(*self.offsetrange), random.randint(*self.offsetrange))
spacing = (random.randint(*self.spacingrange), random.randint(*self.spacingrange))
doomed_image = ex4.ex4(target_image, offset, spacing)
return doomed_image[0], doomed_image[1], np.transpose(target_image, (2, 0, 1))
def __len__(self):
return len(self.image_files)
def crop_image(image: Image) -> np.array:
resize_transforms = transforms.Compose([
transforms.Resize(size=IMG_SIZE),
transforms.CenterCrop(size=(IMG_SIZE, IMG_SIZE)),
])
return resize_transforms(image)
def preprocess(input: np.array, precision: np.float32 or np.float64) -> np.array:
# image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255
# https://www.geeksforgeeks.org/how-to-normalize-images-in-pytorch/
# normalize image from -1 - 1
target_image = np.array(input, dtype=precision)
target_image = target_image / 255.0
# Perform normalization for each channel
# image = (image - self.norm_mean) / self.norm_std
return target_image
# postprecess should be the inverese function of preprocess!
def postprocess(input: np.array) -> np.array:
# todo clipping here correct? some values are >1 because of model
target = np.clip(input, 0, 1)
target_image = (target * 255.0).astype(np.uint8)
return target_image
def get_image_loader(path: str, precision: np.float32 or np.float64):
# ranges due to project spec
image_dataset = ImageDataset(path,
offsetrange=(0, 8),
spacingrange=(2, 6),
transform_chain=transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip()]),
precision=precision)
#
# image_dataset_augmented = ImageDataset(path,
# offsetrange=(0, 8),
# spacingrange=(2, 6),
# transform_chain=transforms.Compose([transforms.RandomHorizontalFlip(),
# transforms.RandomVerticalFlip(),
# transforms.GaussianBlur(3, 4)]),
# precision=precision)
# merge different datasets here!
merged_dataset = torch.utils.data.ConcatDataset([image_dataset])
totlen = len(merged_dataset)
test_set_size = .1
train_split, test_split = torch.utils.data.dataset.random_split(merged_dataset,
lengths=(totlen - int(totlen * test_set_size),
int(totlen * test_set_size)))
train_loader = DataLoader(
train_split,
shuffle=True, # shuffle the order of our samples
batch_size=25, # stack 4 samples to a minibatch
num_workers=4 # no background workers (see comment below)
)
test_loader = DataLoader(
test_split,
shuffle=True, # shuffle the order of our samples
batch_size=5, # stack 4 samples to a minibatch
num_workers=0 # no background workers (see comment below)
)
return train_loader, test_loader