141 lines
5.0 KiB
Python
141 lines
5.0 KiB
Python
import os
|
|
import sys
|
|
|
|
import PIL
|
|
import numpy as np
|
|
import packaging
|
|
import torch
|
|
from matplotlib import pyplot as plt
|
|
from packaging.version import Version
|
|
|
|
import DataLoader
|
|
from DataLoader import get_image_loader
|
|
from Net import ImageNN
|
|
from netio import save_model, eval_evalset
|
|
|
|
|
|
def get_train_device():
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
print(f'Device to train net: {device}')
|
|
return torch.device(device)
|
|
|
|
|
|
def train_model():
|
|
# Set a known random seed for reproducibility
|
|
np.random.seed(0)
|
|
torch.manual_seed(0)
|
|
device = get_train_device()
|
|
|
|
# Prepare a path to plot to
|
|
plotpath = "plots/"
|
|
os.makedirs(plotpath, exist_ok=True)
|
|
|
|
# Load datasets
|
|
train_loader, test_loader = get_image_loader("training/", precision=np.float32)
|
|
nn = ImageNN(n_in_channels=3, precision=np.float32) # todo net params
|
|
nn.train() # init with train modeAdam
|
|
nn.to(device) # send net to device available
|
|
|
|
optimizer = torch.optim.AdamW(nn.parameters(), lr=1e-3, weight_decay=1e-5) # todo adjust parameters and lr
|
|
loss_function = torch.nn.MSELoss()
|
|
loss_function.to(device)
|
|
n_epochs = 5 # todo epcchs here
|
|
|
|
train_sample_size = len(train_loader)
|
|
losses = []
|
|
best_eval_loss = np.inf
|
|
for epoch in range(n_epochs):
|
|
print(f"Epoch {epoch}/{n_epochs}\n")
|
|
i = 0
|
|
for input_tensor, target_tensor in train_loader:
|
|
optimizer.zero_grad() # reset gradients
|
|
|
|
output = nn(input_tensor.to(device)) # get model output (forward pass)
|
|
|
|
loss = loss_function(output.to(device),
|
|
target_tensor.to(device)) # compute loss given model output and true target
|
|
loss.backward() # compute gradients (backward pass)
|
|
optimizer.step() # perform gradient descent update step
|
|
|
|
losses.append(loss.item())
|
|
|
|
i += train_loader.batch_size
|
|
print(
|
|
f'\rTraining epoch {epoch} [{i}/{train_sample_size * train_loader.batch_size}] (curr loss: {loss.item():.3})',
|
|
end='')
|
|
|
|
# eval model every 3000th sample
|
|
if i % 3000 == 0:
|
|
print(f"\nEvaluating model")
|
|
eval_loss = eval_model(nn, test_loader, loss_function, device)
|
|
print(f"Evalution loss={eval_loss}")
|
|
if eval_loss < best_eval_loss:
|
|
best_eval_loss = eval_loss
|
|
save_model(nn)
|
|
|
|
nn.train()
|
|
|
|
# Plot output
|
|
if i % 100 == 0:
|
|
plot(input_tensor.detach().cpu().numpy()[:1], target_tensor.detach().cpu().numpy()[:1],
|
|
output.detach().cpu().numpy()[:1],
|
|
plotpath, i, epoch)
|
|
|
|
# evaluate model with submission pkl file
|
|
eval_evalset()
|
|
|
|
|
|
# func to evaluate our trained model
|
|
def eval_model(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn, device: torch.device):
|
|
# switch to eval mode
|
|
model.eval()
|
|
loss = .0
|
|
# disable gradient calculations
|
|
with torch.no_grad():
|
|
i = 0
|
|
for input, target in dataloader:
|
|
input = input.to(device)
|
|
target = target.to(device)
|
|
|
|
out = model(input)
|
|
loss += loss_fn(out, target).item()
|
|
print(f'\rEval prog[{i}/{len(dataloader) * dataloader.batch_size}]', end='')
|
|
i += dataloader.batch_size
|
|
print()
|
|
loss /= len(dataloader)
|
|
return loss
|
|
|
|
|
|
def plot(inputs, targets, predictions, path, update, epoch):
|
|
"""Plotting the inputs, targets and predictions to file `path`"""
|
|
os.makedirs(path, exist_ok=True)
|
|
fig, axes = plt.subplots(ncols=3, figsize=(15, 5))
|
|
|
|
for i in range(len(inputs)):
|
|
for ax, data, title in zip(axes, [inputs, targets, predictions], ["Input", "Target", "Prediction"]):
|
|
ax.clear()
|
|
ax.set_title(title)
|
|
# ax.imshow(DataLoader.postprocess(np.transpose(data[i], (1, 2, 0))), interpolation="none")
|
|
ax.imshow(np.transpose((data[i]), (1, 2, 0)), interpolation="none")
|
|
ax.set_axis_off()
|
|
fig.savefig(os.path.join(path, f"{epoch:02d}_{update:07d}_{i:02d}.png"), dpi=100)
|
|
|
|
plt.close(fig)
|
|
|
|
|
|
def check_module_versions() -> None:
|
|
python_check = '(\u2713)' if sys.version_info >= (3, 8) else '(\u2717)'
|
|
numpy_check = '(\u2713)' if Version(np.__version__) >= Version('1.18') else '(\u2717)'
|
|
torch_check = '(\u2713)' if Version(torch.__version__) >= Version('1.6.0') else '(\u2717)'
|
|
pil_check = '(\u2713)' if Version(PIL.__version__) >= Version('6.0.0') else '(\u2717)'
|
|
print(f'Installed Python version: {sys.version_info.major}.{sys.version_info.minor} {python_check}')
|
|
print(f'Installed numpy version: {np.__version__} {numpy_check}')
|
|
print(f'Installed PyTorch version: {torch.__version__} {torch_check}')
|
|
print(f'Installed PIL version: {PIL.__version__} {pil_check}')
|
|
assert any(x == '(\u2713)' for x in [python_check, numpy_check, torch_check, pil_check])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
check_module_versions()
|
|
train_model()
|