lots of improvements
data augmentation plotting of intermediate pics
This commit is contained in:
parent
11640a6494
commit
c56e583f68
@ -15,11 +15,13 @@ def apply_model(filepath: str):
|
|||||||
model = load_model()
|
model = load_model()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
pic = DataLoader.preprocess(img, precision=np.float32)
|
pic = DataLoader.crop_image(img)
|
||||||
|
pic = DataLoader.preprocess(pic, precision=np.float32)
|
||||||
pic = ex4.ex4(pic, (5, 5), (4, 4))[0]
|
pic = ex4.ex4(pic, (5, 5), (4, 4))[0]
|
||||||
Image.fromarray((np.transpose(pic * 255.0, (1, 2, 0)).astype(np.uint8))).save("filename_grid.jpg")
|
Image.fromarray((np.transpose(DataLoader.postprocess(pic), (1, 2, 0)).astype(np.uint8))).save("filename_grid.jpg")
|
||||||
out = model(torch.from_numpy(pic).to(device))
|
out = model(torch.from_numpy(pic).to(device))
|
||||||
out = DataLoader.postprocess(out.cpu().detach().numpy())
|
out = out.cpu().detach().numpy()
|
||||||
|
out = DataLoader.postprocess(out)
|
||||||
out = np.transpose(out, (1, 2, 0))
|
out = np.transpose(out, (1, 2, 0))
|
||||||
im = Image.fromarray(out)
|
im = Image.fromarray(out)
|
||||||
im.save("filename.jpg", format="jpeg")
|
im.save("filename.jpg", format="jpeg")
|
||||||
|
@ -7,6 +7,7 @@ from torch.utils.data import Dataset
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
import random
|
||||||
|
|
||||||
import ex4
|
import ex4
|
||||||
|
|
||||||
@ -14,21 +15,27 @@ IMG_SIZE = 100
|
|||||||
|
|
||||||
|
|
||||||
class ImageDataset(Dataset):
|
class ImageDataset(Dataset):
|
||||||
def __init__(self, image_dir, precision: np.float32 or np.float64):
|
def __init__(self, image_dir, offsetrange: (int, int), spacingrange: (int, int), transform_chain: transforms,
|
||||||
|
precision: np.float32 or np.float64 = np.float32):
|
||||||
self.image_files = sorted(glob.glob(os.path.join(image_dir, "**", "*.jpg"), recursive=True))
|
self.image_files = sorted(glob.glob(os.path.join(image_dir, "**", "*.jpg"), recursive=True))
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
|
self.offsetrange = offsetrange
|
||||||
|
self.spacingrange = spacingrange
|
||||||
|
self.transform_chain = transform_chain
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
# Open image file, convert to numpy array and scale to [0, 1]
|
# Open image file, convert to numpy array and scale to [0, 1]
|
||||||
target_image = Image.open(self.image_files[index])
|
target_image = Image.open(self.image_files[index])
|
||||||
|
target_image = crop_image(target_image)
|
||||||
|
|
||||||
|
target_image = self.transform_chain(target_image)
|
||||||
|
|
||||||
target_image = preprocess(target_image, self.precision)
|
target_image = preprocess(target_image, self.precision)
|
||||||
|
|
||||||
# calculate image with black grid
|
# calculate image with black grid
|
||||||
doomed_image = ex4.ex4(target_image, (5, 5), (4, 4))
|
offset = (random.randint(*self.offsetrange), random.randint(*self.offsetrange))
|
||||||
|
spacing = (random.randint(*self.spacingrange), random.randint(*self.spacingrange))
|
||||||
# convert image to grayscale
|
doomed_image = ex4.ex4(target_image, offset, spacing)
|
||||||
# target_image = rgb2gray(target_image) # todo look if gray image makes sense
|
|
||||||
|
|
||||||
return doomed_image[0], np.transpose(target_image, (2, 0, 1))
|
return doomed_image[0], np.transpose(target_image, (2, 0, 1))
|
||||||
|
|
||||||
@ -36,16 +43,20 @@ class ImageDataset(Dataset):
|
|||||||
return len(self.image_files)
|
return len(self.image_files)
|
||||||
|
|
||||||
|
|
||||||
def preprocess(input: np.array, precision: np.float32 or np.float64) -> np.array:
|
def crop_image(image: Image) -> np.array:
|
||||||
# image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255
|
|
||||||
resize_transforms = transforms.Compose([
|
resize_transforms = transforms.Compose([
|
||||||
transforms.Resize(size=IMG_SIZE),
|
transforms.Resize(size=IMG_SIZE),
|
||||||
transforms.CenterCrop(size=(IMG_SIZE, IMG_SIZE)),
|
transforms.CenterCrop(size=(IMG_SIZE, IMG_SIZE)),
|
||||||
])
|
])
|
||||||
input = resize_transforms(input)
|
return resize_transforms(image)
|
||||||
|
|
||||||
# normalize image from 0-1
|
|
||||||
target_image = np.array(input, dtype=precision) / 255.0
|
def preprocess(input: np.array, precision: np.float32 or np.float64) -> np.array:
|
||||||
|
# image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255
|
||||||
|
# https://www.geeksforgeeks.org/how-to-normalize-images-in-pytorch/
|
||||||
|
# normalize image from -1 - 1
|
||||||
|
target_image = np.array(input, dtype=precision)
|
||||||
|
target_image = target_image / 255.0
|
||||||
|
|
||||||
# Perform normalization for each channel
|
# Perform normalization for each channel
|
||||||
# image = (image - self.norm_mean) / self.norm_std
|
# image = (image - self.norm_mean) / self.norm_std
|
||||||
@ -55,28 +66,47 @@ def preprocess(input: np.array, precision: np.float32 or np.float64) -> np.array
|
|||||||
|
|
||||||
# postprecess should be the inverese function of preprocess!
|
# postprecess should be the inverese function of preprocess!
|
||||||
def postprocess(input: np.array) -> np.array:
|
def postprocess(input: np.array) -> np.array:
|
||||||
target_image = (input * 255.0).astype(np.uint8)
|
# todo clipping here correct? some values are >1 because of model
|
||||||
|
target = np.clip(input, 0, 1)
|
||||||
|
target_image = (target * 255.0).astype(np.uint8)
|
||||||
return target_image
|
return target_image
|
||||||
|
|
||||||
|
|
||||||
def get_image_loader(path: str, precision: np.float32 or np.float64):
|
def get_image_loader(path: str, precision: np.float32 or np.float64):
|
||||||
image_dataset = ImageDataset(path, precision)
|
# ranges due to project spec
|
||||||
totlen = len(image_dataset)
|
image_dataset = ImageDataset(path,
|
||||||
|
offsetrange=(0, 8),
|
||||||
|
spacingrange=(2, 6),
|
||||||
|
transform_chain=transforms.Compose([transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.RandomVerticalFlip()]),
|
||||||
|
precision=precision)
|
||||||
|
|
||||||
|
image_dataset_augmented = ImageDataset(path,
|
||||||
|
offsetrange=(0, 8),
|
||||||
|
spacingrange=(2, 6),
|
||||||
|
transform_chain=transforms.Compose([transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.RandomVerticalFlip(),
|
||||||
|
transforms.GaussianBlur(3, 4)]),
|
||||||
|
precision=precision)
|
||||||
|
|
||||||
|
# merge different datasets here!
|
||||||
|
merged_dataset = torch.utils.data.ConcatDataset([image_dataset, image_dataset_augmented])
|
||||||
|
|
||||||
|
totlen = len(merged_dataset)
|
||||||
test_set_size = .1
|
test_set_size = .1
|
||||||
trains, tests = torch.utils.data.dataset.random_split(image_dataset, lengths=(totlen - int(totlen * test_set_size),
|
train_split, test_split = torch.utils.data.dataset.random_split(merged_dataset,
|
||||||
int(totlen * test_set_size)),
|
lengths=(totlen - int(totlen * test_set_size),
|
||||||
generator=torch.Generator().manual_seed(0))
|
int(totlen * test_set_size)))
|
||||||
|
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
trains,
|
train_split,
|
||||||
shuffle=True, # shuffle the order of our samples
|
shuffle=True, # shuffle the order of our samples
|
||||||
batch_size=25, # stack 4 samples to a minibatch
|
batch_size=25, # stack 4 samples to a minibatch
|
||||||
num_workers=4 # no background workers (see comment below)
|
num_workers=4 # no background workers (see comment below)
|
||||||
)
|
)
|
||||||
|
|
||||||
test_loader = DataLoader(
|
test_loader = DataLoader(
|
||||||
tests,
|
test_split,
|
||||||
shuffle=True, # shuffle the order of our samples
|
shuffle=True, # shuffle the order of our samples
|
||||||
batch_size=5, # stack 4 samples to a minibatch
|
batch_size=5, # stack 4 samples to a minibatch
|
||||||
num_workers=0 # no background workers (see comment below)
|
num_workers=0 # no background workers (see comment below)
|
||||||
|
@ -1,13 +1,17 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import PIL
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import packaging
|
||||||
import torch
|
import torch
|
||||||
from PIL.Image import Image
|
from matplotlib import pyplot as plt
|
||||||
|
from packaging.version import Version
|
||||||
|
|
||||||
import DataLoader
|
import DataLoader
|
||||||
from DataLoader import get_image_loader
|
from DataLoader import get_image_loader
|
||||||
from Net import ImageNN
|
from Net import ImageNN
|
||||||
|
from netio import save_model, eval_evalset
|
||||||
# 01.05.22 -- 0.5h
|
|
||||||
from netio import save_model, load_model, eval_evalset
|
|
||||||
|
|
||||||
|
|
||||||
def get_train_device():
|
def get_train_device():
|
||||||
@ -22,16 +26,20 @@ def train_model():
|
|||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
device = get_train_device()
|
device = get_train_device()
|
||||||
|
|
||||||
|
# Prepare a path to plot to
|
||||||
|
plotpath = "plots/"
|
||||||
|
os.makedirs(plotpath, exist_ok=True)
|
||||||
|
|
||||||
# Load datasets
|
# Load datasets
|
||||||
train_loader, test_loader = get_image_loader("training/", precision=np.float32)
|
train_loader, test_loader = get_image_loader("training/", precision=np.float32)
|
||||||
nn = ImageNN(n_in_channels=3, precision=np.float32) # todo net params
|
nn = ImageNN(n_in_channels=3, precision=np.float32) # todo net params
|
||||||
nn.train() # init with train mode
|
nn.train() # init with train modeAdam
|
||||||
nn.to(device) # send net to device available
|
nn.to(device) # send net to device available
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(nn.parameters(), lr=0.1, weight_decay=1e-5) # todo adjust parameters and lr
|
optimizer = torch.optim.AdamW(nn.parameters(), lr=1e-3, weight_decay=1e-5) # todo adjust parameters and lr
|
||||||
loss_function = torch.nn.MSELoss()
|
loss_function = torch.nn.MSELoss()
|
||||||
loss_function.to(device)
|
loss_function.to(device)
|
||||||
n_epochs = 7 # todo epcchs here
|
n_epochs = 5 # todo epcchs here
|
||||||
|
|
||||||
train_sample_size = len(train_loader)
|
train_sample_size = len(train_loader)
|
||||||
losses = []
|
losses = []
|
||||||
@ -40,12 +48,15 @@ def train_model():
|
|||||||
print(f"Epoch {epoch}/{n_epochs}\n")
|
print(f"Epoch {epoch}/{n_epochs}\n")
|
||||||
i = 0
|
i = 0
|
||||||
for input_tensor, target_tensor in train_loader:
|
for input_tensor, target_tensor in train_loader:
|
||||||
|
optimizer.zero_grad() # reset gradients
|
||||||
|
|
||||||
output = nn(input_tensor.to(device)) # get model output (forward pass)
|
output = nn(input_tensor.to(device)) # get model output (forward pass)
|
||||||
|
|
||||||
loss = loss_function(output.to(device), target_tensor.to(device)) # compute loss given model output and true target
|
loss = loss_function(output.to(device),
|
||||||
|
target_tensor.to(device)) # compute loss given model output and true target
|
||||||
loss.backward() # compute gradients (backward pass)
|
loss.backward() # compute gradients (backward pass)
|
||||||
optimizer.step() # perform gradient descent update step
|
optimizer.step() # perform gradient descent update step
|
||||||
optimizer.zero_grad() # reset gradients
|
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
|
|
||||||
i += train_loader.batch_size
|
i += train_loader.batch_size
|
||||||
@ -64,6 +75,12 @@ def train_model():
|
|||||||
|
|
||||||
nn.train()
|
nn.train()
|
||||||
|
|
||||||
|
# Plot output
|
||||||
|
if i % 100 == 0:
|
||||||
|
plot(input_tensor.detach().cpu().numpy()[:1], target_tensor.detach().cpu().numpy()[:1],
|
||||||
|
output.detach().cpu().numpy()[:1],
|
||||||
|
plotpath, i, epoch)
|
||||||
|
|
||||||
# evaluate model with submission pkl file
|
# evaluate model with submission pkl file
|
||||||
eval_evalset()
|
eval_evalset()
|
||||||
|
|
||||||
@ -89,6 +106,35 @@ def eval_model(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader,
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def plot(inputs, targets, predictions, path, update, epoch):
|
||||||
|
"""Plotting the inputs, targets and predictions to file `path`"""
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
fig, axes = plt.subplots(ncols=3, figsize=(15, 5))
|
||||||
|
|
||||||
|
for i in range(len(inputs)):
|
||||||
|
for ax, data, title in zip(axes, [inputs, targets, predictions], ["Input", "Target", "Prediction"]):
|
||||||
|
ax.clear()
|
||||||
|
ax.set_title(title)
|
||||||
|
# ax.imshow(DataLoader.postprocess(np.transpose(data[i], (1, 2, 0))), interpolation="none")
|
||||||
|
ax.imshow(np.transpose((data[i]), (1, 2, 0)), interpolation="none")
|
||||||
|
ax.set_axis_off()
|
||||||
|
fig.savefig(os.path.join(path, f"{epoch:02d}_{update:07d}_{i:02d}.png"), dpi=100)
|
||||||
|
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def check_module_versions() -> None:
|
||||||
|
python_check = '(\u2713)' if sys.version_info >= (3, 8) else '(\u2717)'
|
||||||
|
numpy_check = '(\u2713)' if Version(np.__version__) >= Version('1.18') else '(\u2717)'
|
||||||
|
torch_check = '(\u2713)' if Version(torch.__version__) >= Version('1.6.0') else '(\u2717)'
|
||||||
|
pil_check = '(\u2713)' if Version(PIL.__version__) >= Version('6.0.0') else '(\u2717)'
|
||||||
|
print(f'Installed Python version: {sys.version_info.major}.{sys.version_info.minor} {python_check}')
|
||||||
|
print(f'Installed numpy version: {np.__version__} {numpy_check}')
|
||||||
|
print(f'Installed PyTorch version: {torch.__version__} {torch_check}')
|
||||||
|
print(f'Installed PIL version: {PIL.__version__} {pil_check}')
|
||||||
|
assert any(x == '(\u2713)' for x in [python_check, numpy_check, torch_check, pil_check])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
check_module_versions()
|
||||||
train_model()
|
train_model()
|
||||||
|
@ -70,7 +70,7 @@ def scoring_file(prediction_file: str, target_file: str):
|
|||||||
"""Computes the mean RMSE loss on two lists of numpy arrays stored in pickle files prediction_file and targets_file
|
"""Computes the mean RMSE loss on two lists of numpy arrays stored in pickle files prediction_file and targets_file
|
||||||
|
|
||||||
Computation of mean RMSE loss, as used in the challenge for exercise 5. See files "example_testset.pkl" and
|
Computation of mean RMSE loss, as used in the challenge for exercise 5. See files "example_testset.pkl" and
|
||||||
"example_submission_random.pkl" for an example test set and example targets, respectively. The real test set
|
"example_submission_random.pkl" for an example testing set and example targets, respectively. The real testing set
|
||||||
(without targets) will be available as download (see assignment sheet 2).
|
(without targets) will be available as download (see assignment sheet 2).
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
Loading…
Reference in New Issue
Block a user