diff --git a/ApplyModel.py b/ApplyModel.py index bf53f94..374f042 100644 --- a/ApplyModel.py +++ b/ApplyModel.py @@ -15,11 +15,13 @@ def apply_model(filepath: str): model = load_model() model.to(device) - pic = DataLoader.preprocess(img, precision=np.float32) + pic = DataLoader.crop_image(img) + pic = DataLoader.preprocess(pic, precision=np.float32) pic = ex4.ex4(pic, (5, 5), (4, 4))[0] - Image.fromarray((np.transpose(pic * 255.0, (1, 2, 0)).astype(np.uint8))).save("filename_grid.jpg") + Image.fromarray((np.transpose(DataLoader.postprocess(pic), (1, 2, 0)).astype(np.uint8))).save("filename_grid.jpg") out = model(torch.from_numpy(pic).to(device)) - out = DataLoader.postprocess(out.cpu().detach().numpy()) + out = out.cpu().detach().numpy() + out = DataLoader.postprocess(out) out = np.transpose(out, (1, 2, 0)) im = Image.fromarray(out) im.save("filename.jpg", format="jpeg") diff --git a/DataLoader.py b/DataLoader.py index edb0f28..a8110bf 100644 --- a/DataLoader.py +++ b/DataLoader.py @@ -7,6 +7,7 @@ from torch.utils.data import Dataset from torch.utils.data import DataLoader from torchvision import transforms from PIL import Image +import random import ex4 @@ -14,21 +15,27 @@ IMG_SIZE = 100 class ImageDataset(Dataset): - def __init__(self, image_dir, precision: np.float32 or np.float64): + def __init__(self, image_dir, offsetrange: (int, int), spacingrange: (int, int), transform_chain: transforms, + precision: np.float32 or np.float64 = np.float32): self.image_files = sorted(glob.glob(os.path.join(image_dir, "**", "*.jpg"), recursive=True)) self.precision = precision + self.offsetrange = offsetrange + self.spacingrange = spacingrange + self.transform_chain = transform_chain def __getitem__(self, index): # Open image file, convert to numpy array and scale to [0, 1] target_image = Image.open(self.image_files[index]) + target_image = crop_image(target_image) + + target_image = self.transform_chain(target_image) target_image = preprocess(target_image, self.precision) # calculate image with black grid - doomed_image = ex4.ex4(target_image, (5, 5), (4, 4)) - - # convert image to grayscale - # target_image = rgb2gray(target_image) # todo look if gray image makes sense + offset = (random.randint(*self.offsetrange), random.randint(*self.offsetrange)) + spacing = (random.randint(*self.spacingrange), random.randint(*self.spacingrange)) + doomed_image = ex4.ex4(target_image, offset, spacing) return doomed_image[0], np.transpose(target_image, (2, 0, 1)) @@ -36,16 +43,20 @@ class ImageDataset(Dataset): return len(self.image_files) -def preprocess(input: np.array, precision: np.float32 or np.float64) -> np.array: - # image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255 +def crop_image(image: Image) -> np.array: resize_transforms = transforms.Compose([ transforms.Resize(size=IMG_SIZE), transforms.CenterCrop(size=(IMG_SIZE, IMG_SIZE)), ]) - input = resize_transforms(input) + return resize_transforms(image) - # normalize image from 0-1 - target_image = np.array(input, dtype=precision) / 255.0 + +def preprocess(input: np.array, precision: np.float32 or np.float64) -> np.array: + # image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255 + # https://www.geeksforgeeks.org/how-to-normalize-images-in-pytorch/ + # normalize image from -1 - 1 + target_image = np.array(input, dtype=precision) + target_image = target_image / 255.0 # Perform normalization for each channel # image = (image - self.norm_mean) / self.norm_std @@ -55,28 +66,47 @@ def preprocess(input: np.array, precision: np.float32 or np.float64) -> np.array # postprecess should be the inverese function of preprocess! def postprocess(input: np.array) -> np.array: - target_image = (input * 255.0).astype(np.uint8) + # todo clipping here correct? some values are >1 because of model + target = np.clip(input, 0, 1) + target_image = (target * 255.0).astype(np.uint8) return target_image def get_image_loader(path: str, precision: np.float32 or np.float64): - image_dataset = ImageDataset(path, precision) - totlen = len(image_dataset) + # ranges due to project spec + image_dataset = ImageDataset(path, + offsetrange=(0, 8), + spacingrange=(2, 6), + transform_chain=transforms.Compose([transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip()]), + precision=precision) + image_dataset_augmented = ImageDataset(path, + offsetrange=(0, 8), + spacingrange=(2, 6), + transform_chain=transforms.Compose([transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + transforms.GaussianBlur(3, 4)]), + precision=precision) + + # merge different datasets here! + merged_dataset = torch.utils.data.ConcatDataset([image_dataset, image_dataset_augmented]) + + totlen = len(merged_dataset) test_set_size = .1 - trains, tests = torch.utils.data.dataset.random_split(image_dataset, lengths=(totlen - int(totlen * test_set_size), - int(totlen * test_set_size)), - generator=torch.Generator().manual_seed(0)) + train_split, test_split = torch.utils.data.dataset.random_split(merged_dataset, + lengths=(totlen - int(totlen * test_set_size), + int(totlen * test_set_size))) train_loader = DataLoader( - trains, + train_split, shuffle=True, # shuffle the order of our samples batch_size=25, # stack 4 samples to a minibatch num_workers=4 # no background workers (see comment below) ) test_loader = DataLoader( - tests, + test_split, shuffle=True, # shuffle the order of our samples batch_size=5, # stack 4 samples to a minibatch num_workers=0 # no background workers (see comment below) diff --git a/ImageImpaint.py b/ImageImpaint.py index bba021f..667eaaa 100644 --- a/ImageImpaint.py +++ b/ImageImpaint.py @@ -1,13 +1,17 @@ +import os +import sys + +import PIL import numpy as np +import packaging import torch -from PIL.Image import Image +from matplotlib import pyplot as plt +from packaging.version import Version import DataLoader from DataLoader import get_image_loader from Net import ImageNN - -# 01.05.22 -- 0.5h -from netio import save_model, load_model, eval_evalset +from netio import save_model, eval_evalset def get_train_device(): @@ -22,16 +26,20 @@ def train_model(): torch.manual_seed(0) device = get_train_device() + # Prepare a path to plot to + plotpath = "plots/" + os.makedirs(plotpath, exist_ok=True) + # Load datasets train_loader, test_loader = get_image_loader("training/", precision=np.float32) nn = ImageNN(n_in_channels=3, precision=np.float32) # todo net params - nn.train() # init with train mode + nn.train() # init with train modeAdam nn.to(device) # send net to device available - optimizer = torch.optim.AdamW(nn.parameters(), lr=0.1, weight_decay=1e-5) # todo adjust parameters and lr + optimizer = torch.optim.AdamW(nn.parameters(), lr=1e-3, weight_decay=1e-5) # todo adjust parameters and lr loss_function = torch.nn.MSELoss() loss_function.to(device) - n_epochs = 7 # todo epcchs here + n_epochs = 5 # todo epcchs here train_sample_size = len(train_loader) losses = [] @@ -40,12 +48,15 @@ def train_model(): print(f"Epoch {epoch}/{n_epochs}\n") i = 0 for input_tensor, target_tensor in train_loader: + optimizer.zero_grad() # reset gradients + output = nn(input_tensor.to(device)) # get model output (forward pass) - loss = loss_function(output.to(device), target_tensor.to(device)) # compute loss given model output and true target + loss = loss_function(output.to(device), + target_tensor.to(device)) # compute loss given model output and true target loss.backward() # compute gradients (backward pass) optimizer.step() # perform gradient descent update step - optimizer.zero_grad() # reset gradients + losses.append(loss.item()) i += train_loader.batch_size @@ -64,6 +75,12 @@ def train_model(): nn.train() + # Plot output + if i % 100 == 0: + plot(input_tensor.detach().cpu().numpy()[:1], target_tensor.detach().cpu().numpy()[:1], + output.detach().cpu().numpy()[:1], + plotpath, i, epoch) + # evaluate model with submission pkl file eval_evalset() @@ -89,6 +106,35 @@ def eval_model(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, return loss +def plot(inputs, targets, predictions, path, update, epoch): + """Plotting the inputs, targets and predictions to file `path`""" + os.makedirs(path, exist_ok=True) + fig, axes = plt.subplots(ncols=3, figsize=(15, 5)) + + for i in range(len(inputs)): + for ax, data, title in zip(axes, [inputs, targets, predictions], ["Input", "Target", "Prediction"]): + ax.clear() + ax.set_title(title) + # ax.imshow(DataLoader.postprocess(np.transpose(data[i], (1, 2, 0))), interpolation="none") + ax.imshow(np.transpose((data[i]), (1, 2, 0)), interpolation="none") + ax.set_axis_off() + fig.savefig(os.path.join(path, f"{epoch:02d}_{update:07d}_{i:02d}.png"), dpi=100) + + plt.close(fig) + + +def check_module_versions() -> None: + python_check = '(\u2713)' if sys.version_info >= (3, 8) else '(\u2717)' + numpy_check = '(\u2713)' if Version(np.__version__) >= Version('1.18') else '(\u2717)' + torch_check = '(\u2713)' if Version(torch.__version__) >= Version('1.6.0') else '(\u2717)' + pil_check = '(\u2713)' if Version(PIL.__version__) >= Version('6.0.0') else '(\u2717)' + print(f'Installed Python version: {sys.version_info.major}.{sys.version_info.minor} {python_check}') + print(f'Installed numpy version: {np.__version__} {numpy_check}') + print(f'Installed PyTorch version: {torch.__version__} {torch_check}') + print(f'Installed PIL version: {PIL.__version__} {pil_check}') + assert any(x == '(\u2713)' for x in [python_check, numpy_check, torch_check, pil_check]) + if __name__ == '__main__': + check_module_versions() train_model() diff --git a/Scoring.py b/Scoring.py index a01c0b4..e8002f7 100644 --- a/Scoring.py +++ b/Scoring.py @@ -70,7 +70,7 @@ def scoring_file(prediction_file: str, target_file: str): """Computes the mean RMSE loss on two lists of numpy arrays stored in pickle files prediction_file and targets_file Computation of mean RMSE loss, as used in the challenge for exercise 5. See files "example_testset.pkl" and - "example_submission_random.pkl" for an example test set and example targets, respectively. The real test set + "example_submission_random.pkl" for an example testing set and example targets, respectively. The real testing set (without targets) will be available as download (see assignment sheet 2). Parameters