lots of improvements

data augmentation
plotting of intermediate pics
This commit is contained in:
lukas-heiligenbrunner 2022-07-02 16:11:27 +02:00
parent 11640a6494
commit c56e583f68
4 changed files with 109 additions and 31 deletions

View File

@ -15,11 +15,13 @@ def apply_model(filepath: str):
model = load_model()
model.to(device)
pic = DataLoader.preprocess(img, precision=np.float32)
pic = DataLoader.crop_image(img)
pic = DataLoader.preprocess(pic, precision=np.float32)
pic = ex4.ex4(pic, (5, 5), (4, 4))[0]
Image.fromarray((np.transpose(pic * 255.0, (1, 2, 0)).astype(np.uint8))).save("filename_grid.jpg")
Image.fromarray((np.transpose(DataLoader.postprocess(pic), (1, 2, 0)).astype(np.uint8))).save("filename_grid.jpg")
out = model(torch.from_numpy(pic).to(device))
out = DataLoader.postprocess(out.cpu().detach().numpy())
out = out.cpu().detach().numpy()
out = DataLoader.postprocess(out)
out = np.transpose(out, (1, 2, 0))
im = Image.fromarray(out)
im.save("filename.jpg", format="jpeg")

View File

@ -7,6 +7,7 @@ from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import random
import ex4
@ -14,21 +15,27 @@ IMG_SIZE = 100
class ImageDataset(Dataset):
def __init__(self, image_dir, precision: np.float32 or np.float64):
def __init__(self, image_dir, offsetrange: (int, int), spacingrange: (int, int), transform_chain: transforms,
precision: np.float32 or np.float64 = np.float32):
self.image_files = sorted(glob.glob(os.path.join(image_dir, "**", "*.jpg"), recursive=True))
self.precision = precision
self.offsetrange = offsetrange
self.spacingrange = spacingrange
self.transform_chain = transform_chain
def __getitem__(self, index):
# Open image file, convert to numpy array and scale to [0, 1]
target_image = Image.open(self.image_files[index])
target_image = crop_image(target_image)
target_image = self.transform_chain(target_image)
target_image = preprocess(target_image, self.precision)
# calculate image with black grid
doomed_image = ex4.ex4(target_image, (5, 5), (4, 4))
# convert image to grayscale
# target_image = rgb2gray(target_image) # todo look if gray image makes sense
offset = (random.randint(*self.offsetrange), random.randint(*self.offsetrange))
spacing = (random.randint(*self.spacingrange), random.randint(*self.spacingrange))
doomed_image = ex4.ex4(target_image, offset, spacing)
return doomed_image[0], np.transpose(target_image, (2, 0, 1))
@ -36,16 +43,20 @@ class ImageDataset(Dataset):
return len(self.image_files)
def preprocess(input: np.array, precision: np.float32 or np.float64) -> np.array:
# image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255
def crop_image(image: Image) -> np.array:
resize_transforms = transforms.Compose([
transforms.Resize(size=IMG_SIZE),
transforms.CenterCrop(size=(IMG_SIZE, IMG_SIZE)),
])
input = resize_transforms(input)
return resize_transforms(image)
# normalize image from 0-1
target_image = np.array(input, dtype=precision) / 255.0
def preprocess(input: np.array, precision: np.float32 or np.float64) -> np.array:
# image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255
# https://www.geeksforgeeks.org/how-to-normalize-images-in-pytorch/
# normalize image from -1 - 1
target_image = np.array(input, dtype=precision)
target_image = target_image / 255.0
# Perform normalization for each channel
# image = (image - self.norm_mean) / self.norm_std
@ -55,28 +66,47 @@ def preprocess(input: np.array, precision: np.float32 or np.float64) -> np.array
# postprecess should be the inverese function of preprocess!
def postprocess(input: np.array) -> np.array:
target_image = (input * 255.0).astype(np.uint8)
# todo clipping here correct? some values are >1 because of model
target = np.clip(input, 0, 1)
target_image = (target * 255.0).astype(np.uint8)
return target_image
def get_image_loader(path: str, precision: np.float32 or np.float64):
image_dataset = ImageDataset(path, precision)
totlen = len(image_dataset)
# ranges due to project spec
image_dataset = ImageDataset(path,
offsetrange=(0, 8),
spacingrange=(2, 6),
transform_chain=transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip()]),
precision=precision)
image_dataset_augmented = ImageDataset(path,
offsetrange=(0, 8),
spacingrange=(2, 6),
transform_chain=transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.GaussianBlur(3, 4)]),
precision=precision)
# merge different datasets here!
merged_dataset = torch.utils.data.ConcatDataset([image_dataset, image_dataset_augmented])
totlen = len(merged_dataset)
test_set_size = .1
trains, tests = torch.utils.data.dataset.random_split(image_dataset, lengths=(totlen - int(totlen * test_set_size),
int(totlen * test_set_size)),
generator=torch.Generator().manual_seed(0))
train_split, test_split = torch.utils.data.dataset.random_split(merged_dataset,
lengths=(totlen - int(totlen * test_set_size),
int(totlen * test_set_size)))
train_loader = DataLoader(
trains,
train_split,
shuffle=True, # shuffle the order of our samples
batch_size=25, # stack 4 samples to a minibatch
num_workers=4 # no background workers (see comment below)
)
test_loader = DataLoader(
tests,
test_split,
shuffle=True, # shuffle the order of our samples
batch_size=5, # stack 4 samples to a minibatch
num_workers=0 # no background workers (see comment below)

View File

@ -1,13 +1,17 @@
import os
import sys
import PIL
import numpy as np
import packaging
import torch
from PIL.Image import Image
from matplotlib import pyplot as plt
from packaging.version import Version
import DataLoader
from DataLoader import get_image_loader
from Net import ImageNN
# 01.05.22 -- 0.5h
from netio import save_model, load_model, eval_evalset
from netio import save_model, eval_evalset
def get_train_device():
@ -22,16 +26,20 @@ def train_model():
torch.manual_seed(0)
device = get_train_device()
# Prepare a path to plot to
plotpath = "plots/"
os.makedirs(plotpath, exist_ok=True)
# Load datasets
train_loader, test_loader = get_image_loader("training/", precision=np.float32)
nn = ImageNN(n_in_channels=3, precision=np.float32) # todo net params
nn.train() # init with train mode
nn.train() # init with train modeAdam
nn.to(device) # send net to device available
optimizer = torch.optim.AdamW(nn.parameters(), lr=0.1, weight_decay=1e-5) # todo adjust parameters and lr
optimizer = torch.optim.AdamW(nn.parameters(), lr=1e-3, weight_decay=1e-5) # todo adjust parameters and lr
loss_function = torch.nn.MSELoss()
loss_function.to(device)
n_epochs = 7 # todo epcchs here
n_epochs = 5 # todo epcchs here
train_sample_size = len(train_loader)
losses = []
@ -40,12 +48,15 @@ def train_model():
print(f"Epoch {epoch}/{n_epochs}\n")
i = 0
for input_tensor, target_tensor in train_loader:
optimizer.zero_grad() # reset gradients
output = nn(input_tensor.to(device)) # get model output (forward pass)
loss = loss_function(output.to(device), target_tensor.to(device)) # compute loss given model output and true target
loss = loss_function(output.to(device),
target_tensor.to(device)) # compute loss given model output and true target
loss.backward() # compute gradients (backward pass)
optimizer.step() # perform gradient descent update step
optimizer.zero_grad() # reset gradients
losses.append(loss.item())
i += train_loader.batch_size
@ -64,6 +75,12 @@ def train_model():
nn.train()
# Plot output
if i % 100 == 0:
plot(input_tensor.detach().cpu().numpy()[:1], target_tensor.detach().cpu().numpy()[:1],
output.detach().cpu().numpy()[:1],
plotpath, i, epoch)
# evaluate model with submission pkl file
eval_evalset()
@ -89,6 +106,35 @@ def eval_model(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader,
return loss
def plot(inputs, targets, predictions, path, update, epoch):
"""Plotting the inputs, targets and predictions to file `path`"""
os.makedirs(path, exist_ok=True)
fig, axes = plt.subplots(ncols=3, figsize=(15, 5))
for i in range(len(inputs)):
for ax, data, title in zip(axes, [inputs, targets, predictions], ["Input", "Target", "Prediction"]):
ax.clear()
ax.set_title(title)
# ax.imshow(DataLoader.postprocess(np.transpose(data[i], (1, 2, 0))), interpolation="none")
ax.imshow(np.transpose((data[i]), (1, 2, 0)), interpolation="none")
ax.set_axis_off()
fig.savefig(os.path.join(path, f"{epoch:02d}_{update:07d}_{i:02d}.png"), dpi=100)
plt.close(fig)
def check_module_versions() -> None:
python_check = '(\u2713)' if sys.version_info >= (3, 8) else '(\u2717)'
numpy_check = '(\u2713)' if Version(np.__version__) >= Version('1.18') else '(\u2717)'
torch_check = '(\u2713)' if Version(torch.__version__) >= Version('1.6.0') else '(\u2717)'
pil_check = '(\u2713)' if Version(PIL.__version__) >= Version('6.0.0') else '(\u2717)'
print(f'Installed Python version: {sys.version_info.major}.{sys.version_info.minor} {python_check}')
print(f'Installed numpy version: {np.__version__} {numpy_check}')
print(f'Installed PyTorch version: {torch.__version__} {torch_check}')
print(f'Installed PIL version: {PIL.__version__} {pil_check}')
assert any(x == '(\u2713)' for x in [python_check, numpy_check, torch_check, pil_check])
if __name__ == '__main__':
check_module_versions()
train_model()

View File

@ -70,7 +70,7 @@ def scoring_file(prediction_file: str, target_file: str):
"""Computes the mean RMSE loss on two lists of numpy arrays stored in pickle files prediction_file and targets_file
Computation of mean RMSE loss, as used in the challenge for exercise 5. See files "example_testset.pkl" and
"example_submission_random.pkl" for an example test set and example targets, respectively. The real test set
"example_submission_random.pkl" for an example testing set and example targets, respectively. The real testing set
(without targets) will be available as download (see assignment sheet 2).
Parameters