2022-06-01 10:27:58 +00:00
|
|
|
import torch
|
|
|
|
|
|
|
|
from DataLoader import get_image_loader
|
|
|
|
from Net import ImageNN
|
|
|
|
|
|
|
|
|
2022-06-01 14:07:32 +00:00
|
|
|
# 01.05.22 -- 0.5h
|
|
|
|
from netio import save_model, load_model
|
|
|
|
|
|
|
|
|
2022-06-01 10:27:58 +00:00
|
|
|
def train_model():
|
2022-06-01 14:07:32 +00:00
|
|
|
train_loader, test_loader = get_image_loader("my/supercool/image/dir")
|
|
|
|
nn = ImageNN() # todo pass size ason.
|
|
|
|
nn.train() # init with train mode
|
2022-06-01 10:27:58 +00:00
|
|
|
|
2022-06-01 14:07:32 +00:00
|
|
|
optimizer = torch.optim.SGD(nn.parameters(), lr=0.1) # todo adjust parameters and lr
|
2022-06-01 10:27:58 +00:00
|
|
|
loss_function = torch.nn.CrossEntropyLoss()
|
2022-06-01 14:07:32 +00:00
|
|
|
n_epochs = 15 # todo epcchs here
|
2022-06-01 10:27:58 +00:00
|
|
|
|
|
|
|
losses = []
|
|
|
|
for epoch in range(n_epochs):
|
2022-06-01 14:07:32 +00:00
|
|
|
print(f"Epoch {epoch}/{n_epochs}\n")
|
|
|
|
for input_tensor, target_tensor in train_loader:
|
2022-06-01 10:27:58 +00:00
|
|
|
output = nn(input_tensor) # get model output (forward pass)
|
|
|
|
loss = loss_function(output, target_tensor) # compute loss given model output and true target
|
|
|
|
loss.backward() # compute gradients (backward pass)
|
|
|
|
optimizer.step() # perform gradient descent update step
|
|
|
|
optimizer.zero_grad() # reset gradients
|
|
|
|
losses.append(loss.item())
|
|
|
|
|
2022-06-01 14:07:32 +00:00
|
|
|
# switch net to eval mode
|
|
|
|
nn.eval()
|
|
|
|
with torch.no_grad():
|
|
|
|
for input_tensor, target_tensor in test_loader:
|
|
|
|
# iterate testloader and we have to decide somhow now the goodness of the prediction
|
|
|
|
out = nn(input_tensor) # apply model
|
|
|
|
|
|
|
|
diff = out - target_tensor
|
|
|
|
# todo evaluate trained model on testset loader
|
|
|
|
|
2022-06-01 10:27:58 +00:00
|
|
|
# todo save trained model to blob file
|
2022-06-01 14:07:32 +00:00
|
|
|
save_model(nn)
|
2022-06-01 10:27:58 +00:00
|
|
|
|
|
|
|
|
|
|
|
def apply_model():
|
2022-06-01 14:07:32 +00:00
|
|
|
model = load_model()
|
|
|
|
|
2022-06-01 10:27:58 +00:00
|
|
|
pass
|