ImageImpaint_Python_II/ImageImpaint.py

33 lines
1.0 KiB
Python
Raw Normal View History

2022-06-01 10:27:58 +00:00
import torch
from DataLoader import get_image_loader
from Net import ImageNN
def train_model():
image_loader = get_image_loader("my/supercool/image/dir")
# todo split to train and test (maybe evaluation sets)
nn = ImageNN() # todo pass size ason.
optimizer = torch.optim.SGD(nn.parameters(), lr=0.1) # todo adjust parameters and lr
loss_function = torch.nn.CrossEntropyLoss()
n_epochs = 15 # todo epcchs here
# Training
losses = []
for epoch in range(n_epochs):
for input_tensor, target_tensor in image_loader:
output = nn(input_tensor) # get model output (forward pass)
loss = loss_function(output, target_tensor) # compute loss given model output and true target
loss.backward() # compute gradients (backward pass)
optimizer.step() # perform gradient descent update step
optimizer.zero_grad() # reset gradients
losses.append(loss.item())
# todo evaluate trained model
# todo save trained model to blob file
def apply_model():
pass