diff --git a/DataLoader.py b/DataLoader.py index 04af5ed..d6cedda 100644 --- a/DataLoader.py +++ b/DataLoader.py @@ -3,9 +3,14 @@ import os import numpy as np import torch.utils.data.dataset -from PIL import Image from torch.utils.data import Dataset from torch.utils.data import DataLoader +from torchvision import transforms +from PIL import Image + +import ex4 + +IMG_SIZE = 100 class ImageDataset(Dataset): @@ -17,10 +22,27 @@ class ImageDataset(Dataset): def __getitem__(self, index): # Open image file, convert to numpy array and scale to [0, 1] - image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255 + target_image = Image.open(self.image_files[index]) + # image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255 + resize_transforms = transforms.Compose([ + transforms.Resize(size=IMG_SIZE), + transforms.CenterCrop(size=(IMG_SIZE, IMG_SIZE)), + ]) + target_image = resize_transforms(target_image) + + # normalize image from 0-1 + target_image = np.array(target_image, dtype=np.float64) / 255.0 + # Perform normalization for each channel - image = (image - self.norm_mean) / self.norm_std - return image, index + # image = (image - self.norm_mean) / self.norm_std + + # calculate image with black grid + doomed_image = ex4.ex4(target_image, (5, 5), (4, 4)) + + # convert image to grayscale + # target_image = rgb2gray(target_image) # todo look if gray image makes sense + + return doomed_image[0], np.transpose(target_image, (2, 0, 1)) def __len__(self): return len(self.image_files) @@ -35,15 +57,21 @@ def get_image_loader(path: str): train_loader = DataLoader( trains, shuffle=True, # shuffle the order of our samples - batch_size=4, # stack 4 samples to a minibatch - num_workers=0 # no background workers (see comment below) + batch_size=5, # stack 4 samples to a minibatch + num_workers=2 # no background workers (see comment below) ) test_loader = DataLoader( - tsts, + tests, shuffle=True, # shuffle the order of our samples - batch_size=4, # stack 4 samples to a minibatch + batch_size=1, # stack 4 samples to a minibatch num_workers=0 # no background workers (see comment below) ) return train_loader, test_loader + + +def rgb2gray(rgb_array: np.ndarray): + r, g, b = rgb_array[:, :, 0], rgb_array[:, :, 1], rgb_array[:, :, 2] + gray = 0.2989 * r + 0.5870 * g + 0.1140 * b + return gray diff --git a/ImageImpaint.py b/ImageImpaint.py index e91bba8..fe1261b 100644 --- a/ImageImpaint.py +++ b/ImageImpaint.py @@ -1,26 +1,48 @@ +import numpy as np import torch from DataLoader import get_image_loader from Net import ImageNN - # 01.05.22 -- 0.5h from netio import save_model, load_model +def get_train_device(): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + print(f'Device to train net: {device}') + return torch.device(device) + + def train_model(): - train_loader, test_loader = get_image_loader("my/supercool/image/dir") - nn = ImageNN() # todo pass size ason. + # Set a known random seed for reproducibility + np.random.seed(0) + torch.manual_seed(0) + device = get_train_device() + + # Load datasets + train_loader, test_loader = get_image_loader("training/") + nn = ImageNN(n_in_channels=3) # todo pass size ason. nn.train() # init with train mode + nn.to(device) # send net to device available optimizer = torch.optim.SGD(nn.parameters(), lr=0.1) # todo adjust parameters and lr - loss_function = torch.nn.CrossEntropyLoss() + loss_function = torch.nn.MSELoss() n_epochs = 15 # todo epcchs here + # todo look wtf is that + nn.double() + + train_sample_size = len(train_loader) losses = [] + best_eval_loss = 0 for epoch in range(n_epochs): print(f"Epoch {epoch}/{n_epochs}\n") + i = 0 for input_tensor, target_tensor in train_loader: + input_tensor.to(device) + target_tensor.to(device) + output = nn(input_tensor) # get model output (forward pass) loss = loss_function(output, target_tensor) # compute loss given model output and true target loss.backward() # compute gradients (backward pass) @@ -28,18 +50,44 @@ def train_model(): optimizer.zero_grad() # reset gradients losses.append(loss.item()) + print( + f'\rTraining epoch {epoch} [{i}/{train_sample_size * train_loader.batch_size}] (curr loss: {loss.item():.3})', + end='') + i += train_loader.batch_size + + # eval model every 500th element + if i % 500 == 0: + print(f"\nEvaluating model") + eval_loss = eval_model(nn, test_loader, loss_function, device) + print(f"Evalution loss={eval_loss}") + if eval_loss > best_eval_loss: + best_eval_loss = eval_loss + save_model(nn) + # switch net to eval mode - nn.eval() + print(eval_model(nn, test_loader, loss_function, device=device)) + + +# func to evaluate our trained model +def eval_model(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn, device: torch.device): + # switch to eval mode + model.eval() + loss = .0 + # disable gradient calculations with torch.no_grad(): - for input_tensor, target_tensor in test_loader: - # iterate testloader and we have to decide somhow now the goodness of the prediction - out = nn(input_tensor) # apply model + i = 0 + for input, target in dataloader: + input.to(device) + target.to(device) - diff = out - target_tensor - # todo evaluate trained model on testset loader - - # todo save trained model to blob file - save_model(nn) + out = model(input) + loss += loss_fn(out, target).item() + print(f'\rEval prog[{i}/{len(dataloader) * dataloader.batch_size}]', end='') + i += dataloader.batch_size + print() + loss /= len(dataloader) + model.train() + return loss def apply_model(): diff --git a/Net.py b/Net.py index 6b66150..4766c9c 100644 --- a/Net.py +++ b/Net.py @@ -2,10 +2,31 @@ import torch class ImageNN(torch.nn.Module): - def __init__(self): + def __init__(self, n_in_channels: int = 1, n_hidden_layers: int = 3, n_kernels: int = 32, kernel_size: int = 7): + """Simple CNN with `n_hidden_layers`, `n_kernels`, and `kernel_size` as hyperparameters""" super().__init__() - # todo implement the nn structure - def forward(self, x: torch.Tensor): - pass - # todo implement forward + cnn = [] + for i in range(n_hidden_layers): + cnn.append(torch.nn.Conv2d( + in_channels=n_in_channels, + out_channels=n_kernels, + kernel_size=kernel_size, + padding=int(kernel_size / 2) + )) + cnn.append(torch.nn.ReLU()) + n_in_channels = n_kernels + self.hidden_layers = torch.nn.Sequential(*cnn) + + self.output_layer = torch.nn.Conv2d( + in_channels=n_in_channels, + out_channels=3, + kernel_size=kernel_size, + padding=int(kernel_size / 2) + ) + + def forward(self, x): + """Apply CNN to input `x` of shape (N, n_channels, X, Y), where N=n_samples and X, Y are spatial dimensions""" + cnn_out = self.hidden_layers(x) # apply hidden layers (N, n_in_channels, X, Y) -> (N, n_kernels, X, Y) + pred = self.output_layer(cnn_out) # apply output layer (N, n_kernels, X, Y) -> (N, 1, X, Y) + return pred diff --git a/ex2_unittest.py b/ex2_unittest.py deleted file mode 100644 index 28df537..0000000 --- a/ex2_unittest.py +++ /dev/null @@ -1,159 +0,0 @@ -""" -Author -- Michael Widrich, Andreas Schörgenhumer -Contact -- schoergenhumer@ml.jku.at -Date -- 04.03.2022 - -############################################################################### - -The following copyright statement applies to all code within this file. - -Copyright statement: -This material, no matter whether in printed or electronic form, -may be used for personal and non-commercial educational use only. -Any reproduction of this manuscript, no matter whether as a whole or in parts, -no matter whether in printed or in electronic form, requires explicit prior -acceptance of the authors. - -############################################################################### - -Images taken from: https://pixabay.com/ -""" - -import hashlib -import os -import shutil -import sys -from glob import glob - -import dill as pkl - - -def print_outs(outs, line_token="-"): - print(line_token * 40) - print(outs, end="" if isinstance(outs, str) and outs.endswith("\n") else "\n") - print(line_token * 40) - - -ex_file = "ex2.py" -full_points = 15 -points = full_points -python = sys.executable - -solutions_dir = os.path.join("unittest", "solutions") -outputs_dir = os.path.join("unittest", "outputs") - -# Remove previous outputs folder -shutil.rmtree(outputs_dir, ignore_errors=True) - -inputs = sorted(glob(os.path.join("unittest", "unittest_input_*"), recursive=True)) -if not len(inputs): - raise FileNotFoundError("Could not find unittest_input_* files") - -with open(os.path.join(solutions_dir, "counts.pkl"), "rb") as f: - sol_counts = pkl.load(f) - -for test_i, input_folder in enumerate(inputs): - comment = "" - fcall = "" - - with open(os.devnull, "w") as null: - # sys.stdout = null - try: - from ex2 import validate_images - - proper_import = True - except Exception as e: - outs = "" - errs = e - points -= full_points / len(inputs) - proper_import = False - finally: - sys.stdout.flush() - sys.stdout = sys.__stdout__ - - if proper_import: - with open(os.devnull, "w") as null: - # sys.stdout = null - try: - input_basename = os.path.basename(input_folder) - output_dir = os.path.join(outputs_dir, input_basename) - logfilepath = output_dir + ".log" - formatter = "06d" - counts = validate_images(input_dir=input_folder, output_dir=output_dir, log_file=logfilepath, - formatter=formatter) - fcall = f'validate_images(\n\tinput_dir="{input_folder}",\n\toutput_dir="{output_dir}",\n\tlog_file="{logfilepath}",\n\tformatter="{formatter}"\n)' - errs = "" - - try: - with open(os.path.join(outputs_dir, f"{input_basename}.log"), "r") as lfh: - logfile = lfh.read() - except FileNotFoundError: - # two cases: - # 1) no invalid files and thus no log file -> ok -> equal to empty tlogfile - # 2) invalid files but no log file -> not ok -> will fail the comparison with tlogfile (below) - logfile = "" - with open(os.path.join(solutions_dir, f"{input_basename}.log"), "r") as lfh: - # must replace the separator that was used when creating the solution files - tlogfile = lfh.read().replace("\\", os.path.sep) - - files = sorted(glob(os.path.join(outputs_dir, input_basename, "**", "*"), recursive=True)) - hashing_function = hashlib.sha256() - for file in files: - with open(file, "rb") as fh: - hashing_function.update(fh.read()) - hash = hashing_function.digest() - hashing_function = hashlib.sha256() - tfiles = sorted(glob(os.path.join(solutions_dir, input_basename, "**", "*"), recursive=True)) - for file in tfiles: - with open(file, "rb") as fh: - hashing_function.update(fh.read()) - thash = hashing_function.digest() - - tcounts = sol_counts[input_basename] - - if not counts == tcounts: - points -= full_points / len(inputs) - comment = f"Function should return {tcounts} but returned {counts}" - elif not [f.split(os.path.sep)[-2:] for f in files] == [f.split(os.path.sep)[-2:] for f in tfiles]: - points -= full_points / len(inputs) - comment = f"Contents of output directory do not match (see directory 'solutions')" - elif not hash == thash: - points -= full_points / len(inputs) - comment = f"Hash value of the files in the output directory do not match (see directory 'solutions')" - elif not logfile == tlogfile: - points -= full_points / len(inputs) - comment = f"Contents of logfiles do not match (see directory 'solutions')" - - except Exception as e: - outs = "" - errs = e - points -= full_points / len(inputs) - finally: - sys.stdout.flush() - sys.stdout = sys.__stdout__ - - print() - print_outs(f"Test {test_i}", line_token="#") - print("Function call:") - print_outs(fcall) - - if errs: - print(f"Some unexpected errors occurred:") - print_outs(f"{type(errs).__name__}: {errs}") - else: - print("Notes:") - print_outs("No issues found" if comment == "" else comment) - - # due to floating point calculations it could happen that we get -0 here - if points < 0: - assert abs(points) < 1e-7, f"points were {points} < 0: error when subtracting points?" - points = abs(points) - print(f"Current points: {points:.2f}") - -print(f"\nEstimated points upon submission: {points:.2f} (out of {full_points:.2f})") -if points < full_points: - print(f"Check the folder '{outputs_dir}' to see where your errors are") -else: - shutil.rmtree(os.path.join(outputs_dir)) -print(f"This is only an estimate, see 'Instructions for submitting homework' in Moodle " - f"for common mistakes that can still lead to 0 points.") diff --git a/ex3_unittest.py b/ex3_unittest.py deleted file mode 100644 index 014019a..0000000 --- a/ex3_unittest.py +++ /dev/null @@ -1,204 +0,0 @@ -""" -Author -- Michael Widrich, Andreas Schörgenhumer -Contact -- schoergenhumer@ml.jku.at -Date -- 02.03.2022 - -############################################################################### - -The following copyright statement applies to all code within this file. - -Copyright statement: -This material, no matter whether in printed or electronic form, -may be used for personal and non-commercial educational use only. -Any reproduction of this manuscript, no matter whether as a whole or in parts, -no matter whether in printed or in electronic form, requires explicit prior -acceptance of the authors. - -############################################################################### - -Images taken from: https://pixabay.com/ -""" - -import gzip -import os -import signal -import sys -from glob import glob -from types import GeneratorType - -import dill as pkl -import numpy as np - - -def print_outs(outs, line_token="-"): - print(line_token * 40) - print(outs, end="" if isinstance(outs, str) and outs.endswith("\n") else "\n") - print(line_token * 40) - - -time_given = int(15) - -check_for_timeout = hasattr(signal, "SIGALRM") - -if check_for_timeout: - def handler(signum, frame): - raise TimeoutError(f"Timeout after {time_given}sec") - - - signal.signal(signal.SIGALRM, handler) - -ex_file = "ex3.py" -full_points = 15 -points = full_points -python = sys.executable - -inputs = sorted(glob(os.path.join("unittest", "unittest_input_*"), recursive=True)) - -if not len(inputs): - raise FileNotFoundError("Could not find unittest_input_* files") - -for test_i, input_folder in enumerate(inputs): - comment = "" - fcall = "" - - with open(os.devnull, "w") as null: - # sys.stdout = null - try: - if check_for_timeout: - signal.alarm(time_given) - from ex3 import ImageStandardizer - - signal.alarm(0) - else: - from ex3 import ImageStandardizer - proper_import = True - except Exception as e: - errs = e - points -= full_points / len(inputs) - proper_import = False - finally: - sys.stdout.flush() - sys.stdout = sys.__stdout__ - - if proper_import: - with open(os.devnull, "w") as null: - # sys.stdout = null - try: - if check_for_timeout: - signal.alarm(time_given) - # check constructor - instance = ImageStandardizer(input_dir=input_folder) - fcall = f"ImageStandardizer(input_dir='{input_folder}')" - signal.alarm(0) - else: - # check constructor - instance = ImageStandardizer(input_dir=input_folder) - fcall = f"ImageStandardizer(input_dir='{input_folder}')" - errs = "" - - # check correct file names + sorting - input_basename = os.path.basename(input_folder) - with open(os.path.join("unittest", "solutions", input_basename, f"filenames.txt"), "r") as f: - # must replace the separator that was used when creating the solution files - files_sol = f.read().replace("\\", os.path.sep).splitlines() - # for simplicity's sake, only compare relative paths here - common = os.path.commonprefix(instance.files) - rel_instance_files = [os.path.join(input_folder, f[len(common):]) for f in instance.files] - if not hasattr(instance, "files"): - points -= full_points / len(inputs) / 3 - comment += f"Attributes 'files' missing.\n" - elif rel_instance_files != files_sol: - points -= full_points / len(inputs) / 3 - comment += f"Attribute 'files' should be {files_sol} but is {instance.files} (see directory 'solutions').\n" - elif len(instance.files) != len(files_sol): - points -= full_points / len(inputs) / 3 - comment += f"Number of files should be {len(files_sol)} but is {len(instance.files)} (see directory 'solutions').\n" - - # check if class has method analyze_images - method = "analyze_images" - if not hasattr(instance, method): - comment += f"Method '{method}' missing.\n" - points -= full_points / len(inputs) / 3 - else: - # check for correct data types - stats = instance.analyze_images() - if (type(stats) is not tuple) or (len(stats) != 2): - points -= full_points / len(inputs) / 3 - comment += f"Incorrect return value of method '{method}' (should be tuple of length 2).\n" - else: - with open(os.path.join("unittest", "solutions", input_basename, f"mean_and_std.pkl"), - "rb") as fh: - data = pkl.load(fh) - m = data["mean"] - s = data["std"] - if not (isinstance(stats[0], np.ndarray) and isinstance(stats[1], np.ndarray) and - stats[0].dtype == np.float64 and stats[1].dtype == np.float64 and - stats[0].shape == (3,) and stats[1].shape == (3,)): - points -= full_points / len(inputs) / 3 - comment += f"Incorrect return data type of method '{method}' (tuple entries should be np.ndarray of dtype np.float64 and shape (3,)).\n" - else: - if not np.isclose(stats[0], m, atol=0).all(): - points -= full_points / len(inputs) / 6 - comment += f"Mean should be {m} but is {stats[0]} (see directory 'solutions').\n" - if not np.isclose(stats[1], s, atol=0).all(): - points -= full_points / len(inputs) / 6 - comment += f"Std should be {s} but is {stats[1]} (see directory 'solutions').\n" - - # check if class has method get_standardized_images - method = "get_standardized_images" - if not hasattr(instance, method): - comment += f"Method '{method}' missing.\n" - points -= full_points / len(inputs) / 3 - # check for correct data types - elif not isinstance(instance.get_standardized_images(), GeneratorType): - points -= full_points / len(inputs) / 3 - comment += f"'{method}' is not a generator.\n" - else: - # Read correct image solutions - with gzip.open(os.path.join("unittest", "solutions", input_basename, "images.pkl"), "rb") as fh: - ims_sol = pkl.load(file=fh) - - # Get image submissions - ims_sub = list(instance.get_standardized_images()) - - if not len(ims_sub) == len(ims_sol): - points -= full_points / len(inputs) / 3 - comment += f"{len(ims_sol)} image arrays should have been returned but got {len(ims_sub)}.\n" - elif any([im_sub.dtype.num != np.dtype(np.float32).num for im_sub in ims_sub]): - points -= full_points / len(inputs) / 3 - comment += f"Returned image arrays should have datatype np.float32 but at least one array isn't.\n" - else: - equal = [np.all(np.isclose(im_sub, im_sol, atol=0)) for im_sub, im_sol in zip(ims_sub, ims_sol)] - if not all(equal): - points -= full_points / len(inputs) / 3 - comment += f"Returned images {list(np.where(np.logical_not(equal))[0])} do not match solution (see images.pkl files for solution).\n" - except Exception as e: - errs = e - points -= full_points / len(inputs) - finally: - sys.stdout.flush() - sys.stdout = sys.__stdout__ - - print() - print_outs(f"Test {test_i}", line_token="#") - print("Function call:") - print_outs(fcall) - - if errs: - print(f"Some unexpected errors occurred:") - print_outs(f"{type(errs).__name__}: {errs}") - else: - print("Notes:") - print_outs("No issues found" if comment == "" else comment) - - # due to floating point calculations it could happen that we get -0 here - if points < 0: - assert abs(points) < 1e-7, f"points were {points} < 0: error when subtracting points?" - points = abs(points) - print(f"Current points: {points:.2f}") - -print(f"\nEstimated points upon submission: {points:.2f} (out of {full_points:.2f})") -print(f"This is only an estimate, see 'Instructions for submitting homework' in Moodle " - f"for common mistakes that can still lead to 0 points.") -if not check_for_timeout: - print("\n!!Warning: Had to switch to Windows compatibility version and did not check for timeouts!!") diff --git a/ex4_unittest.py b/ex4_unittest.py deleted file mode 100644 index ec1215e..0000000 --- a/ex4_unittest.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Author -- Michael Widrich -Contact -- widrich@ml.jku.at -Date -- 01.10.2019 - -############################################################################### - -The following copyright statement applies to all code within this file. - -Copyright statement: -This material, no matter whether in printed or electronic form, -may be used for personal and non-commercial educational use only. -Any reproduction of this manuscript, no matter whether as a whole or in parts, -no matter whether in printed or in electronic form, requires explicit prior -acceptance of the authors. - -############################################################################### - -""" - -import os -import sys -import traceback - -import dill as pkl -import numpy as np - - -def print_outs(outs, line_token="-"): - print(line_token * 40) - print(outs, end="" if isinstance(outs, str) and outs.endswith("\n") else "\n") - print(line_token * 40) - - -ex_file = 'ex4.py' -full_points = 20 -points = full_points -python = sys.executable - -with open(os.path.join("unittest", "unittest_inputs_outputs.pkl"), "rb") as ufh: - all_inputs_outputs = pkl.load(ufh) - all_inputs = all_inputs_outputs['inputs'] - all_outputs = all_inputs_outputs['outputs'] - -feedback = '' - -for test_i, (inputs, outputs) in enumerate(zip(all_inputs, all_outputs)): - - comment = '' - fcall = '' - with open(os.devnull, 'w') as null: - # sys.stdout = null - try: - from ex4 import ex4 - proper_import = True - except Exception: - outs = '' - errs = traceback.format_exc() - points -= full_points / len(all_inputs_outputs) - proper_import = False - finally: - sys.stdout.flush() - sys.stdout = sys.__stdout__ - - if proper_import: - with open(os.devnull, 'w') as null: - # sys.stdout = null - try: - errs = '' - fcall = f"ex4(image_array={inputs[0]}, offset={inputs[1]}, spacing={inputs[2]}))" - returns = ex4(image_array=inputs[0], offset=inputs[1], - spacing=inputs[2]) - - # Check if returns and outputs are of same type - if type(returns) != type(outputs): - comment = f"Output should be: {type(outputs).__name__} ('{outputs}'). \n" \ - f" but is: {returns}" - points -= full_points / len(all_inputs) - else: - # Check input_array output - if (len(returns) != 3 - or not isinstance(returns[0], np.ndarray) - or returns[0].dtype != outputs[0].dtype - or returns[0].shape != outputs[0].shape - or np.any(returns[0] != outputs[0])): - points -= (full_points / len(all_inputs)) / 3 - comment = f"Incorrect 'input_array'. Output should be: " \ - f"{outputs} \n" \ - f"but is {returns}" - - # Check known_array output - if (len(returns) != 3 - or not isinstance(returns[1], np.ndarray) - or returns[1].dtype != outputs[1].dtype - or returns[1].shape != outputs[1].shape - or np.any(returns[1] != outputs[1])): - points -= (full_points / len(all_inputs)) / 3 - comment = f"Incorrect 'known_array'. Output should be: " \ - f"{outputs} \n" \ - f"but is {returns}" - - # Check target_array output - if (len(returns) != 3 - or not isinstance(returns[2], np.ndarray) - or returns[2].dtype != outputs[2].dtype - or returns[2].shape != outputs[2].shape - or np.any(returns[2] != outputs[2])): - points -= (full_points / len(all_inputs)) / 3 - comment = f"Incorrect 'target_array'. Output should be: " \ - f"{outputs} \n" \ - f"but is {returns}" - - except Exception as e: - outs = '' - if not type(e) == type(outputs): - comment = f"Output should be: {type(outputs).__name__} ('{outputs}'). \n" \ - f" but is:\n{traceback.format_exc()}" - points -= full_points / len(all_inputs) - finally: - sys.stdout.flush() - sys.stdout = sys.__stdout__ - - print() - print_outs(f"Test {test_i}", line_token="#") - print("Function call:") - print_outs(fcall) - - if errs: - print(f"Some unexpected errors occurred:") - print_outs(errs) - else: - print("Notes:") - print_outs("No issues found" if comment == "" else comment) - - # due to floating point calculations it could happen that we get -0 here - if points < 0: - assert abs(points) < 1e-7, f"points were {points} < 0: error when subtracting points?" - points = 0 - print(f"Current points: {points:.2f}") - -print(f"\nEstimated points upon submission: {points:.2f} (out of {full_points:.2f})") -print(f"This is only an estimate, see 'Instructions for submitting homework' in Moodle " - f"for common mistakes that can still lead to 0 points.") diff --git a/netio.py b/netio.py index d9ed115..f6a9e1f 100644 --- a/netio.py +++ b/netio.py @@ -4,11 +4,11 @@ from Net import ImageNN def save_model(model: torch.nn.Module): - torch.save(model.state_dict(), 'impaintmodel.pth') + torch.save(model, 'impaintmodel.pt') def load_model(): model = ImageNN() - model.load_state_dict(torch.load('model_weights.pth')) + model.load_state_dict(torch.load('impaintmodel.pt')) model.eval() return model