ImageImpaint_Python_II/DataLoader.py

86 lines
2.8 KiB
Python

import glob
import os
import numpy as np
import torch.utils.data.dataset
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import ex4
IMG_SIZE = 100
class ImageDataset(Dataset):
def __init__(self, image_dir, precision: np.float32 or np.float64):
self.image_files = sorted(glob.glob(os.path.join(image_dir, "**", "*.jpg"), recursive=True))
self.precision = precision
def __getitem__(self, index):
# Open image file, convert to numpy array and scale to [0, 1]
target_image = Image.open(self.image_files[index])
target_image = preprocess(target_image, self.precision)
# calculate image with black grid
doomed_image = ex4.ex4(target_image, (5, 5), (4, 4))
# convert image to grayscale
# target_image = rgb2gray(target_image) # todo look if gray image makes sense
return doomed_image[0], np.transpose(target_image, (2, 0, 1))
def __len__(self):
return len(self.image_files)
def preprocess(input: np.array, precision: np.float32 or np.float64) -> np.array:
# image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255
resize_transforms = transforms.Compose([
transforms.Resize(size=IMG_SIZE),
transforms.CenterCrop(size=(IMG_SIZE, IMG_SIZE)),
])
input = resize_transforms(input)
# normalize image from 0-1
target_image = np.array(input, dtype=precision) / 255.0
# Perform normalization for each channel
# image = (image - self.norm_mean) / self.norm_std
return target_image
# postprecess should be the inverese function of preprocess!
def postprocess(input: np.array) -> np.array:
target_image = (input * 255.0).astype(np.uint8)
return target_image
def get_image_loader(path: str, precision: np.float32 or np.float64):
image_dataset = ImageDataset(path, precision)
totlen = len(image_dataset)
test_set_size = .1
trains, tests = torch.utils.data.dataset.random_split(image_dataset, lengths=(totlen - int(totlen * test_set_size),
int(totlen * test_set_size)),
generator=torch.Generator().manual_seed(0))
train_loader = DataLoader(
trains,
shuffle=True, # shuffle the order of our samples
batch_size=25, # stack 4 samples to a minibatch
num_workers=4 # no background workers (see comment below)
)
test_loader = DataLoader(
tests,
shuffle=True, # shuffle the order of our samples
batch_size=5, # stack 4 samples to a minibatch
num_workers=0 # no background workers (see comment below)
)
return train_loader, test_loader