import os import sys import PIL import numpy as np import torch from matplotlib import pyplot as plt from packaging.version import Version import DataLoader from DataLoader import get_image_loader from Net import ImageNN from netio import save_model, eval_evalset def get_train_device(): device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f'Device to train net: {device}') return torch.device(device) def train_model(): # Set a known random seed for reproducibility np.random.seed(0) torch.manual_seed(0) device = get_train_device() # Prepare a path to plot to plotpath = "plots/" os.makedirs(plotpath, exist_ok=True) # Load datasets train_loader, test_loader = get_image_loader("training/", precision=np.float32) nn = ImageNN(n_in_channels=6, precision=np.float32) # todo net params nn.train() # init with train modeAdam nn.to(device) # send net to device available optimizer = torch.optim.AdamW(nn.parameters(), lr=1e-3, weight_decay=1e-5) # todo adjust parameters and lr loss_function = torch.nn.MSELoss() loss_function.to(device) n_epochs = 5 # todo epcchs here train_sample_size = len(train_loader) losses = [] best_eval_loss = np.inf for epoch in range(n_epochs): print(f"Epoch {epoch}/{n_epochs}\n") i = 0 for input_tensor, mask, target_tensor in train_loader: input_tensor = input_tensor.to(device) mask = mask.to(device) target_tensor = target_tensor.to(device) optimizer.zero_grad() # reset gradients input_tensor = torch.cat((input_tensor, mask), 1) output = nn(input_tensor) # get model output (forward pass) output_flat = output * (1 - mask) output_flat = output_flat[1 - mask > 0] rest = target_tensor * (1 - mask) rest = rest[1 - mask > 0] loss = loss_function(output_flat, rest) # compute loss given model output and true target loss.backward() # compute gradients (backward pass) optimizer.step() # perform gradient descent update step losses.append(loss.item()) i += train_loader.batch_size print( f'\rTraining epoch {epoch} [{i}/{train_sample_size * train_loader.batch_size}] (curr loss: {loss.item():.3})', end='') # eval model every 3000th sample if i % 3000 == 0: print(f"\nEvaluating model") eval_loss = eval_model(nn, test_loader, loss_function, device) print(f"Evalution loss={eval_loss}") if eval_loss < best_eval_loss: best_eval_loss = eval_loss save_model(nn) nn.train() # Plot output if i % 100 == 0: plot(input_tensor.detach().cpu().numpy()[0], target_tensor.detach().cpu().numpy()[0], output.detach().cpu().numpy()[0], plotpath, i, epoch) # evaluate model with submission pkl file eval_evalset() # func to evaluate our trained model def eval_model(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn, device: torch.device): # switch to eval mode model.eval() loss = .0 # disable gradient calculations with torch.no_grad(): i = 0 for input, mask, target in dataloader: input = input.to(device) target = target.to(device) mask = mask.to(device) input = torch.cat((input, mask), 1) out = model(input) out = out * (1 - mask) out = out[1 - mask > 0] rest = target * (1 - mask) rest = rest[1 - mask > 0] loss += loss_fn(out, rest).item() print(f'\rEval prog[{i}/{len(dataloader) * dataloader.batch_size}]', end='') i += dataloader.batch_size print() loss /= len(dataloader) return loss def plot(inputs, targets, predictions, path, update, epoch): """Plotting the inputs, targets and predictions to file `path`""" os.makedirs(path, exist_ok=True) fig, axes = plt.subplots(ncols=4, figsize=(15, 5)) for ax, data, title in zip(axes, [inputs, targets, predictions, predictions-targets], ["Input", "Target", "Prediction", "diff"]): ax.clear() ax.set_title(title) ax.imshow(DataLoader.postprocess(np.transpose(data[:3, :, :], (1, 2, 0))), interpolation="none") # ax.imshow(np.transpose((data[i]), (1, 2, 0)), interpolation="none") ax.set_axis_off() fig.savefig(os.path.join(path, f"{epoch:02d}_{update:07d}.png"), dpi=100) plt.close(fig) def check_module_versions() -> None: python_check = '(\u2713)' if sys.version_info >= (3, 8) else '(\u2717)' numpy_check = '(\u2713)' if Version(np.__version__) >= Version('1.18') else '(\u2717)' torch_check = '(\u2713)' if Version(torch.__version__) >= Version('1.6.0') else '(\u2717)' pil_check = '(\u2713)' if Version(PIL.__version__) >= Version('6.0.0') else '(\u2717)' print(f'Installed Python version: {sys.version_info.major}.{sys.version_info.minor} {python_check}') print(f'Installed numpy version: {np.__version__} {numpy_check}') print(f'Installed PyTorch version: {torch.__version__} {torch_check}') print(f'Installed PIL version: {PIL.__version__} {pil_check}') assert any(x == '(\u2713)' for x in [python_check, numpy_check, torch_check, pil_check]) if __name__ == '__main__': check_module_versions() train_model()