implmente save/load, eval structure
This commit is contained in:
parent
2966df9a39
commit
b1bc3a2c64
@ -2,6 +2,7 @@ import glob
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch.utils.data.dataset
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@ -27,5 +28,22 @@ class ImageDataset(Dataset):
|
|||||||
|
|
||||||
def get_image_loader(path: str):
|
def get_image_loader(path: str):
|
||||||
image_dataset = ImageDataset(path)
|
image_dataset = ImageDataset(path)
|
||||||
image_loader = DataLoader(image_dataset, shuffle=True, batch_size=10)
|
totlen = len(image_dataset)
|
||||||
return image_loader
|
trains, tests = torch.utils.data.dataset.random_split(image_dataset, (int(totlen * .7), totlen - int(totlen * .7)),
|
||||||
|
generator=torch.Generator().manual_seed(42))
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
trains,
|
||||||
|
shuffle=True, # shuffle the order of our samples
|
||||||
|
batch_size=4, # stack 4 samples to a minibatch
|
||||||
|
num_workers=0 # no background workers (see comment below)
|
||||||
|
)
|
||||||
|
|
||||||
|
test_loader = DataLoader(
|
||||||
|
tsts,
|
||||||
|
shuffle=True, # shuffle the order of our samples
|
||||||
|
batch_size=4, # stack 4 samples to a minibatch
|
||||||
|
num_workers=0 # no background workers (see comment below)
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_loader, test_loader
|
||||||
|
@ -4,19 +4,23 @@ from DataLoader import get_image_loader
|
|||||||
from Net import ImageNN
|
from Net import ImageNN
|
||||||
|
|
||||||
|
|
||||||
|
# 01.05.22 -- 0.5h
|
||||||
|
from netio import save_model, load_model
|
||||||
|
|
||||||
|
|
||||||
def train_model():
|
def train_model():
|
||||||
image_loader = get_image_loader("my/supercool/image/dir")
|
train_loader, test_loader = get_image_loader("my/supercool/image/dir")
|
||||||
# todo split to train and test (maybe evaluation sets)
|
nn = ImageNN() # todo pass size ason.
|
||||||
nn = ImageNN() # todo pass size ason.
|
nn.train() # init with train mode
|
||||||
|
|
||||||
optimizer = torch.optim.SGD(nn.parameters(), lr=0.1) # todo adjust parameters and lr
|
optimizer = torch.optim.SGD(nn.parameters(), lr=0.1) # todo adjust parameters and lr
|
||||||
loss_function = torch.nn.CrossEntropyLoss()
|
loss_function = torch.nn.CrossEntropyLoss()
|
||||||
n_epochs = 15 # todo epcchs here
|
n_epochs = 15 # todo epcchs here
|
||||||
|
|
||||||
# Training
|
|
||||||
losses = []
|
losses = []
|
||||||
for epoch in range(n_epochs):
|
for epoch in range(n_epochs):
|
||||||
for input_tensor, target_tensor in image_loader:
|
print(f"Epoch {epoch}/{n_epochs}\n")
|
||||||
|
for input_tensor, target_tensor in train_loader:
|
||||||
output = nn(input_tensor) # get model output (forward pass)
|
output = nn(input_tensor) # get model output (forward pass)
|
||||||
loss = loss_function(output, target_tensor) # compute loss given model output and true target
|
loss = loss_function(output, target_tensor) # compute loss given model output and true target
|
||||||
loss.backward() # compute gradients (backward pass)
|
loss.backward() # compute gradients (backward pass)
|
||||||
@ -24,9 +28,21 @@ def train_model():
|
|||||||
optimizer.zero_grad() # reset gradients
|
optimizer.zero_grad() # reset gradients
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
|
|
||||||
# todo evaluate trained model
|
# switch net to eval mode
|
||||||
|
nn.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for input_tensor, target_tensor in test_loader:
|
||||||
|
# iterate testloader and we have to decide somhow now the goodness of the prediction
|
||||||
|
out = nn(input_tensor) # apply model
|
||||||
|
|
||||||
|
diff = out - target_tensor
|
||||||
|
# todo evaluate trained model on testset loader
|
||||||
|
|
||||||
# todo save trained model to blob file
|
# todo save trained model to blob file
|
||||||
|
save_model(nn)
|
||||||
|
|
||||||
|
|
||||||
def apply_model():
|
def apply_model():
|
||||||
|
model = load_model()
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user