ImageImpaint_Python_II/ImageImpaint.py

95 lines
3.0 KiB
Python
Raw Normal View History

2022-06-28 16:28:36 +00:00
import numpy as np
2022-06-01 10:27:58 +00:00
import torch
2022-07-01 13:35:12 +00:00
from PIL.Image import Image
2022-06-01 10:27:58 +00:00
2022-07-01 13:35:12 +00:00
import DataLoader
2022-06-01 10:27:58 +00:00
from DataLoader import get_image_loader
from Net import ImageNN
2022-06-01 14:07:32 +00:00
# 01.05.22 -- 0.5h
2022-07-01 13:35:12 +00:00
from netio import save_model, load_model, eval_evalset
2022-06-01 14:07:32 +00:00
2022-06-28 16:28:36 +00:00
def get_train_device():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device to train net: {device}')
return torch.device(device)
2022-06-01 10:27:58 +00:00
def train_model():
2022-06-28 16:28:36 +00:00
# Set a known random seed for reproducibility
np.random.seed(0)
torch.manual_seed(0)
device = get_train_device()
# Load datasets
2022-07-01 13:35:12 +00:00
train_loader, test_loader = get_image_loader("training/", precision=np.float32)
nn = ImageNN(n_in_channels=3, precision=np.float32) # todo net params
2022-06-01 14:07:32 +00:00
nn.train() # init with train mode
2022-06-28 16:28:36 +00:00
nn.to(device) # send net to device available
2022-06-01 10:27:58 +00:00
optimizer = torch.optim.AdamW(nn.parameters(), lr=0.1, weight_decay=1e-5) # todo adjust parameters and lr
2022-06-28 16:28:36 +00:00
loss_function = torch.nn.MSELoss()
2022-07-01 13:35:12 +00:00
loss_function.to(device)
n_epochs = 7 # todo epcchs here
2022-06-28 16:28:36 +00:00
train_sample_size = len(train_loader)
2022-06-01 10:27:58 +00:00
losses = []
best_eval_loss = np.inf
2022-06-01 10:27:58 +00:00
for epoch in range(n_epochs):
2022-06-01 14:07:32 +00:00
print(f"Epoch {epoch}/{n_epochs}\n")
2022-06-28 16:28:36 +00:00
i = 0
2022-06-01 14:07:32 +00:00
for input_tensor, target_tensor in train_loader:
2022-07-01 13:35:12 +00:00
output = nn(input_tensor.to(device)) # get model output (forward pass)
2022-06-28 16:28:36 +00:00
2022-07-01 13:35:12 +00:00
loss = loss_function(output.to(device), target_tensor.to(device)) # compute loss given model output and true target
2022-06-01 10:27:58 +00:00
loss.backward() # compute gradients (backward pass)
optimizer.step() # perform gradient descent update step
optimizer.zero_grad() # reset gradients
losses.append(loss.item())
2022-07-01 13:35:12 +00:00
i += train_loader.batch_size
2022-06-28 16:28:36 +00:00
print(
f'\rTraining epoch {epoch} [{i}/{train_sample_size * train_loader.batch_size}] (curr loss: {loss.item():.3})',
end='')
# eval model every 3000th sample
2022-07-01 13:35:12 +00:00
if i % 3000 == 0:
2022-06-28 16:28:36 +00:00
print(f"\nEvaluating model")
eval_loss = eval_model(nn, test_loader, loss_function, device)
print(f"Evalution loss={eval_loss}")
if eval_loss < best_eval_loss:
2022-06-28 16:28:36 +00:00
best_eval_loss = eval_loss
save_model(nn)
2022-07-01 13:35:12 +00:00
nn.train()
# evaluate model with submission pkl file
eval_evalset()
2022-06-01 14:07:32 +00:00
2022-06-28 16:28:36 +00:00
# func to evaluate our trained model
def eval_model(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn, device: torch.device):
# switch to eval mode
model.eval()
loss = .0
# disable gradient calculations
with torch.no_grad():
i = 0
for input, target in dataloader:
2022-07-01 13:35:12 +00:00
input = input.to(device)
target = target.to(device)
2022-06-28 16:28:36 +00:00
out = model(input)
loss += loss_fn(out, target).item()
print(f'\rEval prog[{i}/{len(dataloader) * dataloader.batch_size}]', end='')
i += dataloader.batch_size
print()
loss /= len(dataloader)
return loss
2022-06-01 10:27:58 +00:00
2022-06-01 14:07:32 +00:00
2022-07-01 13:35:12 +00:00
if __name__ == '__main__':
train_model()