basic training logic
This commit is contained in:
parent
b1bc3a2c64
commit
8cf208cfc9
@ -3,9 +3,14 @@ import os
|
||||
|
||||
import numpy as np
|
||||
import torch.utils.data.dataset
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
|
||||
import ex4
|
||||
|
||||
IMG_SIZE = 100
|
||||
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
@ -17,10 +22,27 @@ class ImageDataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
# Open image file, convert to numpy array and scale to [0, 1]
|
||||
image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255
|
||||
target_image = Image.open(self.image_files[index])
|
||||
# image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255
|
||||
resize_transforms = transforms.Compose([
|
||||
transforms.Resize(size=IMG_SIZE),
|
||||
transforms.CenterCrop(size=(IMG_SIZE, IMG_SIZE)),
|
||||
])
|
||||
target_image = resize_transforms(target_image)
|
||||
|
||||
# normalize image from 0-1
|
||||
target_image = np.array(target_image, dtype=np.float64) / 255.0
|
||||
|
||||
# Perform normalization for each channel
|
||||
image = (image - self.norm_mean) / self.norm_std
|
||||
return image, index
|
||||
# image = (image - self.norm_mean) / self.norm_std
|
||||
|
||||
# calculate image with black grid
|
||||
doomed_image = ex4.ex4(target_image, (5, 5), (4, 4))
|
||||
|
||||
# convert image to grayscale
|
||||
# target_image = rgb2gray(target_image) # todo look if gray image makes sense
|
||||
|
||||
return doomed_image[0], np.transpose(target_image, (2, 0, 1))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_files)
|
||||
@ -35,15 +57,21 @@ def get_image_loader(path: str):
|
||||
train_loader = DataLoader(
|
||||
trains,
|
||||
shuffle=True, # shuffle the order of our samples
|
||||
batch_size=4, # stack 4 samples to a minibatch
|
||||
num_workers=0 # no background workers (see comment below)
|
||||
batch_size=5, # stack 4 samples to a minibatch
|
||||
num_workers=2 # no background workers (see comment below)
|
||||
)
|
||||
|
||||
test_loader = DataLoader(
|
||||
tsts,
|
||||
tests,
|
||||
shuffle=True, # shuffle the order of our samples
|
||||
batch_size=4, # stack 4 samples to a minibatch
|
||||
batch_size=1, # stack 4 samples to a minibatch
|
||||
num_workers=0 # no background workers (see comment below)
|
||||
)
|
||||
|
||||
return train_loader, test_loader
|
||||
|
||||
|
||||
def rgb2gray(rgb_array: np.ndarray):
|
||||
r, g, b = rgb_array[:, :, 0], rgb_array[:, :, 1], rgb_array[:, :, 2]
|
||||
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
|
||||
return gray
|
||||
|
@ -1,26 +1,48 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from DataLoader import get_image_loader
|
||||
from Net import ImageNN
|
||||
|
||||
|
||||
# 01.05.22 -- 0.5h
|
||||
from netio import save_model, load_model
|
||||
|
||||
|
||||
def get_train_device():
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
print(f'Device to train net: {device}')
|
||||
return torch.device(device)
|
||||
|
||||
|
||||
def train_model():
|
||||
train_loader, test_loader = get_image_loader("my/supercool/image/dir")
|
||||
nn = ImageNN() # todo pass size ason.
|
||||
# Set a known random seed for reproducibility
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
device = get_train_device()
|
||||
|
||||
# Load datasets
|
||||
train_loader, test_loader = get_image_loader("training/")
|
||||
nn = ImageNN(n_in_channels=3) # todo pass size ason.
|
||||
nn.train() # init with train mode
|
||||
nn.to(device) # send net to device available
|
||||
|
||||
optimizer = torch.optim.SGD(nn.parameters(), lr=0.1) # todo adjust parameters and lr
|
||||
loss_function = torch.nn.CrossEntropyLoss()
|
||||
loss_function = torch.nn.MSELoss()
|
||||
n_epochs = 15 # todo epcchs here
|
||||
|
||||
# todo look wtf is that
|
||||
nn.double()
|
||||
|
||||
train_sample_size = len(train_loader)
|
||||
losses = []
|
||||
best_eval_loss = 0
|
||||
for epoch in range(n_epochs):
|
||||
print(f"Epoch {epoch}/{n_epochs}\n")
|
||||
i = 0
|
||||
for input_tensor, target_tensor in train_loader:
|
||||
input_tensor.to(device)
|
||||
target_tensor.to(device)
|
||||
|
||||
output = nn(input_tensor) # get model output (forward pass)
|
||||
loss = loss_function(output, target_tensor) # compute loss given model output and true target
|
||||
loss.backward() # compute gradients (backward pass)
|
||||
@ -28,18 +50,44 @@ def train_model():
|
||||
optimizer.zero_grad() # reset gradients
|
||||
losses.append(loss.item())
|
||||
|
||||
print(
|
||||
f'\rTraining epoch {epoch} [{i}/{train_sample_size * train_loader.batch_size}] (curr loss: {loss.item():.3})',
|
||||
end='')
|
||||
i += train_loader.batch_size
|
||||
|
||||
# eval model every 500th element
|
||||
if i % 500 == 0:
|
||||
print(f"\nEvaluating model")
|
||||
eval_loss = eval_model(nn, test_loader, loss_function, device)
|
||||
print(f"Evalution loss={eval_loss}")
|
||||
if eval_loss > best_eval_loss:
|
||||
best_eval_loss = eval_loss
|
||||
save_model(nn)
|
||||
|
||||
# switch net to eval mode
|
||||
nn.eval()
|
||||
print(eval_model(nn, test_loader, loss_function, device=device))
|
||||
|
||||
|
||||
# func to evaluate our trained model
|
||||
def eval_model(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn, device: torch.device):
|
||||
# switch to eval mode
|
||||
model.eval()
|
||||
loss = .0
|
||||
# disable gradient calculations
|
||||
with torch.no_grad():
|
||||
for input_tensor, target_tensor in test_loader:
|
||||
# iterate testloader and we have to decide somhow now the goodness of the prediction
|
||||
out = nn(input_tensor) # apply model
|
||||
i = 0
|
||||
for input, target in dataloader:
|
||||
input.to(device)
|
||||
target.to(device)
|
||||
|
||||
diff = out - target_tensor
|
||||
# todo evaluate trained model on testset loader
|
||||
|
||||
# todo save trained model to blob file
|
||||
save_model(nn)
|
||||
out = model(input)
|
||||
loss += loss_fn(out, target).item()
|
||||
print(f'\rEval prog[{i}/{len(dataloader) * dataloader.batch_size}]', end='')
|
||||
i += dataloader.batch_size
|
||||
print()
|
||||
loss /= len(dataloader)
|
||||
model.train()
|
||||
return loss
|
||||
|
||||
|
||||
def apply_model():
|
||||
|
31
Net.py
31
Net.py
@ -2,10 +2,31 @@ import torch
|
||||
|
||||
|
||||
class ImageNN(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, n_in_channels: int = 1, n_hidden_layers: int = 3, n_kernels: int = 32, kernel_size: int = 7):
|
||||
"""Simple CNN with `n_hidden_layers`, `n_kernels`, and `kernel_size` as hyperparameters"""
|
||||
super().__init__()
|
||||
# todo implement the nn structure
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
pass
|
||||
# todo implement forward
|
||||
cnn = []
|
||||
for i in range(n_hidden_layers):
|
||||
cnn.append(torch.nn.Conv2d(
|
||||
in_channels=n_in_channels,
|
||||
out_channels=n_kernels,
|
||||
kernel_size=kernel_size,
|
||||
padding=int(kernel_size / 2)
|
||||
))
|
||||
cnn.append(torch.nn.ReLU())
|
||||
n_in_channels = n_kernels
|
||||
self.hidden_layers = torch.nn.Sequential(*cnn)
|
||||
|
||||
self.output_layer = torch.nn.Conv2d(
|
||||
in_channels=n_in_channels,
|
||||
out_channels=3,
|
||||
kernel_size=kernel_size,
|
||||
padding=int(kernel_size / 2)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""Apply CNN to input `x` of shape (N, n_channels, X, Y), where N=n_samples and X, Y are spatial dimensions"""
|
||||
cnn_out = self.hidden_layers(x) # apply hidden layers (N, n_in_channels, X, Y) -> (N, n_kernels, X, Y)
|
||||
pred = self.output_layer(cnn_out) # apply output layer (N, n_kernels, X, Y) -> (N, 1, X, Y)
|
||||
return pred
|
||||
|
159
ex2_unittest.py
159
ex2_unittest.py
@ -1,159 +0,0 @@
|
||||
"""
|
||||
Author -- Michael Widrich, Andreas Schörgenhumer
|
||||
Contact -- schoergenhumer@ml.jku.at
|
||||
Date -- 04.03.2022
|
||||
|
||||
###############################################################################
|
||||
|
||||
The following copyright statement applies to all code within this file.
|
||||
|
||||
Copyright statement:
|
||||
This material, no matter whether in printed or electronic form,
|
||||
may be used for personal and non-commercial educational use only.
|
||||
Any reproduction of this manuscript, no matter whether as a whole or in parts,
|
||||
no matter whether in printed or in electronic form, requires explicit prior
|
||||
acceptance of the authors.
|
||||
|
||||
###############################################################################
|
||||
|
||||
Images taken from: https://pixabay.com/
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from glob import glob
|
||||
|
||||
import dill as pkl
|
||||
|
||||
|
||||
def print_outs(outs, line_token="-"):
|
||||
print(line_token * 40)
|
||||
print(outs, end="" if isinstance(outs, str) and outs.endswith("\n") else "\n")
|
||||
print(line_token * 40)
|
||||
|
||||
|
||||
ex_file = "ex2.py"
|
||||
full_points = 15
|
||||
points = full_points
|
||||
python = sys.executable
|
||||
|
||||
solutions_dir = os.path.join("unittest", "solutions")
|
||||
outputs_dir = os.path.join("unittest", "outputs")
|
||||
|
||||
# Remove previous outputs folder
|
||||
shutil.rmtree(outputs_dir, ignore_errors=True)
|
||||
|
||||
inputs = sorted(glob(os.path.join("unittest", "unittest_input_*"), recursive=True))
|
||||
if not len(inputs):
|
||||
raise FileNotFoundError("Could not find unittest_input_* files")
|
||||
|
||||
with open(os.path.join(solutions_dir, "counts.pkl"), "rb") as f:
|
||||
sol_counts = pkl.load(f)
|
||||
|
||||
for test_i, input_folder in enumerate(inputs):
|
||||
comment = ""
|
||||
fcall = ""
|
||||
|
||||
with open(os.devnull, "w") as null:
|
||||
# sys.stdout = null
|
||||
try:
|
||||
from ex2 import validate_images
|
||||
|
||||
proper_import = True
|
||||
except Exception as e:
|
||||
outs = ""
|
||||
errs = e
|
||||
points -= full_points / len(inputs)
|
||||
proper_import = False
|
||||
finally:
|
||||
sys.stdout.flush()
|
||||
sys.stdout = sys.__stdout__
|
||||
|
||||
if proper_import:
|
||||
with open(os.devnull, "w") as null:
|
||||
# sys.stdout = null
|
||||
try:
|
||||
input_basename = os.path.basename(input_folder)
|
||||
output_dir = os.path.join(outputs_dir, input_basename)
|
||||
logfilepath = output_dir + ".log"
|
||||
formatter = "06d"
|
||||
counts = validate_images(input_dir=input_folder, output_dir=output_dir, log_file=logfilepath,
|
||||
formatter=formatter)
|
||||
fcall = f'validate_images(\n\tinput_dir="{input_folder}",\n\toutput_dir="{output_dir}",\n\tlog_file="{logfilepath}",\n\tformatter="{formatter}"\n)'
|
||||
errs = ""
|
||||
|
||||
try:
|
||||
with open(os.path.join(outputs_dir, f"{input_basename}.log"), "r") as lfh:
|
||||
logfile = lfh.read()
|
||||
except FileNotFoundError:
|
||||
# two cases:
|
||||
# 1) no invalid files and thus no log file -> ok -> equal to empty tlogfile
|
||||
# 2) invalid files but no log file -> not ok -> will fail the comparison with tlogfile (below)
|
||||
logfile = ""
|
||||
with open(os.path.join(solutions_dir, f"{input_basename}.log"), "r") as lfh:
|
||||
# must replace the separator that was used when creating the solution files
|
||||
tlogfile = lfh.read().replace("\\", os.path.sep)
|
||||
|
||||
files = sorted(glob(os.path.join(outputs_dir, input_basename, "**", "*"), recursive=True))
|
||||
hashing_function = hashlib.sha256()
|
||||
for file in files:
|
||||
with open(file, "rb") as fh:
|
||||
hashing_function.update(fh.read())
|
||||
hash = hashing_function.digest()
|
||||
hashing_function = hashlib.sha256()
|
||||
tfiles = sorted(glob(os.path.join(solutions_dir, input_basename, "**", "*"), recursive=True))
|
||||
for file in tfiles:
|
||||
with open(file, "rb") as fh:
|
||||
hashing_function.update(fh.read())
|
||||
thash = hashing_function.digest()
|
||||
|
||||
tcounts = sol_counts[input_basename]
|
||||
|
||||
if not counts == tcounts:
|
||||
points -= full_points / len(inputs)
|
||||
comment = f"Function should return {tcounts} but returned {counts}"
|
||||
elif not [f.split(os.path.sep)[-2:] for f in files] == [f.split(os.path.sep)[-2:] for f in tfiles]:
|
||||
points -= full_points / len(inputs)
|
||||
comment = f"Contents of output directory do not match (see directory 'solutions')"
|
||||
elif not hash == thash:
|
||||
points -= full_points / len(inputs)
|
||||
comment = f"Hash value of the files in the output directory do not match (see directory 'solutions')"
|
||||
elif not logfile == tlogfile:
|
||||
points -= full_points / len(inputs)
|
||||
comment = f"Contents of logfiles do not match (see directory 'solutions')"
|
||||
|
||||
except Exception as e:
|
||||
outs = ""
|
||||
errs = e
|
||||
points -= full_points / len(inputs)
|
||||
finally:
|
||||
sys.stdout.flush()
|
||||
sys.stdout = sys.__stdout__
|
||||
|
||||
print()
|
||||
print_outs(f"Test {test_i}", line_token="#")
|
||||
print("Function call:")
|
||||
print_outs(fcall)
|
||||
|
||||
if errs:
|
||||
print(f"Some unexpected errors occurred:")
|
||||
print_outs(f"{type(errs).__name__}: {errs}")
|
||||
else:
|
||||
print("Notes:")
|
||||
print_outs("No issues found" if comment == "" else comment)
|
||||
|
||||
# due to floating point calculations it could happen that we get -0 here
|
||||
if points < 0:
|
||||
assert abs(points) < 1e-7, f"points were {points} < 0: error when subtracting points?"
|
||||
points = abs(points)
|
||||
print(f"Current points: {points:.2f}")
|
||||
|
||||
print(f"\nEstimated points upon submission: {points:.2f} (out of {full_points:.2f})")
|
||||
if points < full_points:
|
||||
print(f"Check the folder '{outputs_dir}' to see where your errors are")
|
||||
else:
|
||||
shutil.rmtree(os.path.join(outputs_dir))
|
||||
print(f"This is only an estimate, see 'Instructions for submitting homework' in Moodle "
|
||||
f"for common mistakes that can still lead to 0 points.")
|
204
ex3_unittest.py
204
ex3_unittest.py
@ -1,204 +0,0 @@
|
||||
"""
|
||||
Author -- Michael Widrich, Andreas Schörgenhumer
|
||||
Contact -- schoergenhumer@ml.jku.at
|
||||
Date -- 02.03.2022
|
||||
|
||||
###############################################################################
|
||||
|
||||
The following copyright statement applies to all code within this file.
|
||||
|
||||
Copyright statement:
|
||||
This material, no matter whether in printed or electronic form,
|
||||
may be used for personal and non-commercial educational use only.
|
||||
Any reproduction of this manuscript, no matter whether as a whole or in parts,
|
||||
no matter whether in printed or in electronic form, requires explicit prior
|
||||
acceptance of the authors.
|
||||
|
||||
###############################################################################
|
||||
|
||||
Images taken from: https://pixabay.com/
|
||||
"""
|
||||
|
||||
import gzip
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from glob import glob
|
||||
from types import GeneratorType
|
||||
|
||||
import dill as pkl
|
||||
import numpy as np
|
||||
|
||||
|
||||
def print_outs(outs, line_token="-"):
|
||||
print(line_token * 40)
|
||||
print(outs, end="" if isinstance(outs, str) and outs.endswith("\n") else "\n")
|
||||
print(line_token * 40)
|
||||
|
||||
|
||||
time_given = int(15)
|
||||
|
||||
check_for_timeout = hasattr(signal, "SIGALRM")
|
||||
|
||||
if check_for_timeout:
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError(f"Timeout after {time_given}sec")
|
||||
|
||||
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
|
||||
ex_file = "ex3.py"
|
||||
full_points = 15
|
||||
points = full_points
|
||||
python = sys.executable
|
||||
|
||||
inputs = sorted(glob(os.path.join("unittest", "unittest_input_*"), recursive=True))
|
||||
|
||||
if not len(inputs):
|
||||
raise FileNotFoundError("Could not find unittest_input_* files")
|
||||
|
||||
for test_i, input_folder in enumerate(inputs):
|
||||
comment = ""
|
||||
fcall = ""
|
||||
|
||||
with open(os.devnull, "w") as null:
|
||||
# sys.stdout = null
|
||||
try:
|
||||
if check_for_timeout:
|
||||
signal.alarm(time_given)
|
||||
from ex3 import ImageStandardizer
|
||||
|
||||
signal.alarm(0)
|
||||
else:
|
||||
from ex3 import ImageStandardizer
|
||||
proper_import = True
|
||||
except Exception as e:
|
||||
errs = e
|
||||
points -= full_points / len(inputs)
|
||||
proper_import = False
|
||||
finally:
|
||||
sys.stdout.flush()
|
||||
sys.stdout = sys.__stdout__
|
||||
|
||||
if proper_import:
|
||||
with open(os.devnull, "w") as null:
|
||||
# sys.stdout = null
|
||||
try:
|
||||
if check_for_timeout:
|
||||
signal.alarm(time_given)
|
||||
# check constructor
|
||||
instance = ImageStandardizer(input_dir=input_folder)
|
||||
fcall = f"ImageStandardizer(input_dir='{input_folder}')"
|
||||
signal.alarm(0)
|
||||
else:
|
||||
# check constructor
|
||||
instance = ImageStandardizer(input_dir=input_folder)
|
||||
fcall = f"ImageStandardizer(input_dir='{input_folder}')"
|
||||
errs = ""
|
||||
|
||||
# check correct file names + sorting
|
||||
input_basename = os.path.basename(input_folder)
|
||||
with open(os.path.join("unittest", "solutions", input_basename, f"filenames.txt"), "r") as f:
|
||||
# must replace the separator that was used when creating the solution files
|
||||
files_sol = f.read().replace("\\", os.path.sep).splitlines()
|
||||
# for simplicity's sake, only compare relative paths here
|
||||
common = os.path.commonprefix(instance.files)
|
||||
rel_instance_files = [os.path.join(input_folder, f[len(common):]) for f in instance.files]
|
||||
if not hasattr(instance, "files"):
|
||||
points -= full_points / len(inputs) / 3
|
||||
comment += f"Attributes 'files' missing.\n"
|
||||
elif rel_instance_files != files_sol:
|
||||
points -= full_points / len(inputs) / 3
|
||||
comment += f"Attribute 'files' should be {files_sol} but is {instance.files} (see directory 'solutions').\n"
|
||||
elif len(instance.files) != len(files_sol):
|
||||
points -= full_points / len(inputs) / 3
|
||||
comment += f"Number of files should be {len(files_sol)} but is {len(instance.files)} (see directory 'solutions').\n"
|
||||
|
||||
# check if class has method analyze_images
|
||||
method = "analyze_images"
|
||||
if not hasattr(instance, method):
|
||||
comment += f"Method '{method}' missing.\n"
|
||||
points -= full_points / len(inputs) / 3
|
||||
else:
|
||||
# check for correct data types
|
||||
stats = instance.analyze_images()
|
||||
if (type(stats) is not tuple) or (len(stats) != 2):
|
||||
points -= full_points / len(inputs) / 3
|
||||
comment += f"Incorrect return value of method '{method}' (should be tuple of length 2).\n"
|
||||
else:
|
||||
with open(os.path.join("unittest", "solutions", input_basename, f"mean_and_std.pkl"),
|
||||
"rb") as fh:
|
||||
data = pkl.load(fh)
|
||||
m = data["mean"]
|
||||
s = data["std"]
|
||||
if not (isinstance(stats[0], np.ndarray) and isinstance(stats[1], np.ndarray) and
|
||||
stats[0].dtype == np.float64 and stats[1].dtype == np.float64 and
|
||||
stats[0].shape == (3,) and stats[1].shape == (3,)):
|
||||
points -= full_points / len(inputs) / 3
|
||||
comment += f"Incorrect return data type of method '{method}' (tuple entries should be np.ndarray of dtype np.float64 and shape (3,)).\n"
|
||||
else:
|
||||
if not np.isclose(stats[0], m, atol=0).all():
|
||||
points -= full_points / len(inputs) / 6
|
||||
comment += f"Mean should be {m} but is {stats[0]} (see directory 'solutions').\n"
|
||||
if not np.isclose(stats[1], s, atol=0).all():
|
||||
points -= full_points / len(inputs) / 6
|
||||
comment += f"Std should be {s} but is {stats[1]} (see directory 'solutions').\n"
|
||||
|
||||
# check if class has method get_standardized_images
|
||||
method = "get_standardized_images"
|
||||
if not hasattr(instance, method):
|
||||
comment += f"Method '{method}' missing.\n"
|
||||
points -= full_points / len(inputs) / 3
|
||||
# check for correct data types
|
||||
elif not isinstance(instance.get_standardized_images(), GeneratorType):
|
||||
points -= full_points / len(inputs) / 3
|
||||
comment += f"'{method}' is not a generator.\n"
|
||||
else:
|
||||
# Read correct image solutions
|
||||
with gzip.open(os.path.join("unittest", "solutions", input_basename, "images.pkl"), "rb") as fh:
|
||||
ims_sol = pkl.load(file=fh)
|
||||
|
||||
# Get image submissions
|
||||
ims_sub = list(instance.get_standardized_images())
|
||||
|
||||
if not len(ims_sub) == len(ims_sol):
|
||||
points -= full_points / len(inputs) / 3
|
||||
comment += f"{len(ims_sol)} image arrays should have been returned but got {len(ims_sub)}.\n"
|
||||
elif any([im_sub.dtype.num != np.dtype(np.float32).num for im_sub in ims_sub]):
|
||||
points -= full_points / len(inputs) / 3
|
||||
comment += f"Returned image arrays should have datatype np.float32 but at least one array isn't.\n"
|
||||
else:
|
||||
equal = [np.all(np.isclose(im_sub, im_sol, atol=0)) for im_sub, im_sol in zip(ims_sub, ims_sol)]
|
||||
if not all(equal):
|
||||
points -= full_points / len(inputs) / 3
|
||||
comment += f"Returned images {list(np.where(np.logical_not(equal))[0])} do not match solution (see images.pkl files for solution).\n"
|
||||
except Exception as e:
|
||||
errs = e
|
||||
points -= full_points / len(inputs)
|
||||
finally:
|
||||
sys.stdout.flush()
|
||||
sys.stdout = sys.__stdout__
|
||||
|
||||
print()
|
||||
print_outs(f"Test {test_i}", line_token="#")
|
||||
print("Function call:")
|
||||
print_outs(fcall)
|
||||
|
||||
if errs:
|
||||
print(f"Some unexpected errors occurred:")
|
||||
print_outs(f"{type(errs).__name__}: {errs}")
|
||||
else:
|
||||
print("Notes:")
|
||||
print_outs("No issues found" if comment == "" else comment)
|
||||
|
||||
# due to floating point calculations it could happen that we get -0 here
|
||||
if points < 0:
|
||||
assert abs(points) < 1e-7, f"points were {points} < 0: error when subtracting points?"
|
||||
points = abs(points)
|
||||
print(f"Current points: {points:.2f}")
|
||||
|
||||
print(f"\nEstimated points upon submission: {points:.2f} (out of {full_points:.2f})")
|
||||
print(f"This is only an estimate, see 'Instructions for submitting homework' in Moodle "
|
||||
f"for common mistakes that can still lead to 0 points.")
|
||||
if not check_for_timeout:
|
||||
print("\n!!Warning: Had to switch to Windows compatibility version and did not check for timeouts!!")
|
143
ex4_unittest.py
143
ex4_unittest.py
@ -1,143 +0,0 @@
|
||||
"""
|
||||
Author -- Michael Widrich
|
||||
Contact -- widrich@ml.jku.at
|
||||
Date -- 01.10.2019
|
||||
|
||||
###############################################################################
|
||||
|
||||
The following copyright statement applies to all code within this file.
|
||||
|
||||
Copyright statement:
|
||||
This material, no matter whether in printed or electronic form,
|
||||
may be used for personal and non-commercial educational use only.
|
||||
Any reproduction of this manuscript, no matter whether as a whole or in parts,
|
||||
no matter whether in printed or in electronic form, requires explicit prior
|
||||
acceptance of the authors.
|
||||
|
||||
###############################################################################
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import dill as pkl
|
||||
import numpy as np
|
||||
|
||||
|
||||
def print_outs(outs, line_token="-"):
|
||||
print(line_token * 40)
|
||||
print(outs, end="" if isinstance(outs, str) and outs.endswith("\n") else "\n")
|
||||
print(line_token * 40)
|
||||
|
||||
|
||||
ex_file = 'ex4.py'
|
||||
full_points = 20
|
||||
points = full_points
|
||||
python = sys.executable
|
||||
|
||||
with open(os.path.join("unittest", "unittest_inputs_outputs.pkl"), "rb") as ufh:
|
||||
all_inputs_outputs = pkl.load(ufh)
|
||||
all_inputs = all_inputs_outputs['inputs']
|
||||
all_outputs = all_inputs_outputs['outputs']
|
||||
|
||||
feedback = ''
|
||||
|
||||
for test_i, (inputs, outputs) in enumerate(zip(all_inputs, all_outputs)):
|
||||
|
||||
comment = ''
|
||||
fcall = ''
|
||||
with open(os.devnull, 'w') as null:
|
||||
# sys.stdout = null
|
||||
try:
|
||||
from ex4 import ex4
|
||||
proper_import = True
|
||||
except Exception:
|
||||
outs = ''
|
||||
errs = traceback.format_exc()
|
||||
points -= full_points / len(all_inputs_outputs)
|
||||
proper_import = False
|
||||
finally:
|
||||
sys.stdout.flush()
|
||||
sys.stdout = sys.__stdout__
|
||||
|
||||
if proper_import:
|
||||
with open(os.devnull, 'w') as null:
|
||||
# sys.stdout = null
|
||||
try:
|
||||
errs = ''
|
||||
fcall = f"ex4(image_array={inputs[0]}, offset={inputs[1]}, spacing={inputs[2]}))"
|
||||
returns = ex4(image_array=inputs[0], offset=inputs[1],
|
||||
spacing=inputs[2])
|
||||
|
||||
# Check if returns and outputs are of same type
|
||||
if type(returns) != type(outputs):
|
||||
comment = f"Output should be: {type(outputs).__name__} ('{outputs}'). \n" \
|
||||
f" but is: {returns}"
|
||||
points -= full_points / len(all_inputs)
|
||||
else:
|
||||
# Check input_array output
|
||||
if (len(returns) != 3
|
||||
or not isinstance(returns[0], np.ndarray)
|
||||
or returns[0].dtype != outputs[0].dtype
|
||||
or returns[0].shape != outputs[0].shape
|
||||
or np.any(returns[0] != outputs[0])):
|
||||
points -= (full_points / len(all_inputs)) / 3
|
||||
comment = f"Incorrect 'input_array'. Output should be: " \
|
||||
f"{outputs} \n" \
|
||||
f"but is {returns}"
|
||||
|
||||
# Check known_array output
|
||||
if (len(returns) != 3
|
||||
or not isinstance(returns[1], np.ndarray)
|
||||
or returns[1].dtype != outputs[1].dtype
|
||||
or returns[1].shape != outputs[1].shape
|
||||
or np.any(returns[1] != outputs[1])):
|
||||
points -= (full_points / len(all_inputs)) / 3
|
||||
comment = f"Incorrect 'known_array'. Output should be: " \
|
||||
f"{outputs} \n" \
|
||||
f"but is {returns}"
|
||||
|
||||
# Check target_array output
|
||||
if (len(returns) != 3
|
||||
or not isinstance(returns[2], np.ndarray)
|
||||
or returns[2].dtype != outputs[2].dtype
|
||||
or returns[2].shape != outputs[2].shape
|
||||
or np.any(returns[2] != outputs[2])):
|
||||
points -= (full_points / len(all_inputs)) / 3
|
||||
comment = f"Incorrect 'target_array'. Output should be: " \
|
||||
f"{outputs} \n" \
|
||||
f"but is {returns}"
|
||||
|
||||
except Exception as e:
|
||||
outs = ''
|
||||
if not type(e) == type(outputs):
|
||||
comment = f"Output should be: {type(outputs).__name__} ('{outputs}'). \n" \
|
||||
f" but is:\n{traceback.format_exc()}"
|
||||
points -= full_points / len(all_inputs)
|
||||
finally:
|
||||
sys.stdout.flush()
|
||||
sys.stdout = sys.__stdout__
|
||||
|
||||
print()
|
||||
print_outs(f"Test {test_i}", line_token="#")
|
||||
print("Function call:")
|
||||
print_outs(fcall)
|
||||
|
||||
if errs:
|
||||
print(f"Some unexpected errors occurred:")
|
||||
print_outs(errs)
|
||||
else:
|
||||
print("Notes:")
|
||||
print_outs("No issues found" if comment == "" else comment)
|
||||
|
||||
# due to floating point calculations it could happen that we get -0 here
|
||||
if points < 0:
|
||||
assert abs(points) < 1e-7, f"points were {points} < 0: error when subtracting points?"
|
||||
points = 0
|
||||
print(f"Current points: {points:.2f}")
|
||||
|
||||
print(f"\nEstimated points upon submission: {points:.2f} (out of {full_points:.2f})")
|
||||
print(f"This is only an estimate, see 'Instructions for submitting homework' in Moodle "
|
||||
f"for common mistakes that can still lead to 0 points.")
|
4
netio.py
4
netio.py
@ -4,11 +4,11 @@ from Net import ImageNN
|
||||
|
||||
|
||||
def save_model(model: torch.nn.Module):
|
||||
torch.save(model.state_dict(), 'impaintmodel.pth')
|
||||
torch.save(model, 'impaintmodel.pt')
|
||||
|
||||
|
||||
def load_model():
|
||||
model = ImageNN()
|
||||
model.load_state_dict(torch.load('model_weights.pth'))
|
||||
model.load_state_dict(torch.load('impaintmodel.pt'))
|
||||
model.eval()
|
||||
return model
|
||||
|
Loading…
Reference in New Issue
Block a user