ImageImpaint_Python_II/DataLoader.py

86 lines
2.8 KiB
Python
Raw Normal View History

2022-06-01 10:27:58 +00:00
import glob
import os
import numpy as np
2022-06-01 14:07:32 +00:00
import torch.utils.data.dataset
2022-06-01 10:27:58 +00:00
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
2022-06-28 16:28:36 +00:00
from torchvision import transforms
from PIL import Image
import ex4
IMG_SIZE = 100
2022-06-01 10:27:58 +00:00
class ImageDataset(Dataset):
2022-07-01 13:35:12 +00:00
def __init__(self, image_dir, precision: np.float32 or np.float64):
2022-06-01 10:27:58 +00:00
self.image_files = sorted(glob.glob(os.path.join(image_dir, "**", "*.jpg"), recursive=True))
2022-07-01 13:35:12 +00:00
self.precision = precision
2022-06-01 10:27:58 +00:00
def __getitem__(self, index):
# Open image file, convert to numpy array and scale to [0, 1]
2022-06-28 16:28:36 +00:00
target_image = Image.open(self.image_files[index])
2022-07-01 13:35:12 +00:00
target_image = preprocess(target_image, self.precision)
2022-06-28 16:28:36 +00:00
# calculate image with black grid
doomed_image = ex4.ex4(target_image, (5, 5), (4, 4))
# convert image to grayscale
# target_image = rgb2gray(target_image) # todo look if gray image makes sense
return doomed_image[0], np.transpose(target_image, (2, 0, 1))
2022-06-01 10:27:58 +00:00
def __len__(self):
return len(self.image_files)
2022-07-01 13:35:12 +00:00
def preprocess(input: np.array, precision: np.float32 or np.float64) -> np.array:
# image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255
resize_transforms = transforms.Compose([
transforms.Resize(size=IMG_SIZE),
transforms.CenterCrop(size=(IMG_SIZE, IMG_SIZE)),
])
input = resize_transforms(input)
# normalize image from 0-1
2022-07-01 13:35:12 +00:00
target_image = np.array(input, dtype=precision) / 255.0
# Perform normalization for each channel
# image = (image - self.norm_mean) / self.norm_std
return target_image
# postprecess should be the inverese function of preprocess!
def postprocess(input: np.array) -> np.array:
target_image = (input * 255.0).astype(np.uint8)
return target_image
2022-07-01 13:35:12 +00:00
def get_image_loader(path: str, precision: np.float32 or np.float64):
image_dataset = ImageDataset(path, precision)
2022-06-01 14:07:32 +00:00
totlen = len(image_dataset)
2022-07-01 13:35:12 +00:00
test_set_size = .1
trains, tests = torch.utils.data.dataset.random_split(image_dataset, lengths=(totlen - int(totlen * test_set_size),
int(totlen * test_set_size)),
2022-07-01 13:35:12 +00:00
generator=torch.Generator().manual_seed(0))
2022-06-01 14:07:32 +00:00
train_loader = DataLoader(
trains,
shuffle=True, # shuffle the order of our samples
2022-07-01 13:35:12 +00:00
batch_size=25, # stack 4 samples to a minibatch
num_workers=4 # no background workers (see comment below)
2022-06-01 14:07:32 +00:00
)
test_loader = DataLoader(
2022-06-28 16:28:36 +00:00
tests,
2022-06-01 14:07:32 +00:00
shuffle=True, # shuffle the order of our samples
2022-07-01 13:35:12 +00:00
batch_size=5, # stack 4 samples to a minibatch
2022-06-01 14:07:32 +00:00
num_workers=0 # no background workers (see comment below)
)
return train_loader, test_loader