2022-06-01 10:27:58 +00:00
|
|
|
import glob
|
|
|
|
import os
|
|
|
|
|
|
|
|
import numpy as np
|
2022-06-01 14:07:32 +00:00
|
|
|
import torch.utils.data.dataset
|
2022-06-01 10:27:58 +00:00
|
|
|
from torch.utils.data import Dataset
|
|
|
|
from torch.utils.data import DataLoader
|
2022-06-28 16:28:36 +00:00
|
|
|
from torchvision import transforms
|
|
|
|
from PIL import Image
|
2022-07-02 14:11:27 +00:00
|
|
|
import random
|
2022-06-28 16:28:36 +00:00
|
|
|
|
|
|
|
import ex4
|
|
|
|
|
|
|
|
IMG_SIZE = 100
|
2022-06-01 10:27:58 +00:00
|
|
|
|
|
|
|
|
|
|
|
class ImageDataset(Dataset):
|
2022-07-02 14:11:27 +00:00
|
|
|
def __init__(self, image_dir, offsetrange: (int, int), spacingrange: (int, int), transform_chain: transforms,
|
|
|
|
precision: np.float32 or np.float64 = np.float32):
|
2022-06-01 10:27:58 +00:00
|
|
|
self.image_files = sorted(glob.glob(os.path.join(image_dir, "**", "*.jpg"), recursive=True))
|
2022-07-01 13:35:12 +00:00
|
|
|
self.precision = precision
|
2022-07-02 14:11:27 +00:00
|
|
|
self.offsetrange = offsetrange
|
|
|
|
self.spacingrange = spacingrange
|
|
|
|
self.transform_chain = transform_chain
|
2022-06-01 10:27:58 +00:00
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
# Open image file, convert to numpy array and scale to [0, 1]
|
2022-06-28 16:28:36 +00:00
|
|
|
target_image = Image.open(self.image_files[index])
|
2022-07-02 14:11:27 +00:00
|
|
|
target_image = crop_image(target_image)
|
|
|
|
|
|
|
|
target_image = self.transform_chain(target_image)
|
2022-07-01 13:35:12 +00:00
|
|
|
|
|
|
|
target_image = preprocess(target_image, self.precision)
|
2022-06-28 16:28:36 +00:00
|
|
|
|
|
|
|
# calculate image with black grid
|
2022-07-02 14:11:27 +00:00
|
|
|
offset = (random.randint(*self.offsetrange), random.randint(*self.offsetrange))
|
|
|
|
spacing = (random.randint(*self.spacingrange), random.randint(*self.spacingrange))
|
|
|
|
doomed_image = ex4.ex4(target_image, offset, spacing)
|
2022-06-28 16:28:36 +00:00
|
|
|
|
|
|
|
return doomed_image[0], np.transpose(target_image, (2, 0, 1))
|
2022-06-01 10:27:58 +00:00
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.image_files)
|
|
|
|
|
|
|
|
|
2022-07-02 14:11:27 +00:00
|
|
|
def crop_image(image: Image) -> np.array:
|
2022-07-01 13:35:12 +00:00
|
|
|
resize_transforms = transforms.Compose([
|
|
|
|
transforms.Resize(size=IMG_SIZE),
|
|
|
|
transforms.CenterCrop(size=(IMG_SIZE, IMG_SIZE)),
|
|
|
|
])
|
2022-07-02 14:11:27 +00:00
|
|
|
return resize_transforms(image)
|
2022-07-01 13:35:12 +00:00
|
|
|
|
2022-07-02 14:11:27 +00:00
|
|
|
|
|
|
|
def preprocess(input: np.array, precision: np.float32 or np.float64) -> np.array:
|
|
|
|
# image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255
|
|
|
|
# https://www.geeksforgeeks.org/how-to-normalize-images-in-pytorch/
|
|
|
|
# normalize image from -1 - 1
|
|
|
|
target_image = np.array(input, dtype=precision)
|
|
|
|
target_image = target_image / 255.0
|
2022-06-29 15:20:16 +00:00
|
|
|
|
|
|
|
# Perform normalization for each channel
|
|
|
|
# image = (image - self.norm_mean) / self.norm_std
|
|
|
|
|
|
|
|
return target_image
|
|
|
|
|
|
|
|
|
|
|
|
# postprecess should be the inverese function of preprocess!
|
|
|
|
def postprocess(input: np.array) -> np.array:
|
2022-07-02 14:11:27 +00:00
|
|
|
# todo clipping here correct? some values are >1 because of model
|
|
|
|
target = np.clip(input, 0, 1)
|
|
|
|
target_image = (target * 255.0).astype(np.uint8)
|
2022-06-29 15:20:16 +00:00
|
|
|
return target_image
|
|
|
|
|
|
|
|
|
2022-07-01 13:35:12 +00:00
|
|
|
def get_image_loader(path: str, precision: np.float32 or np.float64):
|
2022-07-02 14:11:27 +00:00
|
|
|
# ranges due to project spec
|
|
|
|
image_dataset = ImageDataset(path,
|
|
|
|
offsetrange=(0, 8),
|
|
|
|
spacingrange=(2, 6),
|
|
|
|
transform_chain=transforms.Compose([transforms.RandomHorizontalFlip(),
|
|
|
|
transforms.RandomVerticalFlip()]),
|
|
|
|
precision=precision)
|
|
|
|
|
|
|
|
image_dataset_augmented = ImageDataset(path,
|
|
|
|
offsetrange=(0, 8),
|
|
|
|
spacingrange=(2, 6),
|
|
|
|
transform_chain=transforms.Compose([transforms.RandomHorizontalFlip(),
|
|
|
|
transforms.RandomVerticalFlip(),
|
|
|
|
transforms.GaussianBlur(3, 4)]),
|
|
|
|
precision=precision)
|
|
|
|
|
|
|
|
# merge different datasets here!
|
|
|
|
merged_dataset = torch.utils.data.ConcatDataset([image_dataset, image_dataset_augmented])
|
|
|
|
|
|
|
|
totlen = len(merged_dataset)
|
2022-07-01 13:35:12 +00:00
|
|
|
test_set_size = .1
|
2022-07-02 14:11:27 +00:00
|
|
|
train_split, test_split = torch.utils.data.dataset.random_split(merged_dataset,
|
|
|
|
lengths=(totlen - int(totlen * test_set_size),
|
|
|
|
int(totlen * test_set_size)))
|
2022-06-01 14:07:32 +00:00
|
|
|
|
|
|
|
train_loader = DataLoader(
|
2022-07-02 14:11:27 +00:00
|
|
|
train_split,
|
2022-06-01 14:07:32 +00:00
|
|
|
shuffle=True, # shuffle the order of our samples
|
2022-07-01 13:35:12 +00:00
|
|
|
batch_size=25, # stack 4 samples to a minibatch
|
2022-06-29 15:20:16 +00:00
|
|
|
num_workers=4 # no background workers (see comment below)
|
2022-06-01 14:07:32 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
test_loader = DataLoader(
|
2022-07-02 14:11:27 +00:00
|
|
|
test_split,
|
2022-06-01 14:07:32 +00:00
|
|
|
shuffle=True, # shuffle the order of our samples
|
2022-07-01 13:35:12 +00:00
|
|
|
batch_size=5, # stack 4 samples to a minibatch
|
2022-06-01 14:07:32 +00:00
|
|
|
num_workers=0 # no background workers (see comment below)
|
|
|
|
)
|
|
|
|
|
|
|
|
return train_loader, test_loader
|