basic training logic
This commit is contained in:
parent
b1bc3a2c64
commit
8cf208cfc9
@ -3,9 +3,14 @@ import os
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.utils.data.dataset
|
import torch.utils.data.dataset
|
||||||
from PIL import Image
|
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision import transforms
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import ex4
|
||||||
|
|
||||||
|
IMG_SIZE = 100
|
||||||
|
|
||||||
|
|
||||||
class ImageDataset(Dataset):
|
class ImageDataset(Dataset):
|
||||||
@ -17,10 +22,27 @@ class ImageDataset(Dataset):
|
|||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
# Open image file, convert to numpy array and scale to [0, 1]
|
# Open image file, convert to numpy array and scale to [0, 1]
|
||||||
image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255
|
target_image = Image.open(self.image_files[index])
|
||||||
|
# image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255
|
||||||
|
resize_transforms = transforms.Compose([
|
||||||
|
transforms.Resize(size=IMG_SIZE),
|
||||||
|
transforms.CenterCrop(size=(IMG_SIZE, IMG_SIZE)),
|
||||||
|
])
|
||||||
|
target_image = resize_transforms(target_image)
|
||||||
|
|
||||||
|
# normalize image from 0-1
|
||||||
|
target_image = np.array(target_image, dtype=np.float64) / 255.0
|
||||||
|
|
||||||
# Perform normalization for each channel
|
# Perform normalization for each channel
|
||||||
image = (image - self.norm_mean) / self.norm_std
|
# image = (image - self.norm_mean) / self.norm_std
|
||||||
return image, index
|
|
||||||
|
# calculate image with black grid
|
||||||
|
doomed_image = ex4.ex4(target_image, (5, 5), (4, 4))
|
||||||
|
|
||||||
|
# convert image to grayscale
|
||||||
|
# target_image = rgb2gray(target_image) # todo look if gray image makes sense
|
||||||
|
|
||||||
|
return doomed_image[0], np.transpose(target_image, (2, 0, 1))
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.image_files)
|
return len(self.image_files)
|
||||||
@ -35,15 +57,21 @@ def get_image_loader(path: str):
|
|||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
trains,
|
trains,
|
||||||
shuffle=True, # shuffle the order of our samples
|
shuffle=True, # shuffle the order of our samples
|
||||||
batch_size=4, # stack 4 samples to a minibatch
|
batch_size=5, # stack 4 samples to a minibatch
|
||||||
num_workers=0 # no background workers (see comment below)
|
num_workers=2 # no background workers (see comment below)
|
||||||
)
|
)
|
||||||
|
|
||||||
test_loader = DataLoader(
|
test_loader = DataLoader(
|
||||||
tsts,
|
tests,
|
||||||
shuffle=True, # shuffle the order of our samples
|
shuffle=True, # shuffle the order of our samples
|
||||||
batch_size=4, # stack 4 samples to a minibatch
|
batch_size=1, # stack 4 samples to a minibatch
|
||||||
num_workers=0 # no background workers (see comment below)
|
num_workers=0 # no background workers (see comment below)
|
||||||
)
|
)
|
||||||
|
|
||||||
return train_loader, test_loader
|
return train_loader, test_loader
|
||||||
|
|
||||||
|
|
||||||
|
def rgb2gray(rgb_array: np.ndarray):
|
||||||
|
r, g, b = rgb_array[:, :, 0], rgb_array[:, :, 1], rgb_array[:, :, 2]
|
||||||
|
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
|
||||||
|
return gray
|
||||||
|
@ -1,26 +1,48 @@
|
|||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from DataLoader import get_image_loader
|
from DataLoader import get_image_loader
|
||||||
from Net import ImageNN
|
from Net import ImageNN
|
||||||
|
|
||||||
|
|
||||||
# 01.05.22 -- 0.5h
|
# 01.05.22 -- 0.5h
|
||||||
from netio import save_model, load_model
|
from netio import save_model, load_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_train_device():
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
print(f'Device to train net: {device}')
|
||||||
|
return torch.device(device)
|
||||||
|
|
||||||
|
|
||||||
def train_model():
|
def train_model():
|
||||||
train_loader, test_loader = get_image_loader("my/supercool/image/dir")
|
# Set a known random seed for reproducibility
|
||||||
nn = ImageNN() # todo pass size ason.
|
np.random.seed(0)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
device = get_train_device()
|
||||||
|
|
||||||
|
# Load datasets
|
||||||
|
train_loader, test_loader = get_image_loader("training/")
|
||||||
|
nn = ImageNN(n_in_channels=3) # todo pass size ason.
|
||||||
nn.train() # init with train mode
|
nn.train() # init with train mode
|
||||||
|
nn.to(device) # send net to device available
|
||||||
|
|
||||||
optimizer = torch.optim.SGD(nn.parameters(), lr=0.1) # todo adjust parameters and lr
|
optimizer = torch.optim.SGD(nn.parameters(), lr=0.1) # todo adjust parameters and lr
|
||||||
loss_function = torch.nn.CrossEntropyLoss()
|
loss_function = torch.nn.MSELoss()
|
||||||
n_epochs = 15 # todo epcchs here
|
n_epochs = 15 # todo epcchs here
|
||||||
|
|
||||||
|
# todo look wtf is that
|
||||||
|
nn.double()
|
||||||
|
|
||||||
|
train_sample_size = len(train_loader)
|
||||||
losses = []
|
losses = []
|
||||||
|
best_eval_loss = 0
|
||||||
for epoch in range(n_epochs):
|
for epoch in range(n_epochs):
|
||||||
print(f"Epoch {epoch}/{n_epochs}\n")
|
print(f"Epoch {epoch}/{n_epochs}\n")
|
||||||
|
i = 0
|
||||||
for input_tensor, target_tensor in train_loader:
|
for input_tensor, target_tensor in train_loader:
|
||||||
|
input_tensor.to(device)
|
||||||
|
target_tensor.to(device)
|
||||||
|
|
||||||
output = nn(input_tensor) # get model output (forward pass)
|
output = nn(input_tensor) # get model output (forward pass)
|
||||||
loss = loss_function(output, target_tensor) # compute loss given model output and true target
|
loss = loss_function(output, target_tensor) # compute loss given model output and true target
|
||||||
loss.backward() # compute gradients (backward pass)
|
loss.backward() # compute gradients (backward pass)
|
||||||
@ -28,19 +50,45 @@ def train_model():
|
|||||||
optimizer.zero_grad() # reset gradients
|
optimizer.zero_grad() # reset gradients
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
|
|
||||||
# switch net to eval mode
|
print(
|
||||||
nn.eval()
|
f'\rTraining epoch {epoch} [{i}/{train_sample_size * train_loader.batch_size}] (curr loss: {loss.item():.3})',
|
||||||
with torch.no_grad():
|
end='')
|
||||||
for input_tensor, target_tensor in test_loader:
|
i += train_loader.batch_size
|
||||||
# iterate testloader and we have to decide somhow now the goodness of the prediction
|
|
||||||
out = nn(input_tensor) # apply model
|
|
||||||
|
|
||||||
diff = out - target_tensor
|
# eval model every 500th element
|
||||||
# todo evaluate trained model on testset loader
|
if i % 500 == 0:
|
||||||
|
print(f"\nEvaluating model")
|
||||||
# todo save trained model to blob file
|
eval_loss = eval_model(nn, test_loader, loss_function, device)
|
||||||
|
print(f"Evalution loss={eval_loss}")
|
||||||
|
if eval_loss > best_eval_loss:
|
||||||
|
best_eval_loss = eval_loss
|
||||||
save_model(nn)
|
save_model(nn)
|
||||||
|
|
||||||
|
# switch net to eval mode
|
||||||
|
print(eval_model(nn, test_loader, loss_function, device=device))
|
||||||
|
|
||||||
|
|
||||||
|
# func to evaluate our trained model
|
||||||
|
def eval_model(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn, device: torch.device):
|
||||||
|
# switch to eval mode
|
||||||
|
model.eval()
|
||||||
|
loss = .0
|
||||||
|
# disable gradient calculations
|
||||||
|
with torch.no_grad():
|
||||||
|
i = 0
|
||||||
|
for input, target in dataloader:
|
||||||
|
input.to(device)
|
||||||
|
target.to(device)
|
||||||
|
|
||||||
|
out = model(input)
|
||||||
|
loss += loss_fn(out, target).item()
|
||||||
|
print(f'\rEval prog[{i}/{len(dataloader) * dataloader.batch_size}]', end='')
|
||||||
|
i += dataloader.batch_size
|
||||||
|
print()
|
||||||
|
loss /= len(dataloader)
|
||||||
|
model.train()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def apply_model():
|
def apply_model():
|
||||||
model = load_model()
|
model = load_model()
|
||||||
|
31
Net.py
31
Net.py
@ -2,10 +2,31 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
class ImageNN(torch.nn.Module):
|
class ImageNN(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self, n_in_channels: int = 1, n_hidden_layers: int = 3, n_kernels: int = 32, kernel_size: int = 7):
|
||||||
|
"""Simple CNN with `n_hidden_layers`, `n_kernels`, and `kernel_size` as hyperparameters"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# todo implement the nn structure
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
cnn = []
|
||||||
pass
|
for i in range(n_hidden_layers):
|
||||||
# todo implement forward
|
cnn.append(torch.nn.Conv2d(
|
||||||
|
in_channels=n_in_channels,
|
||||||
|
out_channels=n_kernels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=int(kernel_size / 2)
|
||||||
|
))
|
||||||
|
cnn.append(torch.nn.ReLU())
|
||||||
|
n_in_channels = n_kernels
|
||||||
|
self.hidden_layers = torch.nn.Sequential(*cnn)
|
||||||
|
|
||||||
|
self.output_layer = torch.nn.Conv2d(
|
||||||
|
in_channels=n_in_channels,
|
||||||
|
out_channels=3,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=int(kernel_size / 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Apply CNN to input `x` of shape (N, n_channels, X, Y), where N=n_samples and X, Y are spatial dimensions"""
|
||||||
|
cnn_out = self.hidden_layers(x) # apply hidden layers (N, n_in_channels, X, Y) -> (N, n_kernels, X, Y)
|
||||||
|
pred = self.output_layer(cnn_out) # apply output layer (N, n_kernels, X, Y) -> (N, 1, X, Y)
|
||||||
|
return pred
|
||||||
|
159
ex2_unittest.py
159
ex2_unittest.py
@ -1,159 +0,0 @@
|
|||||||
"""
|
|
||||||
Author -- Michael Widrich, Andreas Schörgenhumer
|
|
||||||
Contact -- schoergenhumer@ml.jku.at
|
|
||||||
Date -- 04.03.2022
|
|
||||||
|
|
||||||
###############################################################################
|
|
||||||
|
|
||||||
The following copyright statement applies to all code within this file.
|
|
||||||
|
|
||||||
Copyright statement:
|
|
||||||
This material, no matter whether in printed or electronic form,
|
|
||||||
may be used for personal and non-commercial educational use only.
|
|
||||||
Any reproduction of this manuscript, no matter whether as a whole or in parts,
|
|
||||||
no matter whether in printed or in electronic form, requires explicit prior
|
|
||||||
acceptance of the authors.
|
|
||||||
|
|
||||||
###############################################################################
|
|
||||||
|
|
||||||
Images taken from: https://pixabay.com/
|
|
||||||
"""
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import sys
|
|
||||||
from glob import glob
|
|
||||||
|
|
||||||
import dill as pkl
|
|
||||||
|
|
||||||
|
|
||||||
def print_outs(outs, line_token="-"):
|
|
||||||
print(line_token * 40)
|
|
||||||
print(outs, end="" if isinstance(outs, str) and outs.endswith("\n") else "\n")
|
|
||||||
print(line_token * 40)
|
|
||||||
|
|
||||||
|
|
||||||
ex_file = "ex2.py"
|
|
||||||
full_points = 15
|
|
||||||
points = full_points
|
|
||||||
python = sys.executable
|
|
||||||
|
|
||||||
solutions_dir = os.path.join("unittest", "solutions")
|
|
||||||
outputs_dir = os.path.join("unittest", "outputs")
|
|
||||||
|
|
||||||
# Remove previous outputs folder
|
|
||||||
shutil.rmtree(outputs_dir, ignore_errors=True)
|
|
||||||
|
|
||||||
inputs = sorted(glob(os.path.join("unittest", "unittest_input_*"), recursive=True))
|
|
||||||
if not len(inputs):
|
|
||||||
raise FileNotFoundError("Could not find unittest_input_* files")
|
|
||||||
|
|
||||||
with open(os.path.join(solutions_dir, "counts.pkl"), "rb") as f:
|
|
||||||
sol_counts = pkl.load(f)
|
|
||||||
|
|
||||||
for test_i, input_folder in enumerate(inputs):
|
|
||||||
comment = ""
|
|
||||||
fcall = ""
|
|
||||||
|
|
||||||
with open(os.devnull, "w") as null:
|
|
||||||
# sys.stdout = null
|
|
||||||
try:
|
|
||||||
from ex2 import validate_images
|
|
||||||
|
|
||||||
proper_import = True
|
|
||||||
except Exception as e:
|
|
||||||
outs = ""
|
|
||||||
errs = e
|
|
||||||
points -= full_points / len(inputs)
|
|
||||||
proper_import = False
|
|
||||||
finally:
|
|
||||||
sys.stdout.flush()
|
|
||||||
sys.stdout = sys.__stdout__
|
|
||||||
|
|
||||||
if proper_import:
|
|
||||||
with open(os.devnull, "w") as null:
|
|
||||||
# sys.stdout = null
|
|
||||||
try:
|
|
||||||
input_basename = os.path.basename(input_folder)
|
|
||||||
output_dir = os.path.join(outputs_dir, input_basename)
|
|
||||||
logfilepath = output_dir + ".log"
|
|
||||||
formatter = "06d"
|
|
||||||
counts = validate_images(input_dir=input_folder, output_dir=output_dir, log_file=logfilepath,
|
|
||||||
formatter=formatter)
|
|
||||||
fcall = f'validate_images(\n\tinput_dir="{input_folder}",\n\toutput_dir="{output_dir}",\n\tlog_file="{logfilepath}",\n\tformatter="{formatter}"\n)'
|
|
||||||
errs = ""
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(os.path.join(outputs_dir, f"{input_basename}.log"), "r") as lfh:
|
|
||||||
logfile = lfh.read()
|
|
||||||
except FileNotFoundError:
|
|
||||||
# two cases:
|
|
||||||
# 1) no invalid files and thus no log file -> ok -> equal to empty tlogfile
|
|
||||||
# 2) invalid files but no log file -> not ok -> will fail the comparison with tlogfile (below)
|
|
||||||
logfile = ""
|
|
||||||
with open(os.path.join(solutions_dir, f"{input_basename}.log"), "r") as lfh:
|
|
||||||
# must replace the separator that was used when creating the solution files
|
|
||||||
tlogfile = lfh.read().replace("\\", os.path.sep)
|
|
||||||
|
|
||||||
files = sorted(glob(os.path.join(outputs_dir, input_basename, "**", "*"), recursive=True))
|
|
||||||
hashing_function = hashlib.sha256()
|
|
||||||
for file in files:
|
|
||||||
with open(file, "rb") as fh:
|
|
||||||
hashing_function.update(fh.read())
|
|
||||||
hash = hashing_function.digest()
|
|
||||||
hashing_function = hashlib.sha256()
|
|
||||||
tfiles = sorted(glob(os.path.join(solutions_dir, input_basename, "**", "*"), recursive=True))
|
|
||||||
for file in tfiles:
|
|
||||||
with open(file, "rb") as fh:
|
|
||||||
hashing_function.update(fh.read())
|
|
||||||
thash = hashing_function.digest()
|
|
||||||
|
|
||||||
tcounts = sol_counts[input_basename]
|
|
||||||
|
|
||||||
if not counts == tcounts:
|
|
||||||
points -= full_points / len(inputs)
|
|
||||||
comment = f"Function should return {tcounts} but returned {counts}"
|
|
||||||
elif not [f.split(os.path.sep)[-2:] for f in files] == [f.split(os.path.sep)[-2:] for f in tfiles]:
|
|
||||||
points -= full_points / len(inputs)
|
|
||||||
comment = f"Contents of output directory do not match (see directory 'solutions')"
|
|
||||||
elif not hash == thash:
|
|
||||||
points -= full_points / len(inputs)
|
|
||||||
comment = f"Hash value of the files in the output directory do not match (see directory 'solutions')"
|
|
||||||
elif not logfile == tlogfile:
|
|
||||||
points -= full_points / len(inputs)
|
|
||||||
comment = f"Contents of logfiles do not match (see directory 'solutions')"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
outs = ""
|
|
||||||
errs = e
|
|
||||||
points -= full_points / len(inputs)
|
|
||||||
finally:
|
|
||||||
sys.stdout.flush()
|
|
||||||
sys.stdout = sys.__stdout__
|
|
||||||
|
|
||||||
print()
|
|
||||||
print_outs(f"Test {test_i}", line_token="#")
|
|
||||||
print("Function call:")
|
|
||||||
print_outs(fcall)
|
|
||||||
|
|
||||||
if errs:
|
|
||||||
print(f"Some unexpected errors occurred:")
|
|
||||||
print_outs(f"{type(errs).__name__}: {errs}")
|
|
||||||
else:
|
|
||||||
print("Notes:")
|
|
||||||
print_outs("No issues found" if comment == "" else comment)
|
|
||||||
|
|
||||||
# due to floating point calculations it could happen that we get -0 here
|
|
||||||
if points < 0:
|
|
||||||
assert abs(points) < 1e-7, f"points were {points} < 0: error when subtracting points?"
|
|
||||||
points = abs(points)
|
|
||||||
print(f"Current points: {points:.2f}")
|
|
||||||
|
|
||||||
print(f"\nEstimated points upon submission: {points:.2f} (out of {full_points:.2f})")
|
|
||||||
if points < full_points:
|
|
||||||
print(f"Check the folder '{outputs_dir}' to see where your errors are")
|
|
||||||
else:
|
|
||||||
shutil.rmtree(os.path.join(outputs_dir))
|
|
||||||
print(f"This is only an estimate, see 'Instructions for submitting homework' in Moodle "
|
|
||||||
f"for common mistakes that can still lead to 0 points.")
|
|
204
ex3_unittest.py
204
ex3_unittest.py
@ -1,204 +0,0 @@
|
|||||||
"""
|
|
||||||
Author -- Michael Widrich, Andreas Schörgenhumer
|
|
||||||
Contact -- schoergenhumer@ml.jku.at
|
|
||||||
Date -- 02.03.2022
|
|
||||||
|
|
||||||
###############################################################################
|
|
||||||
|
|
||||||
The following copyright statement applies to all code within this file.
|
|
||||||
|
|
||||||
Copyright statement:
|
|
||||||
This material, no matter whether in printed or electronic form,
|
|
||||||
may be used for personal and non-commercial educational use only.
|
|
||||||
Any reproduction of this manuscript, no matter whether as a whole or in parts,
|
|
||||||
no matter whether in printed or in electronic form, requires explicit prior
|
|
||||||
acceptance of the authors.
|
|
||||||
|
|
||||||
###############################################################################
|
|
||||||
|
|
||||||
Images taken from: https://pixabay.com/
|
|
||||||
"""
|
|
||||||
|
|
||||||
import gzip
|
|
||||||
import os
|
|
||||||
import signal
|
|
||||||
import sys
|
|
||||||
from glob import glob
|
|
||||||
from types import GeneratorType
|
|
||||||
|
|
||||||
import dill as pkl
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def print_outs(outs, line_token="-"):
|
|
||||||
print(line_token * 40)
|
|
||||||
print(outs, end="" if isinstance(outs, str) and outs.endswith("\n") else "\n")
|
|
||||||
print(line_token * 40)
|
|
||||||
|
|
||||||
|
|
||||||
time_given = int(15)
|
|
||||||
|
|
||||||
check_for_timeout = hasattr(signal, "SIGALRM")
|
|
||||||
|
|
||||||
if check_for_timeout:
|
|
||||||
def handler(signum, frame):
|
|
||||||
raise TimeoutError(f"Timeout after {time_given}sec")
|
|
||||||
|
|
||||||
|
|
||||||
signal.signal(signal.SIGALRM, handler)
|
|
||||||
|
|
||||||
ex_file = "ex3.py"
|
|
||||||
full_points = 15
|
|
||||||
points = full_points
|
|
||||||
python = sys.executable
|
|
||||||
|
|
||||||
inputs = sorted(glob(os.path.join("unittest", "unittest_input_*"), recursive=True))
|
|
||||||
|
|
||||||
if not len(inputs):
|
|
||||||
raise FileNotFoundError("Could not find unittest_input_* files")
|
|
||||||
|
|
||||||
for test_i, input_folder in enumerate(inputs):
|
|
||||||
comment = ""
|
|
||||||
fcall = ""
|
|
||||||
|
|
||||||
with open(os.devnull, "w") as null:
|
|
||||||
# sys.stdout = null
|
|
||||||
try:
|
|
||||||
if check_for_timeout:
|
|
||||||
signal.alarm(time_given)
|
|
||||||
from ex3 import ImageStandardizer
|
|
||||||
|
|
||||||
signal.alarm(0)
|
|
||||||
else:
|
|
||||||
from ex3 import ImageStandardizer
|
|
||||||
proper_import = True
|
|
||||||
except Exception as e:
|
|
||||||
errs = e
|
|
||||||
points -= full_points / len(inputs)
|
|
||||||
proper_import = False
|
|
||||||
finally:
|
|
||||||
sys.stdout.flush()
|
|
||||||
sys.stdout = sys.__stdout__
|
|
||||||
|
|
||||||
if proper_import:
|
|
||||||
with open(os.devnull, "w") as null:
|
|
||||||
# sys.stdout = null
|
|
||||||
try:
|
|
||||||
if check_for_timeout:
|
|
||||||
signal.alarm(time_given)
|
|
||||||
# check constructor
|
|
||||||
instance = ImageStandardizer(input_dir=input_folder)
|
|
||||||
fcall = f"ImageStandardizer(input_dir='{input_folder}')"
|
|
||||||
signal.alarm(0)
|
|
||||||
else:
|
|
||||||
# check constructor
|
|
||||||
instance = ImageStandardizer(input_dir=input_folder)
|
|
||||||
fcall = f"ImageStandardizer(input_dir='{input_folder}')"
|
|
||||||
errs = ""
|
|
||||||
|
|
||||||
# check correct file names + sorting
|
|
||||||
input_basename = os.path.basename(input_folder)
|
|
||||||
with open(os.path.join("unittest", "solutions", input_basename, f"filenames.txt"), "r") as f:
|
|
||||||
# must replace the separator that was used when creating the solution files
|
|
||||||
files_sol = f.read().replace("\\", os.path.sep).splitlines()
|
|
||||||
# for simplicity's sake, only compare relative paths here
|
|
||||||
common = os.path.commonprefix(instance.files)
|
|
||||||
rel_instance_files = [os.path.join(input_folder, f[len(common):]) for f in instance.files]
|
|
||||||
if not hasattr(instance, "files"):
|
|
||||||
points -= full_points / len(inputs) / 3
|
|
||||||
comment += f"Attributes 'files' missing.\n"
|
|
||||||
elif rel_instance_files != files_sol:
|
|
||||||
points -= full_points / len(inputs) / 3
|
|
||||||
comment += f"Attribute 'files' should be {files_sol} but is {instance.files} (see directory 'solutions').\n"
|
|
||||||
elif len(instance.files) != len(files_sol):
|
|
||||||
points -= full_points / len(inputs) / 3
|
|
||||||
comment += f"Number of files should be {len(files_sol)} but is {len(instance.files)} (see directory 'solutions').\n"
|
|
||||||
|
|
||||||
# check if class has method analyze_images
|
|
||||||
method = "analyze_images"
|
|
||||||
if not hasattr(instance, method):
|
|
||||||
comment += f"Method '{method}' missing.\n"
|
|
||||||
points -= full_points / len(inputs) / 3
|
|
||||||
else:
|
|
||||||
# check for correct data types
|
|
||||||
stats = instance.analyze_images()
|
|
||||||
if (type(stats) is not tuple) or (len(stats) != 2):
|
|
||||||
points -= full_points / len(inputs) / 3
|
|
||||||
comment += f"Incorrect return value of method '{method}' (should be tuple of length 2).\n"
|
|
||||||
else:
|
|
||||||
with open(os.path.join("unittest", "solutions", input_basename, f"mean_and_std.pkl"),
|
|
||||||
"rb") as fh:
|
|
||||||
data = pkl.load(fh)
|
|
||||||
m = data["mean"]
|
|
||||||
s = data["std"]
|
|
||||||
if not (isinstance(stats[0], np.ndarray) and isinstance(stats[1], np.ndarray) and
|
|
||||||
stats[0].dtype == np.float64 and stats[1].dtype == np.float64 and
|
|
||||||
stats[0].shape == (3,) and stats[1].shape == (3,)):
|
|
||||||
points -= full_points / len(inputs) / 3
|
|
||||||
comment += f"Incorrect return data type of method '{method}' (tuple entries should be np.ndarray of dtype np.float64 and shape (3,)).\n"
|
|
||||||
else:
|
|
||||||
if not np.isclose(stats[0], m, atol=0).all():
|
|
||||||
points -= full_points / len(inputs) / 6
|
|
||||||
comment += f"Mean should be {m} but is {stats[0]} (see directory 'solutions').\n"
|
|
||||||
if not np.isclose(stats[1], s, atol=0).all():
|
|
||||||
points -= full_points / len(inputs) / 6
|
|
||||||
comment += f"Std should be {s} but is {stats[1]} (see directory 'solutions').\n"
|
|
||||||
|
|
||||||
# check if class has method get_standardized_images
|
|
||||||
method = "get_standardized_images"
|
|
||||||
if not hasattr(instance, method):
|
|
||||||
comment += f"Method '{method}' missing.\n"
|
|
||||||
points -= full_points / len(inputs) / 3
|
|
||||||
# check for correct data types
|
|
||||||
elif not isinstance(instance.get_standardized_images(), GeneratorType):
|
|
||||||
points -= full_points / len(inputs) / 3
|
|
||||||
comment += f"'{method}' is not a generator.\n"
|
|
||||||
else:
|
|
||||||
# Read correct image solutions
|
|
||||||
with gzip.open(os.path.join("unittest", "solutions", input_basename, "images.pkl"), "rb") as fh:
|
|
||||||
ims_sol = pkl.load(file=fh)
|
|
||||||
|
|
||||||
# Get image submissions
|
|
||||||
ims_sub = list(instance.get_standardized_images())
|
|
||||||
|
|
||||||
if not len(ims_sub) == len(ims_sol):
|
|
||||||
points -= full_points / len(inputs) / 3
|
|
||||||
comment += f"{len(ims_sol)} image arrays should have been returned but got {len(ims_sub)}.\n"
|
|
||||||
elif any([im_sub.dtype.num != np.dtype(np.float32).num for im_sub in ims_sub]):
|
|
||||||
points -= full_points / len(inputs) / 3
|
|
||||||
comment += f"Returned image arrays should have datatype np.float32 but at least one array isn't.\n"
|
|
||||||
else:
|
|
||||||
equal = [np.all(np.isclose(im_sub, im_sol, atol=0)) for im_sub, im_sol in zip(ims_sub, ims_sol)]
|
|
||||||
if not all(equal):
|
|
||||||
points -= full_points / len(inputs) / 3
|
|
||||||
comment += f"Returned images {list(np.where(np.logical_not(equal))[0])} do not match solution (see images.pkl files for solution).\n"
|
|
||||||
except Exception as e:
|
|
||||||
errs = e
|
|
||||||
points -= full_points / len(inputs)
|
|
||||||
finally:
|
|
||||||
sys.stdout.flush()
|
|
||||||
sys.stdout = sys.__stdout__
|
|
||||||
|
|
||||||
print()
|
|
||||||
print_outs(f"Test {test_i}", line_token="#")
|
|
||||||
print("Function call:")
|
|
||||||
print_outs(fcall)
|
|
||||||
|
|
||||||
if errs:
|
|
||||||
print(f"Some unexpected errors occurred:")
|
|
||||||
print_outs(f"{type(errs).__name__}: {errs}")
|
|
||||||
else:
|
|
||||||
print("Notes:")
|
|
||||||
print_outs("No issues found" if comment == "" else comment)
|
|
||||||
|
|
||||||
# due to floating point calculations it could happen that we get -0 here
|
|
||||||
if points < 0:
|
|
||||||
assert abs(points) < 1e-7, f"points were {points} < 0: error when subtracting points?"
|
|
||||||
points = abs(points)
|
|
||||||
print(f"Current points: {points:.2f}")
|
|
||||||
|
|
||||||
print(f"\nEstimated points upon submission: {points:.2f} (out of {full_points:.2f})")
|
|
||||||
print(f"This is only an estimate, see 'Instructions for submitting homework' in Moodle "
|
|
||||||
f"for common mistakes that can still lead to 0 points.")
|
|
||||||
if not check_for_timeout:
|
|
||||||
print("\n!!Warning: Had to switch to Windows compatibility version and did not check for timeouts!!")
|
|
143
ex4_unittest.py
143
ex4_unittest.py
@ -1,143 +0,0 @@
|
|||||||
"""
|
|
||||||
Author -- Michael Widrich
|
|
||||||
Contact -- widrich@ml.jku.at
|
|
||||||
Date -- 01.10.2019
|
|
||||||
|
|
||||||
###############################################################################
|
|
||||||
|
|
||||||
The following copyright statement applies to all code within this file.
|
|
||||||
|
|
||||||
Copyright statement:
|
|
||||||
This material, no matter whether in printed or electronic form,
|
|
||||||
may be used for personal and non-commercial educational use only.
|
|
||||||
Any reproduction of this manuscript, no matter whether as a whole or in parts,
|
|
||||||
no matter whether in printed or in electronic form, requires explicit prior
|
|
||||||
acceptance of the authors.
|
|
||||||
|
|
||||||
###############################################################################
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import dill as pkl
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def print_outs(outs, line_token="-"):
|
|
||||||
print(line_token * 40)
|
|
||||||
print(outs, end="" if isinstance(outs, str) and outs.endswith("\n") else "\n")
|
|
||||||
print(line_token * 40)
|
|
||||||
|
|
||||||
|
|
||||||
ex_file = 'ex4.py'
|
|
||||||
full_points = 20
|
|
||||||
points = full_points
|
|
||||||
python = sys.executable
|
|
||||||
|
|
||||||
with open(os.path.join("unittest", "unittest_inputs_outputs.pkl"), "rb") as ufh:
|
|
||||||
all_inputs_outputs = pkl.load(ufh)
|
|
||||||
all_inputs = all_inputs_outputs['inputs']
|
|
||||||
all_outputs = all_inputs_outputs['outputs']
|
|
||||||
|
|
||||||
feedback = ''
|
|
||||||
|
|
||||||
for test_i, (inputs, outputs) in enumerate(zip(all_inputs, all_outputs)):
|
|
||||||
|
|
||||||
comment = ''
|
|
||||||
fcall = ''
|
|
||||||
with open(os.devnull, 'w') as null:
|
|
||||||
# sys.stdout = null
|
|
||||||
try:
|
|
||||||
from ex4 import ex4
|
|
||||||
proper_import = True
|
|
||||||
except Exception:
|
|
||||||
outs = ''
|
|
||||||
errs = traceback.format_exc()
|
|
||||||
points -= full_points / len(all_inputs_outputs)
|
|
||||||
proper_import = False
|
|
||||||
finally:
|
|
||||||
sys.stdout.flush()
|
|
||||||
sys.stdout = sys.__stdout__
|
|
||||||
|
|
||||||
if proper_import:
|
|
||||||
with open(os.devnull, 'w') as null:
|
|
||||||
# sys.stdout = null
|
|
||||||
try:
|
|
||||||
errs = ''
|
|
||||||
fcall = f"ex4(image_array={inputs[0]}, offset={inputs[1]}, spacing={inputs[2]}))"
|
|
||||||
returns = ex4(image_array=inputs[0], offset=inputs[1],
|
|
||||||
spacing=inputs[2])
|
|
||||||
|
|
||||||
# Check if returns and outputs are of same type
|
|
||||||
if type(returns) != type(outputs):
|
|
||||||
comment = f"Output should be: {type(outputs).__name__} ('{outputs}'). \n" \
|
|
||||||
f" but is: {returns}"
|
|
||||||
points -= full_points / len(all_inputs)
|
|
||||||
else:
|
|
||||||
# Check input_array output
|
|
||||||
if (len(returns) != 3
|
|
||||||
or not isinstance(returns[0], np.ndarray)
|
|
||||||
or returns[0].dtype != outputs[0].dtype
|
|
||||||
or returns[0].shape != outputs[0].shape
|
|
||||||
or np.any(returns[0] != outputs[0])):
|
|
||||||
points -= (full_points / len(all_inputs)) / 3
|
|
||||||
comment = f"Incorrect 'input_array'. Output should be: " \
|
|
||||||
f"{outputs} \n" \
|
|
||||||
f"but is {returns}"
|
|
||||||
|
|
||||||
# Check known_array output
|
|
||||||
if (len(returns) != 3
|
|
||||||
or not isinstance(returns[1], np.ndarray)
|
|
||||||
or returns[1].dtype != outputs[1].dtype
|
|
||||||
or returns[1].shape != outputs[1].shape
|
|
||||||
or np.any(returns[1] != outputs[1])):
|
|
||||||
points -= (full_points / len(all_inputs)) / 3
|
|
||||||
comment = f"Incorrect 'known_array'. Output should be: " \
|
|
||||||
f"{outputs} \n" \
|
|
||||||
f"but is {returns}"
|
|
||||||
|
|
||||||
# Check target_array output
|
|
||||||
if (len(returns) != 3
|
|
||||||
or not isinstance(returns[2], np.ndarray)
|
|
||||||
or returns[2].dtype != outputs[2].dtype
|
|
||||||
or returns[2].shape != outputs[2].shape
|
|
||||||
or np.any(returns[2] != outputs[2])):
|
|
||||||
points -= (full_points / len(all_inputs)) / 3
|
|
||||||
comment = f"Incorrect 'target_array'. Output should be: " \
|
|
||||||
f"{outputs} \n" \
|
|
||||||
f"but is {returns}"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
outs = ''
|
|
||||||
if not type(e) == type(outputs):
|
|
||||||
comment = f"Output should be: {type(outputs).__name__} ('{outputs}'). \n" \
|
|
||||||
f" but is:\n{traceback.format_exc()}"
|
|
||||||
points -= full_points / len(all_inputs)
|
|
||||||
finally:
|
|
||||||
sys.stdout.flush()
|
|
||||||
sys.stdout = sys.__stdout__
|
|
||||||
|
|
||||||
print()
|
|
||||||
print_outs(f"Test {test_i}", line_token="#")
|
|
||||||
print("Function call:")
|
|
||||||
print_outs(fcall)
|
|
||||||
|
|
||||||
if errs:
|
|
||||||
print(f"Some unexpected errors occurred:")
|
|
||||||
print_outs(errs)
|
|
||||||
else:
|
|
||||||
print("Notes:")
|
|
||||||
print_outs("No issues found" if comment == "" else comment)
|
|
||||||
|
|
||||||
# due to floating point calculations it could happen that we get -0 here
|
|
||||||
if points < 0:
|
|
||||||
assert abs(points) < 1e-7, f"points were {points} < 0: error when subtracting points?"
|
|
||||||
points = 0
|
|
||||||
print(f"Current points: {points:.2f}")
|
|
||||||
|
|
||||||
print(f"\nEstimated points upon submission: {points:.2f} (out of {full_points:.2f})")
|
|
||||||
print(f"This is only an estimate, see 'Instructions for submitting homework' in Moodle "
|
|
||||||
f"for common mistakes that can still lead to 0 points.")
|
|
4
netio.py
4
netio.py
@ -4,11 +4,11 @@ from Net import ImageNN
|
|||||||
|
|
||||||
|
|
||||||
def save_model(model: torch.nn.Module):
|
def save_model(model: torch.nn.Module):
|
||||||
torch.save(model.state_dict(), 'impaintmodel.pth')
|
torch.save(model, 'impaintmodel.pt')
|
||||||
|
|
||||||
|
|
||||||
def load_model():
|
def load_model():
|
||||||
model = ImageNN()
|
model = ImageNN()
|
||||||
model.load_state_dict(torch.load('model_weights.pth'))
|
model.load_state_dict(torch.load('impaintmodel.pt'))
|
||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
|
Loading…
Reference in New Issue
Block a user