import glob import os import numpy as np import torch.utils.data.dataset from torch.utils.data import Dataset from torch.utils.data import DataLoader from torchvision import transforms from PIL import Image import random import ex4 IMG_SIZE = 100 class ImageDataset(Dataset): def __init__(self, image_dir, offsetrange: (int, int), spacingrange: (int, int), transform_chain: transforms, precision: np.float32 or np.float64 = np.float32): self.image_files = sorted(glob.glob(os.path.join(image_dir, "**", "*.jpg"), recursive=True)) self.precision = precision self.offsetrange = offsetrange self.spacingrange = spacingrange self.transform_chain = transform_chain def __getitem__(self, index): # Open image file, convert to numpy array and scale to [0, 1] target_image = Image.open(self.image_files[index]) target_image = crop_image(target_image) target_image = self.transform_chain(target_image) target_image = preprocess(target_image, self.precision) # calculate image with black grid offset = (random.randint(*self.offsetrange), random.randint(*self.offsetrange)) spacing = (random.randint(*self.spacingrange), random.randint(*self.spacingrange)) doomed_image = ex4.ex4(target_image, offset, spacing) return doomed_image[0], doomed_image[1], np.transpose(target_image, (2, 0, 1)) def __len__(self): return len(self.image_files) def crop_image(image: Image) -> np.array: resize_transforms = transforms.Compose([ transforms.Resize(size=IMG_SIZE), transforms.CenterCrop(size=(IMG_SIZE, IMG_SIZE)), ]) return resize_transforms(image) def preprocess(input: np.array, precision: np.float32 or np.float64) -> np.array: # image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255 # https://www.geeksforgeeks.org/how-to-normalize-images-in-pytorch/ # normalize image from -1 - 1 target_image = np.array(input, dtype=precision) target_image = target_image / 255.0 # Perform normalization for each channel # image = (image - self.norm_mean) / self.norm_std return target_image # postprecess should be the inverese function of preprocess! def postprocess(input: np.array) -> np.array: # todo clipping here correct? some values are >1 because of model target = np.clip(input, 0, 1) target_image = (target * 255.0).astype(np.uint8) return target_image def get_image_loader(path: str, precision: np.float32 or np.float64): # ranges due to project spec image_dataset = ImageDataset(path, offsetrange=(0, 8), spacingrange=(2, 6), transform_chain=transforms.Compose([transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip()]), precision=precision) image_dataset_augmented = ImageDataset(path, offsetrange=(0, 8), spacingrange=(2, 6), transform_chain=transforms.Compose([transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.GaussianBlur(3, 4)]), precision=precision) # merge different datasets here! merged_dataset = torch.utils.data.ConcatDataset([image_dataset, image_dataset_augmented]) totlen = len(merged_dataset) test_set_size = .1 train_split, test_split = torch.utils.data.dataset.random_split(merged_dataset, lengths=(totlen - int(totlen * test_set_size), int(totlen * test_set_size))) train_loader = DataLoader( train_split, shuffle=True, # shuffle the order of our samples batch_size=25, # stack 4 samples to a minibatch num_workers=4 # no background workers (see comment below) ) test_loader = DataLoader( test_split, shuffle=True, # shuffle the order of our samples batch_size=5, # stack 4 samples to a minibatch num_workers=0 # no background workers (see comment below) ) return train_loader, test_loader