ImageImpaint_Python_II/ImageImpaint.py
2022-07-11 23:38:51 +02:00

159 lines
5.5 KiB
Python

import os
import sys
import PIL
import numpy as np
import torch
from matplotlib import pyplot as plt
from packaging.version import Version
import DataLoader
from DataLoader import get_image_loader
from Net import ImageNN
from netio import save_model, eval_evalset
def get_train_device():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device to train net: {device}')
return torch.device(device)
def train_model():
# Set a known random seed for reproducibility
np.random.seed(0)
torch.manual_seed(0)
device = get_train_device()
# Prepare a path to plot to
plotpath = "plots/"
os.makedirs(plotpath, exist_ok=True)
# Load datasets
train_loader, test_loader = get_image_loader("training/", precision=np.float32)
nn = ImageNN(n_in_channels=6, precision=np.float32) # todo net params
nn.train() # init with train modeAdam
nn.to(device) # send net to device available
optimizer = torch.optim.AdamW(nn.parameters(), lr=1e-3, weight_decay=1e-5) # todo adjust parameters and lr
loss_function = torch.nn.MSELoss()
loss_function.to(device)
n_epochs = 5 # todo epcchs here
train_sample_size = len(train_loader)
losses = []
best_eval_loss = np.inf
for epoch in range(n_epochs):
print(f"Epoch {epoch}/{n_epochs}\n")
i = 0
for input_tensor, mask, target_tensor in train_loader:
input_tensor = input_tensor.to(device)
mask = mask.to(device)
target_tensor = target_tensor.to(device)
optimizer.zero_grad() # reset gradients
input_tensor = torch.cat((input_tensor, mask), 1)
output = nn(input_tensor) # get model output (forward pass)
output_flat = output * (1 - mask)
output_flat = output_flat[1 - mask > 0]
rest = target_tensor * (1 - mask)
rest = rest[1 - mask > 0]
loss = loss_function(output_flat, rest) # compute loss given model output and true target
loss.backward() # compute gradients (backward pass)
optimizer.step() # perform gradient descent update step
losses.append(loss.item())
i += train_loader.batch_size
print(
f'\rTraining epoch {epoch} [{i}/{train_sample_size * train_loader.batch_size}] (curr loss: {loss.item():.3})',
end='')
# eval model every 3000th sample
if i % 3000 == 0:
print(f"\nEvaluating model")
eval_loss = eval_model(nn, test_loader, loss_function, device)
print(f"Evalution loss={eval_loss}")
if eval_loss < best_eval_loss:
best_eval_loss = eval_loss
save_model(nn)
nn.train()
# Plot output
if i % 100 == 0:
plot(input_tensor.detach().cpu().numpy()[0], target_tensor.detach().cpu().numpy()[0],
output.detach().cpu().numpy()[0],
plotpath, i, epoch)
# evaluate model with submission pkl file
eval_evalset()
# func to evaluate our trained model
def eval_model(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn, device: torch.device):
# switch to eval mode
model.eval()
loss = .0
# disable gradient calculations
with torch.no_grad():
i = 0
for input, mask, target in dataloader:
input = input.to(device)
target = target.to(device)
mask = mask.to(device)
input = torch.cat((input, mask), 1)
out = model(input)
out = out * (1 - mask)
out = out[1 - mask > 0]
rest = target * (1 - mask)
rest = rest[1 - mask > 0]
loss += loss_fn(out, rest).item()
print(f'\rEval prog[{i}/{len(dataloader) * dataloader.batch_size}]', end='')
i += dataloader.batch_size
print()
loss /= len(dataloader)
return loss
def plot(inputs, targets, predictions, path, update, epoch):
"""Plotting the inputs, targets and predictions to file `path`"""
os.makedirs(path, exist_ok=True)
fig, axes = plt.subplots(ncols=4, figsize=(15, 5))
for ax, data, title in zip(axes, [inputs, targets, predictions, predictions-targets], ["Input", "Target", "Prediction", "diff"]):
ax.clear()
ax.set_title(title)
ax.imshow(DataLoader.postprocess(np.transpose(data[:3, :, :], (1, 2, 0))), interpolation="none")
# ax.imshow(np.transpose((data[i]), (1, 2, 0)), interpolation="none")
ax.set_axis_off()
fig.savefig(os.path.join(path, f"{epoch:02d}_{update:07d}.png"), dpi=100)
plt.close(fig)
def check_module_versions() -> None:
python_check = '(\u2713)' if sys.version_info >= (3, 8) else '(\u2717)'
numpy_check = '(\u2713)' if Version(np.__version__) >= Version('1.18') else '(\u2717)'
torch_check = '(\u2713)' if Version(torch.__version__) >= Version('1.6.0') else '(\u2717)'
pil_check = '(\u2713)' if Version(PIL.__version__) >= Version('6.0.0') else '(\u2717)'
print(f'Installed Python version: {sys.version_info.major}.{sys.version_info.minor} {python_check}')
print(f'Installed numpy version: {np.__version__} {numpy_check}')
print(f'Installed PyTorch version: {torch.__version__} {torch_check}')
print(f'Installed PIL version: {PIL.__version__} {pil_check}')
assert any(x == '(\u2713)' for x in [python_check, numpy_check, torch_check, pil_check])
if __name__ == '__main__':
check_module_versions()
train_model()