basic training logic

This commit is contained in:
lukas-heiligenbrunner 2022-06-28 18:28:36 +02:00
parent b1bc3a2c64
commit 8cf208cfc9
7 changed files with 125 additions and 534 deletions

View File

@ -3,9 +3,14 @@ import os
import numpy as np import numpy as np
import torch.utils.data.dataset import torch.utils.data.dataset
from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import ex4
IMG_SIZE = 100
class ImageDataset(Dataset): class ImageDataset(Dataset):
@ -17,10 +22,27 @@ class ImageDataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
# Open image file, convert to numpy array and scale to [0, 1] # Open image file, convert to numpy array and scale to [0, 1]
image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255 target_image = Image.open(self.image_files[index])
# image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255
resize_transforms = transforms.Compose([
transforms.Resize(size=IMG_SIZE),
transforms.CenterCrop(size=(IMG_SIZE, IMG_SIZE)),
])
target_image = resize_transforms(target_image)
# normalize image from 0-1
target_image = np.array(target_image, dtype=np.float64) / 255.0
# Perform normalization for each channel # Perform normalization for each channel
image = (image - self.norm_mean) / self.norm_std # image = (image - self.norm_mean) / self.norm_std
return image, index
# calculate image with black grid
doomed_image = ex4.ex4(target_image, (5, 5), (4, 4))
# convert image to grayscale
# target_image = rgb2gray(target_image) # todo look if gray image makes sense
return doomed_image[0], np.transpose(target_image, (2, 0, 1))
def __len__(self): def __len__(self):
return len(self.image_files) return len(self.image_files)
@ -35,15 +57,21 @@ def get_image_loader(path: str):
train_loader = DataLoader( train_loader = DataLoader(
trains, trains,
shuffle=True, # shuffle the order of our samples shuffle=True, # shuffle the order of our samples
batch_size=4, # stack 4 samples to a minibatch batch_size=5, # stack 4 samples to a minibatch
num_workers=0 # no background workers (see comment below) num_workers=2 # no background workers (see comment below)
) )
test_loader = DataLoader( test_loader = DataLoader(
tsts, tests,
shuffle=True, # shuffle the order of our samples shuffle=True, # shuffle the order of our samples
batch_size=4, # stack 4 samples to a minibatch batch_size=1, # stack 4 samples to a minibatch
num_workers=0 # no background workers (see comment below) num_workers=0 # no background workers (see comment below)
) )
return train_loader, test_loader return train_loader, test_loader
def rgb2gray(rgb_array: np.ndarray):
r, g, b = rgb_array[:, :, 0], rgb_array[:, :, 1], rgb_array[:, :, 2]
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
return gray

View File

@ -1,26 +1,48 @@
import numpy as np
import torch import torch
from DataLoader import get_image_loader from DataLoader import get_image_loader
from Net import ImageNN from Net import ImageNN
# 01.05.22 -- 0.5h # 01.05.22 -- 0.5h
from netio import save_model, load_model from netio import save_model, load_model
def get_train_device():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device to train net: {device}')
return torch.device(device)
def train_model(): def train_model():
train_loader, test_loader = get_image_loader("my/supercool/image/dir") # Set a known random seed for reproducibility
nn = ImageNN() # todo pass size ason. np.random.seed(0)
torch.manual_seed(0)
device = get_train_device()
# Load datasets
train_loader, test_loader = get_image_loader("training/")
nn = ImageNN(n_in_channels=3) # todo pass size ason.
nn.train() # init with train mode nn.train() # init with train mode
nn.to(device) # send net to device available
optimizer = torch.optim.SGD(nn.parameters(), lr=0.1) # todo adjust parameters and lr optimizer = torch.optim.SGD(nn.parameters(), lr=0.1) # todo adjust parameters and lr
loss_function = torch.nn.CrossEntropyLoss() loss_function = torch.nn.MSELoss()
n_epochs = 15 # todo epcchs here n_epochs = 15 # todo epcchs here
# todo look wtf is that
nn.double()
train_sample_size = len(train_loader)
losses = [] losses = []
best_eval_loss = 0
for epoch in range(n_epochs): for epoch in range(n_epochs):
print(f"Epoch {epoch}/{n_epochs}\n") print(f"Epoch {epoch}/{n_epochs}\n")
i = 0
for input_tensor, target_tensor in train_loader: for input_tensor, target_tensor in train_loader:
input_tensor.to(device)
target_tensor.to(device)
output = nn(input_tensor) # get model output (forward pass) output = nn(input_tensor) # get model output (forward pass)
loss = loss_function(output, target_tensor) # compute loss given model output and true target loss = loss_function(output, target_tensor) # compute loss given model output and true target
loss.backward() # compute gradients (backward pass) loss.backward() # compute gradients (backward pass)
@ -28,19 +50,45 @@ def train_model():
optimizer.zero_grad() # reset gradients optimizer.zero_grad() # reset gradients
losses.append(loss.item()) losses.append(loss.item())
# switch net to eval mode print(
nn.eval() f'\rTraining epoch {epoch} [{i}/{train_sample_size * train_loader.batch_size}] (curr loss: {loss.item():.3})',
with torch.no_grad(): end='')
for input_tensor, target_tensor in test_loader: i += train_loader.batch_size
# iterate testloader and we have to decide somhow now the goodness of the prediction
out = nn(input_tensor) # apply model
diff = out - target_tensor # eval model every 500th element
# todo evaluate trained model on testset loader if i % 500 == 0:
print(f"\nEvaluating model")
# todo save trained model to blob file eval_loss = eval_model(nn, test_loader, loss_function, device)
print(f"Evalution loss={eval_loss}")
if eval_loss > best_eval_loss:
best_eval_loss = eval_loss
save_model(nn) save_model(nn)
# switch net to eval mode
print(eval_model(nn, test_loader, loss_function, device=device))
# func to evaluate our trained model
def eval_model(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn, device: torch.device):
# switch to eval mode
model.eval()
loss = .0
# disable gradient calculations
with torch.no_grad():
i = 0
for input, target in dataloader:
input.to(device)
target.to(device)
out = model(input)
loss += loss_fn(out, target).item()
print(f'\rEval prog[{i}/{len(dataloader) * dataloader.batch_size}]', end='')
i += dataloader.batch_size
print()
loss /= len(dataloader)
model.train()
return loss
def apply_model(): def apply_model():
model = load_model() model = load_model()

31
Net.py
View File

@ -2,10 +2,31 @@ import torch
class ImageNN(torch.nn.Module): class ImageNN(torch.nn.Module):
def __init__(self): def __init__(self, n_in_channels: int = 1, n_hidden_layers: int = 3, n_kernels: int = 32, kernel_size: int = 7):
"""Simple CNN with `n_hidden_layers`, `n_kernels`, and `kernel_size` as hyperparameters"""
super().__init__() super().__init__()
# todo implement the nn structure
def forward(self, x: torch.Tensor): cnn = []
pass for i in range(n_hidden_layers):
# todo implement forward cnn.append(torch.nn.Conv2d(
in_channels=n_in_channels,
out_channels=n_kernels,
kernel_size=kernel_size,
padding=int(kernel_size / 2)
))
cnn.append(torch.nn.ReLU())
n_in_channels = n_kernels
self.hidden_layers = torch.nn.Sequential(*cnn)
self.output_layer = torch.nn.Conv2d(
in_channels=n_in_channels,
out_channels=3,
kernel_size=kernel_size,
padding=int(kernel_size / 2)
)
def forward(self, x):
"""Apply CNN to input `x` of shape (N, n_channels, X, Y), where N=n_samples and X, Y are spatial dimensions"""
cnn_out = self.hidden_layers(x) # apply hidden layers (N, n_in_channels, X, Y) -> (N, n_kernels, X, Y)
pred = self.output_layer(cnn_out) # apply output layer (N, n_kernels, X, Y) -> (N, 1, X, Y)
return pred

View File

@ -1,159 +0,0 @@
"""
Author -- Michael Widrich, Andreas Schörgenhumer
Contact -- schoergenhumer@ml.jku.at
Date -- 04.03.2022
###############################################################################
The following copyright statement applies to all code within this file.
Copyright statement:
This material, no matter whether in printed or electronic form,
may be used for personal and non-commercial educational use only.
Any reproduction of this manuscript, no matter whether as a whole or in parts,
no matter whether in printed or in electronic form, requires explicit prior
acceptance of the authors.
###############################################################################
Images taken from: https://pixabay.com/
"""
import hashlib
import os
import shutil
import sys
from glob import glob
import dill as pkl
def print_outs(outs, line_token="-"):
print(line_token * 40)
print(outs, end="" if isinstance(outs, str) and outs.endswith("\n") else "\n")
print(line_token * 40)
ex_file = "ex2.py"
full_points = 15
points = full_points
python = sys.executable
solutions_dir = os.path.join("unittest", "solutions")
outputs_dir = os.path.join("unittest", "outputs")
# Remove previous outputs folder
shutil.rmtree(outputs_dir, ignore_errors=True)
inputs = sorted(glob(os.path.join("unittest", "unittest_input_*"), recursive=True))
if not len(inputs):
raise FileNotFoundError("Could not find unittest_input_* files")
with open(os.path.join(solutions_dir, "counts.pkl"), "rb") as f:
sol_counts = pkl.load(f)
for test_i, input_folder in enumerate(inputs):
comment = ""
fcall = ""
with open(os.devnull, "w") as null:
# sys.stdout = null
try:
from ex2 import validate_images
proper_import = True
except Exception as e:
outs = ""
errs = e
points -= full_points / len(inputs)
proper_import = False
finally:
sys.stdout.flush()
sys.stdout = sys.__stdout__
if proper_import:
with open(os.devnull, "w") as null:
# sys.stdout = null
try:
input_basename = os.path.basename(input_folder)
output_dir = os.path.join(outputs_dir, input_basename)
logfilepath = output_dir + ".log"
formatter = "06d"
counts = validate_images(input_dir=input_folder, output_dir=output_dir, log_file=logfilepath,
formatter=formatter)
fcall = f'validate_images(\n\tinput_dir="{input_folder}",\n\toutput_dir="{output_dir}",\n\tlog_file="{logfilepath}",\n\tformatter="{formatter}"\n)'
errs = ""
try:
with open(os.path.join(outputs_dir, f"{input_basename}.log"), "r") as lfh:
logfile = lfh.read()
except FileNotFoundError:
# two cases:
# 1) no invalid files and thus no log file -> ok -> equal to empty tlogfile
# 2) invalid files but no log file -> not ok -> will fail the comparison with tlogfile (below)
logfile = ""
with open(os.path.join(solutions_dir, f"{input_basename}.log"), "r") as lfh:
# must replace the separator that was used when creating the solution files
tlogfile = lfh.read().replace("\\", os.path.sep)
files = sorted(glob(os.path.join(outputs_dir, input_basename, "**", "*"), recursive=True))
hashing_function = hashlib.sha256()
for file in files:
with open(file, "rb") as fh:
hashing_function.update(fh.read())
hash = hashing_function.digest()
hashing_function = hashlib.sha256()
tfiles = sorted(glob(os.path.join(solutions_dir, input_basename, "**", "*"), recursive=True))
for file in tfiles:
with open(file, "rb") as fh:
hashing_function.update(fh.read())
thash = hashing_function.digest()
tcounts = sol_counts[input_basename]
if not counts == tcounts:
points -= full_points / len(inputs)
comment = f"Function should return {tcounts} but returned {counts}"
elif not [f.split(os.path.sep)[-2:] for f in files] == [f.split(os.path.sep)[-2:] for f in tfiles]:
points -= full_points / len(inputs)
comment = f"Contents of output directory do not match (see directory 'solutions')"
elif not hash == thash:
points -= full_points / len(inputs)
comment = f"Hash value of the files in the output directory do not match (see directory 'solutions')"
elif not logfile == tlogfile:
points -= full_points / len(inputs)
comment = f"Contents of logfiles do not match (see directory 'solutions')"
except Exception as e:
outs = ""
errs = e
points -= full_points / len(inputs)
finally:
sys.stdout.flush()
sys.stdout = sys.__stdout__
print()
print_outs(f"Test {test_i}", line_token="#")
print("Function call:")
print_outs(fcall)
if errs:
print(f"Some unexpected errors occurred:")
print_outs(f"{type(errs).__name__}: {errs}")
else:
print("Notes:")
print_outs("No issues found" if comment == "" else comment)
# due to floating point calculations it could happen that we get -0 here
if points < 0:
assert abs(points) < 1e-7, f"points were {points} < 0: error when subtracting points?"
points = abs(points)
print(f"Current points: {points:.2f}")
print(f"\nEstimated points upon submission: {points:.2f} (out of {full_points:.2f})")
if points < full_points:
print(f"Check the folder '{outputs_dir}' to see where your errors are")
else:
shutil.rmtree(os.path.join(outputs_dir))
print(f"This is only an estimate, see 'Instructions for submitting homework' in Moodle "
f"for common mistakes that can still lead to 0 points.")

View File

@ -1,204 +0,0 @@
"""
Author -- Michael Widrich, Andreas Schörgenhumer
Contact -- schoergenhumer@ml.jku.at
Date -- 02.03.2022
###############################################################################
The following copyright statement applies to all code within this file.
Copyright statement:
This material, no matter whether in printed or electronic form,
may be used for personal and non-commercial educational use only.
Any reproduction of this manuscript, no matter whether as a whole or in parts,
no matter whether in printed or in electronic form, requires explicit prior
acceptance of the authors.
###############################################################################
Images taken from: https://pixabay.com/
"""
import gzip
import os
import signal
import sys
from glob import glob
from types import GeneratorType
import dill as pkl
import numpy as np
def print_outs(outs, line_token="-"):
print(line_token * 40)
print(outs, end="" if isinstance(outs, str) and outs.endswith("\n") else "\n")
print(line_token * 40)
time_given = int(15)
check_for_timeout = hasattr(signal, "SIGALRM")
if check_for_timeout:
def handler(signum, frame):
raise TimeoutError(f"Timeout after {time_given}sec")
signal.signal(signal.SIGALRM, handler)
ex_file = "ex3.py"
full_points = 15
points = full_points
python = sys.executable
inputs = sorted(glob(os.path.join("unittest", "unittest_input_*"), recursive=True))
if not len(inputs):
raise FileNotFoundError("Could not find unittest_input_* files")
for test_i, input_folder in enumerate(inputs):
comment = ""
fcall = ""
with open(os.devnull, "w") as null:
# sys.stdout = null
try:
if check_for_timeout:
signal.alarm(time_given)
from ex3 import ImageStandardizer
signal.alarm(0)
else:
from ex3 import ImageStandardizer
proper_import = True
except Exception as e:
errs = e
points -= full_points / len(inputs)
proper_import = False
finally:
sys.stdout.flush()
sys.stdout = sys.__stdout__
if proper_import:
with open(os.devnull, "w") as null:
# sys.stdout = null
try:
if check_for_timeout:
signal.alarm(time_given)
# check constructor
instance = ImageStandardizer(input_dir=input_folder)
fcall = f"ImageStandardizer(input_dir='{input_folder}')"
signal.alarm(0)
else:
# check constructor
instance = ImageStandardizer(input_dir=input_folder)
fcall = f"ImageStandardizer(input_dir='{input_folder}')"
errs = ""
# check correct file names + sorting
input_basename = os.path.basename(input_folder)
with open(os.path.join("unittest", "solutions", input_basename, f"filenames.txt"), "r") as f:
# must replace the separator that was used when creating the solution files
files_sol = f.read().replace("\\", os.path.sep).splitlines()
# for simplicity's sake, only compare relative paths here
common = os.path.commonprefix(instance.files)
rel_instance_files = [os.path.join(input_folder, f[len(common):]) for f in instance.files]
if not hasattr(instance, "files"):
points -= full_points / len(inputs) / 3
comment += f"Attributes 'files' missing.\n"
elif rel_instance_files != files_sol:
points -= full_points / len(inputs) / 3
comment += f"Attribute 'files' should be {files_sol} but is {instance.files} (see directory 'solutions').\n"
elif len(instance.files) != len(files_sol):
points -= full_points / len(inputs) / 3
comment += f"Number of files should be {len(files_sol)} but is {len(instance.files)} (see directory 'solutions').\n"
# check if class has method analyze_images
method = "analyze_images"
if not hasattr(instance, method):
comment += f"Method '{method}' missing.\n"
points -= full_points / len(inputs) / 3
else:
# check for correct data types
stats = instance.analyze_images()
if (type(stats) is not tuple) or (len(stats) != 2):
points -= full_points / len(inputs) / 3
comment += f"Incorrect return value of method '{method}' (should be tuple of length 2).\n"
else:
with open(os.path.join("unittest", "solutions", input_basename, f"mean_and_std.pkl"),
"rb") as fh:
data = pkl.load(fh)
m = data["mean"]
s = data["std"]
if not (isinstance(stats[0], np.ndarray) and isinstance(stats[1], np.ndarray) and
stats[0].dtype == np.float64 and stats[1].dtype == np.float64 and
stats[0].shape == (3,) and stats[1].shape == (3,)):
points -= full_points / len(inputs) / 3
comment += f"Incorrect return data type of method '{method}' (tuple entries should be np.ndarray of dtype np.float64 and shape (3,)).\n"
else:
if not np.isclose(stats[0], m, atol=0).all():
points -= full_points / len(inputs) / 6
comment += f"Mean should be {m} but is {stats[0]} (see directory 'solutions').\n"
if not np.isclose(stats[1], s, atol=0).all():
points -= full_points / len(inputs) / 6
comment += f"Std should be {s} but is {stats[1]} (see directory 'solutions').\n"
# check if class has method get_standardized_images
method = "get_standardized_images"
if not hasattr(instance, method):
comment += f"Method '{method}' missing.\n"
points -= full_points / len(inputs) / 3
# check for correct data types
elif not isinstance(instance.get_standardized_images(), GeneratorType):
points -= full_points / len(inputs) / 3
comment += f"'{method}' is not a generator.\n"
else:
# Read correct image solutions
with gzip.open(os.path.join("unittest", "solutions", input_basename, "images.pkl"), "rb") as fh:
ims_sol = pkl.load(file=fh)
# Get image submissions
ims_sub = list(instance.get_standardized_images())
if not len(ims_sub) == len(ims_sol):
points -= full_points / len(inputs) / 3
comment += f"{len(ims_sol)} image arrays should have been returned but got {len(ims_sub)}.\n"
elif any([im_sub.dtype.num != np.dtype(np.float32).num for im_sub in ims_sub]):
points -= full_points / len(inputs) / 3
comment += f"Returned image arrays should have datatype np.float32 but at least one array isn't.\n"
else:
equal = [np.all(np.isclose(im_sub, im_sol, atol=0)) for im_sub, im_sol in zip(ims_sub, ims_sol)]
if not all(equal):
points -= full_points / len(inputs) / 3
comment += f"Returned images {list(np.where(np.logical_not(equal))[0])} do not match solution (see images.pkl files for solution).\n"
except Exception as e:
errs = e
points -= full_points / len(inputs)
finally:
sys.stdout.flush()
sys.stdout = sys.__stdout__
print()
print_outs(f"Test {test_i}", line_token="#")
print("Function call:")
print_outs(fcall)
if errs:
print(f"Some unexpected errors occurred:")
print_outs(f"{type(errs).__name__}: {errs}")
else:
print("Notes:")
print_outs("No issues found" if comment == "" else comment)
# due to floating point calculations it could happen that we get -0 here
if points < 0:
assert abs(points) < 1e-7, f"points were {points} < 0: error when subtracting points?"
points = abs(points)
print(f"Current points: {points:.2f}")
print(f"\nEstimated points upon submission: {points:.2f} (out of {full_points:.2f})")
print(f"This is only an estimate, see 'Instructions for submitting homework' in Moodle "
f"for common mistakes that can still lead to 0 points.")
if not check_for_timeout:
print("\n!!Warning: Had to switch to Windows compatibility version and did not check for timeouts!!")

View File

@ -1,143 +0,0 @@
"""
Author -- Michael Widrich
Contact -- widrich@ml.jku.at
Date -- 01.10.2019
###############################################################################
The following copyright statement applies to all code within this file.
Copyright statement:
This material, no matter whether in printed or electronic form,
may be used for personal and non-commercial educational use only.
Any reproduction of this manuscript, no matter whether as a whole or in parts,
no matter whether in printed or in electronic form, requires explicit prior
acceptance of the authors.
###############################################################################
"""
import os
import sys
import traceback
import dill as pkl
import numpy as np
def print_outs(outs, line_token="-"):
print(line_token * 40)
print(outs, end="" if isinstance(outs, str) and outs.endswith("\n") else "\n")
print(line_token * 40)
ex_file = 'ex4.py'
full_points = 20
points = full_points
python = sys.executable
with open(os.path.join("unittest", "unittest_inputs_outputs.pkl"), "rb") as ufh:
all_inputs_outputs = pkl.load(ufh)
all_inputs = all_inputs_outputs['inputs']
all_outputs = all_inputs_outputs['outputs']
feedback = ''
for test_i, (inputs, outputs) in enumerate(zip(all_inputs, all_outputs)):
comment = ''
fcall = ''
with open(os.devnull, 'w') as null:
# sys.stdout = null
try:
from ex4 import ex4
proper_import = True
except Exception:
outs = ''
errs = traceback.format_exc()
points -= full_points / len(all_inputs_outputs)
proper_import = False
finally:
sys.stdout.flush()
sys.stdout = sys.__stdout__
if proper_import:
with open(os.devnull, 'w') as null:
# sys.stdout = null
try:
errs = ''
fcall = f"ex4(image_array={inputs[0]}, offset={inputs[1]}, spacing={inputs[2]}))"
returns = ex4(image_array=inputs[0], offset=inputs[1],
spacing=inputs[2])
# Check if returns and outputs are of same type
if type(returns) != type(outputs):
comment = f"Output should be: {type(outputs).__name__} ('{outputs}'). \n" \
f" but is: {returns}"
points -= full_points / len(all_inputs)
else:
# Check input_array output
if (len(returns) != 3
or not isinstance(returns[0], np.ndarray)
or returns[0].dtype != outputs[0].dtype
or returns[0].shape != outputs[0].shape
or np.any(returns[0] != outputs[0])):
points -= (full_points / len(all_inputs)) / 3
comment = f"Incorrect 'input_array'. Output should be: " \
f"{outputs} \n" \
f"but is {returns}"
# Check known_array output
if (len(returns) != 3
or not isinstance(returns[1], np.ndarray)
or returns[1].dtype != outputs[1].dtype
or returns[1].shape != outputs[1].shape
or np.any(returns[1] != outputs[1])):
points -= (full_points / len(all_inputs)) / 3
comment = f"Incorrect 'known_array'. Output should be: " \
f"{outputs} \n" \
f"but is {returns}"
# Check target_array output
if (len(returns) != 3
or not isinstance(returns[2], np.ndarray)
or returns[2].dtype != outputs[2].dtype
or returns[2].shape != outputs[2].shape
or np.any(returns[2] != outputs[2])):
points -= (full_points / len(all_inputs)) / 3
comment = f"Incorrect 'target_array'. Output should be: " \
f"{outputs} \n" \
f"but is {returns}"
except Exception as e:
outs = ''
if not type(e) == type(outputs):
comment = f"Output should be: {type(outputs).__name__} ('{outputs}'). \n" \
f" but is:\n{traceback.format_exc()}"
points -= full_points / len(all_inputs)
finally:
sys.stdout.flush()
sys.stdout = sys.__stdout__
print()
print_outs(f"Test {test_i}", line_token="#")
print("Function call:")
print_outs(fcall)
if errs:
print(f"Some unexpected errors occurred:")
print_outs(errs)
else:
print("Notes:")
print_outs("No issues found" if comment == "" else comment)
# due to floating point calculations it could happen that we get -0 here
if points < 0:
assert abs(points) < 1e-7, f"points were {points} < 0: error when subtracting points?"
points = 0
print(f"Current points: {points:.2f}")
print(f"\nEstimated points upon submission: {points:.2f} (out of {full_points:.2f})")
print(f"This is only an estimate, see 'Instructions for submitting homework' in Moodle "
f"for common mistakes that can still lead to 0 points.")

View File

@ -4,11 +4,11 @@ from Net import ImageNN
def save_model(model: torch.nn.Module): def save_model(model: torch.nn.Module):
torch.save(model.state_dict(), 'impaintmodel.pth') torch.save(model, 'impaintmodel.pt')
def load_model(): def load_model():
model = ImageNN() model = ImageNN()
model.load_state_dict(torch.load('model_weights.pth')) model.load_state_dict(torch.load('impaintmodel.pt'))
model.eval() model.eval()
return model return model