2022-06-01 10:27:58 +00:00
|
|
|
import glob
|
|
|
|
import os
|
|
|
|
|
|
|
|
import numpy as np
|
2022-06-01 14:07:32 +00:00
|
|
|
import torch.utils.data.dataset
|
2022-06-01 10:27:58 +00:00
|
|
|
from torch.utils.data import Dataset
|
|
|
|
from torch.utils.data import DataLoader
|
2022-06-28 16:28:36 +00:00
|
|
|
from torchvision import transforms
|
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
import ex4
|
|
|
|
|
|
|
|
IMG_SIZE = 100
|
2022-06-01 10:27:58 +00:00
|
|
|
|
|
|
|
|
|
|
|
class ImageDataset(Dataset):
|
|
|
|
def __init__(self, image_dir):
|
|
|
|
self.image_files = sorted(glob.glob(os.path.join(image_dir, "**", "*.jpg"), recursive=True))
|
|
|
|
# Mean and std arrays could also be defined as class attributes
|
|
|
|
self.norm_mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
|
|
|
self.norm_std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
# Open image file, convert to numpy array and scale to [0, 1]
|
2022-06-28 16:28:36 +00:00
|
|
|
target_image = Image.open(self.image_files[index])
|
|
|
|
# image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255
|
|
|
|
resize_transforms = transforms.Compose([
|
|
|
|
transforms.Resize(size=IMG_SIZE),
|
|
|
|
transforms.CenterCrop(size=(IMG_SIZE, IMG_SIZE)),
|
|
|
|
])
|
|
|
|
target_image = resize_transforms(target_image)
|
|
|
|
|
|
|
|
# normalize image from 0-1
|
|
|
|
target_image = np.array(target_image, dtype=np.float64) / 255.0
|
|
|
|
|
2022-06-01 10:27:58 +00:00
|
|
|
# Perform normalization for each channel
|
2022-06-28 16:28:36 +00:00
|
|
|
# image = (image - self.norm_mean) / self.norm_std
|
|
|
|
|
|
|
|
# calculate image with black grid
|
|
|
|
doomed_image = ex4.ex4(target_image, (5, 5), (4, 4))
|
|
|
|
|
|
|
|
# convert image to grayscale
|
|
|
|
# target_image = rgb2gray(target_image) # todo look if gray image makes sense
|
|
|
|
|
|
|
|
return doomed_image[0], np.transpose(target_image, (2, 0, 1))
|
2022-06-01 10:27:58 +00:00
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.image_files)
|
|
|
|
|
|
|
|
|
|
|
|
def get_image_loader(path: str):
|
|
|
|
image_dataset = ImageDataset(path)
|
2022-06-01 14:07:32 +00:00
|
|
|
totlen = len(image_dataset)
|
|
|
|
trains, tests = torch.utils.data.dataset.random_split(image_dataset, (int(totlen * .7), totlen - int(totlen * .7)),
|
|
|
|
generator=torch.Generator().manual_seed(42))
|
|
|
|
|
|
|
|
train_loader = DataLoader(
|
|
|
|
trains,
|
|
|
|
shuffle=True, # shuffle the order of our samples
|
2022-06-28 16:28:36 +00:00
|
|
|
batch_size=5, # stack 4 samples to a minibatch
|
|
|
|
num_workers=2 # no background workers (see comment below)
|
2022-06-01 14:07:32 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
test_loader = DataLoader(
|
2022-06-28 16:28:36 +00:00
|
|
|
tests,
|
2022-06-01 14:07:32 +00:00
|
|
|
shuffle=True, # shuffle the order of our samples
|
2022-06-28 16:28:36 +00:00
|
|
|
batch_size=1, # stack 4 samples to a minibatch
|
2022-06-01 14:07:32 +00:00
|
|
|
num_workers=0 # no background workers (see comment below)
|
|
|
|
)
|
|
|
|
|
|
|
|
return train_loader, test_loader
|
2022-06-28 16:28:36 +00:00
|
|
|
|
|
|
|
|
|
|
|
def rgb2gray(rgb_array: np.ndarray):
|
|
|
|
r, g, b = rgb_array[:, :, 0], rgb_array[:, :, 1], rgb_array[:, :, 2]
|
|
|
|
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
|
|
|
|
return gray
|