basic training logic
This commit is contained in:
		| @@ -3,9 +3,14 @@ import os | |||||||
|  |  | ||||||
| import numpy as np | import numpy as np | ||||||
| import torch.utils.data.dataset | import torch.utils.data.dataset | ||||||
| from PIL import Image |  | ||||||
| from torch.utils.data import Dataset | from torch.utils.data import Dataset | ||||||
| from torch.utils.data import DataLoader | from torch.utils.data import DataLoader | ||||||
|  | from torchvision import transforms | ||||||
|  | from PIL import Image | ||||||
|  |  | ||||||
|  | import ex4 | ||||||
|  |  | ||||||
|  | IMG_SIZE = 100 | ||||||
|  |  | ||||||
|  |  | ||||||
| class ImageDataset(Dataset): | class ImageDataset(Dataset): | ||||||
| @@ -17,10 +22,27 @@ class ImageDataset(Dataset): | |||||||
|  |  | ||||||
|     def __getitem__(self, index): |     def __getitem__(self, index): | ||||||
|         # Open image file, convert to numpy array and scale to [0, 1] |         # Open image file, convert to numpy array and scale to [0, 1] | ||||||
|         image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255 |         target_image = Image.open(self.image_files[index]) | ||||||
|  |         # image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255 | ||||||
|  |         resize_transforms = transforms.Compose([ | ||||||
|  |             transforms.Resize(size=IMG_SIZE), | ||||||
|  |             transforms.CenterCrop(size=(IMG_SIZE, IMG_SIZE)), | ||||||
|  |         ]) | ||||||
|  |         target_image = resize_transforms(target_image) | ||||||
|  |  | ||||||
|  |         # normalize image from 0-1 | ||||||
|  |         target_image = np.array(target_image, dtype=np.float64) / 255.0 | ||||||
|  |  | ||||||
|         # Perform normalization for each channel |         # Perform normalization for each channel | ||||||
|         image = (image - self.norm_mean) / self.norm_std |         # image = (image - self.norm_mean) / self.norm_std | ||||||
|         return image, index |  | ||||||
|  |         # calculate image with black grid | ||||||
|  |         doomed_image = ex4.ex4(target_image, (5, 5), (4, 4)) | ||||||
|  |  | ||||||
|  |         # convert image to grayscale | ||||||
|  |         # target_image = rgb2gray(target_image)  # todo look if gray image makes sense | ||||||
|  |  | ||||||
|  |         return doomed_image[0], np.transpose(target_image, (2, 0, 1)) | ||||||
|  |  | ||||||
|     def __len__(self): |     def __len__(self): | ||||||
|         return len(self.image_files) |         return len(self.image_files) | ||||||
| @@ -35,15 +57,21 @@ def get_image_loader(path: str): | |||||||
|     train_loader = DataLoader( |     train_loader = DataLoader( | ||||||
|         trains, |         trains, | ||||||
|         shuffle=True,  # shuffle the order of our samples |         shuffle=True,  # shuffle the order of our samples | ||||||
|         batch_size=4,  # stack 4 samples to a minibatch |         batch_size=5,  # stack 4 samples to a minibatch | ||||||
|         num_workers=0  # no background workers (see comment below) |         num_workers=2  # no background workers (see comment below) | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     test_loader = DataLoader( |     test_loader = DataLoader( | ||||||
|         tsts, |         tests, | ||||||
|         shuffle=True,  # shuffle the order of our samples |         shuffle=True,  # shuffle the order of our samples | ||||||
|         batch_size=4,  # stack 4 samples to a minibatch |         batch_size=1,  # stack 4 samples to a minibatch | ||||||
|         num_workers=0  # no background workers (see comment below) |         num_workers=0  # no background workers (see comment below) | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     return train_loader, test_loader |     return train_loader, test_loader | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def rgb2gray(rgb_array: np.ndarray): | ||||||
|  |     r, g, b = rgb_array[:, :, 0], rgb_array[:, :, 1], rgb_array[:, :, 2] | ||||||
|  |     gray = 0.2989 * r + 0.5870 * g + 0.1140 * b | ||||||
|  |     return gray | ||||||
|   | |||||||
| @@ -1,26 +1,48 @@ | |||||||
|  | import numpy as np | ||||||
| import torch | import torch | ||||||
|  |  | ||||||
| from DataLoader import get_image_loader | from DataLoader import get_image_loader | ||||||
| from Net import ImageNN | from Net import ImageNN | ||||||
|  |  | ||||||
|  |  | ||||||
| # 01.05.22 -- 0.5h | # 01.05.22 -- 0.5h | ||||||
| from netio import save_model, load_model | from netio import save_model, load_model | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_train_device(): | ||||||
|  |     device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||||||
|  |     print(f'Device to train net: {device}') | ||||||
|  |     return torch.device(device) | ||||||
|  |  | ||||||
|  |  | ||||||
| def train_model(): | def train_model(): | ||||||
|     train_loader, test_loader = get_image_loader("my/supercool/image/dir") |     # Set a known random seed for reproducibility | ||||||
|     nn = ImageNN()  # todo pass size ason. |     np.random.seed(0) | ||||||
|  |     torch.manual_seed(0) | ||||||
|  |     device = get_train_device() | ||||||
|  |  | ||||||
|  |     # Load datasets | ||||||
|  |     train_loader, test_loader = get_image_loader("training/") | ||||||
|  |     nn = ImageNN(n_in_channels=3)  # todo pass size ason. | ||||||
|     nn.train()  # init with train mode |     nn.train()  # init with train mode | ||||||
|  |     nn.to(device)  # send net to device available | ||||||
|  |  | ||||||
|     optimizer = torch.optim.SGD(nn.parameters(), lr=0.1)  # todo adjust parameters and lr |     optimizer = torch.optim.SGD(nn.parameters(), lr=0.1)  # todo adjust parameters and lr | ||||||
|     loss_function = torch.nn.CrossEntropyLoss() |     loss_function = torch.nn.MSELoss() | ||||||
|     n_epochs = 15  # todo epcchs here |     n_epochs = 15  # todo epcchs here | ||||||
|  |  | ||||||
|  |     # todo look wtf is that | ||||||
|  |     nn.double() | ||||||
|  |  | ||||||
|  |     train_sample_size = len(train_loader) | ||||||
|     losses = [] |     losses = [] | ||||||
|  |     best_eval_loss = 0 | ||||||
|     for epoch in range(n_epochs): |     for epoch in range(n_epochs): | ||||||
|         print(f"Epoch {epoch}/{n_epochs}\n") |         print(f"Epoch {epoch}/{n_epochs}\n") | ||||||
|  |         i = 0 | ||||||
|         for input_tensor, target_tensor in train_loader: |         for input_tensor, target_tensor in train_loader: | ||||||
|  |             input_tensor.to(device) | ||||||
|  |             target_tensor.to(device) | ||||||
|  |  | ||||||
|             output = nn(input_tensor)  # get model output (forward pass) |             output = nn(input_tensor)  # get model output (forward pass) | ||||||
|             loss = loss_function(output, target_tensor)  # compute loss given model output and true target |             loss = loss_function(output, target_tensor)  # compute loss given model output and true target | ||||||
|             loss.backward()  # compute gradients (backward pass) |             loss.backward()  # compute gradients (backward pass) | ||||||
| @@ -28,19 +50,45 @@ def train_model(): | |||||||
|             optimizer.zero_grad()  # reset gradients |             optimizer.zero_grad()  # reset gradients | ||||||
|             losses.append(loss.item()) |             losses.append(loss.item()) | ||||||
|  |  | ||||||
|     # switch net to eval mode |             print( | ||||||
|     nn.eval() |                 f'\rTraining epoch {epoch} [{i}/{train_sample_size * train_loader.batch_size}] (curr loss: {loss.item():.3})', | ||||||
|     with torch.no_grad(): |                 end='') | ||||||
|         for input_tensor, target_tensor in test_loader: |             i += train_loader.batch_size | ||||||
|             # iterate testloader and we have to decide somhow now the goodness of the prediction |  | ||||||
|             out = nn(input_tensor)  # apply model |  | ||||||
|  |  | ||||||
|             diff = out - target_tensor |             # eval model every 500th element | ||||||
|         # todo evaluate trained model on testset loader |             if i % 500 == 0: | ||||||
|  |                 print(f"\nEvaluating model") | ||||||
|     # todo save trained model to blob file |                 eval_loss = eval_model(nn, test_loader, loss_function, device) | ||||||
|  |                 print(f"Evalution loss={eval_loss}") | ||||||
|  |                 if eval_loss > best_eval_loss: | ||||||
|  |                     best_eval_loss = eval_loss | ||||||
|                     save_model(nn) |                     save_model(nn) | ||||||
|  |  | ||||||
|  |     # switch net to eval mode | ||||||
|  |     print(eval_model(nn, test_loader, loss_function, device=device)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # func to evaluate our trained model | ||||||
|  | def eval_model(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn, device: torch.device): | ||||||
|  |     # switch to eval mode | ||||||
|  |     model.eval() | ||||||
|  |     loss = .0 | ||||||
|  |     # disable gradient calculations | ||||||
|  |     with torch.no_grad(): | ||||||
|  |         i = 0 | ||||||
|  |         for input, target in dataloader: | ||||||
|  |             input.to(device) | ||||||
|  |             target.to(device) | ||||||
|  |  | ||||||
|  |             out = model(input) | ||||||
|  |             loss += loss_fn(out, target).item() | ||||||
|  |             print(f'\rEval prog[{i}/{len(dataloader) * dataloader.batch_size}]', end='') | ||||||
|  |             i += dataloader.batch_size | ||||||
|  |     print() | ||||||
|  |     loss /= len(dataloader) | ||||||
|  |     model.train() | ||||||
|  |     return loss | ||||||
|  |  | ||||||
|  |  | ||||||
| def apply_model(): | def apply_model(): | ||||||
|     model = load_model() |     model = load_model() | ||||||
|   | |||||||
							
								
								
									
										31
									
								
								Net.py
									
									
									
									
									
								
							
							
						
						
									
										31
									
								
								Net.py
									
									
									
									
									
								
							| @@ -2,10 +2,31 @@ import torch | |||||||
|  |  | ||||||
|  |  | ||||||
| class ImageNN(torch.nn.Module): | class ImageNN(torch.nn.Module): | ||||||
|     def __init__(self): |     def __init__(self, n_in_channels: int = 1, n_hidden_layers: int = 3, n_kernels: int = 32, kernel_size: int = 7): | ||||||
|  |         """Simple CNN with `n_hidden_layers`, `n_kernels`, and `kernel_size` as hyperparameters""" | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         # todo implement the nn structure |  | ||||||
|  |  | ||||||
|     def forward(self, x: torch.Tensor): |         cnn = [] | ||||||
|         pass |         for i in range(n_hidden_layers): | ||||||
|         # todo implement forward |             cnn.append(torch.nn.Conv2d( | ||||||
|  |                 in_channels=n_in_channels, | ||||||
|  |                 out_channels=n_kernels, | ||||||
|  |                 kernel_size=kernel_size, | ||||||
|  |                 padding=int(kernel_size / 2) | ||||||
|  |             )) | ||||||
|  |             cnn.append(torch.nn.ReLU()) | ||||||
|  |             n_in_channels = n_kernels | ||||||
|  |         self.hidden_layers = torch.nn.Sequential(*cnn) | ||||||
|  |  | ||||||
|  |         self.output_layer = torch.nn.Conv2d( | ||||||
|  |             in_channels=n_in_channels, | ||||||
|  |             out_channels=3, | ||||||
|  |             kernel_size=kernel_size, | ||||||
|  |             padding=int(kernel_size / 2) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def forward(self, x): | ||||||
|  |         """Apply CNN to input `x` of shape (N, n_channels, X, Y), where N=n_samples and X, Y are spatial dimensions""" | ||||||
|  |         cnn_out = self.hidden_layers(x)  # apply hidden layers (N, n_in_channels, X, Y) -> (N, n_kernels, X, Y) | ||||||
|  |         pred = self.output_layer(cnn_out)  # apply output layer (N, n_kernels, X, Y) -> (N, 1, X, Y) | ||||||
|  |         return pred | ||||||
|   | |||||||
							
								
								
									
										159
									
								
								ex2_unittest.py
									
									
									
									
									
								
							
							
						
						
									
										159
									
								
								ex2_unittest.py
									
									
									
									
									
								
							| @@ -1,159 +0,0 @@ | |||||||
| """ |  | ||||||
| Author -- Michael Widrich, Andreas Schörgenhumer |  | ||||||
| Contact -- schoergenhumer@ml.jku.at |  | ||||||
| Date -- 04.03.2022 |  | ||||||
|  |  | ||||||
| ############################################################################### |  | ||||||
|  |  | ||||||
| The following copyright statement applies to all code within this file. |  | ||||||
|  |  | ||||||
| Copyright statement: |  | ||||||
| This  material,  no  matter  whether  in  printed  or  electronic  form, |  | ||||||
| may  be  used  for personal  and non-commercial educational use only. |  | ||||||
| Any reproduction of this manuscript, no matter whether as a whole or in parts, |  | ||||||
| no matter whether in printed or in electronic form, requires explicit prior |  | ||||||
| acceptance of the authors. |  | ||||||
|  |  | ||||||
| ############################################################################### |  | ||||||
|  |  | ||||||
| Images taken from: https://pixabay.com/ |  | ||||||
| """ |  | ||||||
|  |  | ||||||
| import hashlib |  | ||||||
| import os |  | ||||||
| import shutil |  | ||||||
| import sys |  | ||||||
| from glob import glob |  | ||||||
|  |  | ||||||
| import dill as pkl |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def print_outs(outs, line_token="-"): |  | ||||||
|     print(line_token * 40) |  | ||||||
|     print(outs, end="" if isinstance(outs, str) and outs.endswith("\n") else "\n") |  | ||||||
|     print(line_token * 40) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| ex_file = "ex2.py" |  | ||||||
| full_points = 15 |  | ||||||
| points = full_points |  | ||||||
| python = sys.executable |  | ||||||
|  |  | ||||||
| solutions_dir = os.path.join("unittest", "solutions") |  | ||||||
| outputs_dir = os.path.join("unittest", "outputs") |  | ||||||
|  |  | ||||||
| # Remove previous outputs folder |  | ||||||
| shutil.rmtree(outputs_dir, ignore_errors=True) |  | ||||||
|  |  | ||||||
| inputs = sorted(glob(os.path.join("unittest", "unittest_input_*"), recursive=True)) |  | ||||||
| if not len(inputs): |  | ||||||
|     raise FileNotFoundError("Could not find unittest_input_* files") |  | ||||||
|  |  | ||||||
| with open(os.path.join(solutions_dir, "counts.pkl"), "rb") as f: |  | ||||||
|     sol_counts = pkl.load(f) |  | ||||||
|  |  | ||||||
| for test_i, input_folder in enumerate(inputs): |  | ||||||
|     comment = "" |  | ||||||
|     fcall = "" |  | ||||||
|      |  | ||||||
|     with open(os.devnull, "w") as null: |  | ||||||
|         # sys.stdout = null |  | ||||||
|         try: |  | ||||||
|             from ex2 import validate_images |  | ||||||
|              |  | ||||||
|             proper_import = True |  | ||||||
|         except Exception as e: |  | ||||||
|             outs = "" |  | ||||||
|             errs = e |  | ||||||
|             points -= full_points / len(inputs) |  | ||||||
|             proper_import = False |  | ||||||
|         finally: |  | ||||||
|             sys.stdout.flush() |  | ||||||
|             sys.stdout = sys.__stdout__ |  | ||||||
|      |  | ||||||
|     if proper_import: |  | ||||||
|         with open(os.devnull, "w") as null: |  | ||||||
|             # sys.stdout = null |  | ||||||
|             try: |  | ||||||
|                 input_basename = os.path.basename(input_folder) |  | ||||||
|                 output_dir = os.path.join(outputs_dir, input_basename) |  | ||||||
|                 logfilepath = output_dir + ".log" |  | ||||||
|                 formatter = "06d" |  | ||||||
|                 counts = validate_images(input_dir=input_folder, output_dir=output_dir, log_file=logfilepath, |  | ||||||
|                                          formatter=formatter) |  | ||||||
|                 fcall = f'validate_images(\n\tinput_dir="{input_folder}",\n\toutput_dir="{output_dir}",\n\tlog_file="{logfilepath}",\n\tformatter="{formatter}"\n)' |  | ||||||
|                 errs = "" |  | ||||||
|                  |  | ||||||
|                 try: |  | ||||||
|                     with open(os.path.join(outputs_dir, f"{input_basename}.log"), "r") as lfh: |  | ||||||
|                         logfile = lfh.read() |  | ||||||
|                 except FileNotFoundError: |  | ||||||
|                     # two cases: |  | ||||||
|                     # 1) no invalid files and thus no log file -> ok -> equal to empty tlogfile |  | ||||||
|                     # 2) invalid files but no log file -> not ok -> will fail the comparison with tlogfile (below) |  | ||||||
|                     logfile = "" |  | ||||||
|                 with open(os.path.join(solutions_dir, f"{input_basename}.log"), "r") as lfh: |  | ||||||
|                     # must replace the separator that was used when creating the solution files |  | ||||||
|                     tlogfile = lfh.read().replace("\\", os.path.sep) |  | ||||||
|                  |  | ||||||
|                 files = sorted(glob(os.path.join(outputs_dir, input_basename, "**", "*"), recursive=True)) |  | ||||||
|                 hashing_function = hashlib.sha256() |  | ||||||
|                 for file in files: |  | ||||||
|                     with open(file, "rb") as fh: |  | ||||||
|                         hashing_function.update(fh.read()) |  | ||||||
|                 hash = hashing_function.digest() |  | ||||||
|                 hashing_function = hashlib.sha256() |  | ||||||
|                 tfiles = sorted(glob(os.path.join(solutions_dir, input_basename, "**", "*"), recursive=True)) |  | ||||||
|                 for file in tfiles: |  | ||||||
|                     with open(file, "rb") as fh: |  | ||||||
|                         hashing_function.update(fh.read()) |  | ||||||
|                 thash = hashing_function.digest() |  | ||||||
|                  |  | ||||||
|                 tcounts = sol_counts[input_basename] |  | ||||||
|                  |  | ||||||
|                 if not counts == tcounts: |  | ||||||
|                     points -= full_points / len(inputs) |  | ||||||
|                     comment = f"Function should return {tcounts} but returned {counts}" |  | ||||||
|                 elif not [f.split(os.path.sep)[-2:] for f in files] == [f.split(os.path.sep)[-2:] for f in tfiles]: |  | ||||||
|                     points -= full_points / len(inputs) |  | ||||||
|                     comment = f"Contents of output directory do not match (see directory 'solutions')" |  | ||||||
|                 elif not hash == thash: |  | ||||||
|                     points -= full_points / len(inputs) |  | ||||||
|                     comment = f"Hash value of the files in the output directory do not match (see directory 'solutions')" |  | ||||||
|                 elif not logfile == tlogfile: |  | ||||||
|                     points -= full_points / len(inputs) |  | ||||||
|                     comment = f"Contents of logfiles do not match (see directory 'solutions')" |  | ||||||
|              |  | ||||||
|             except Exception as e: |  | ||||||
|                 outs = "" |  | ||||||
|                 errs = e |  | ||||||
|                 points -= full_points / len(inputs) |  | ||||||
|             finally: |  | ||||||
|                 sys.stdout.flush() |  | ||||||
|                 sys.stdout = sys.__stdout__ |  | ||||||
|      |  | ||||||
|     print() |  | ||||||
|     print_outs(f"Test {test_i}", line_token="#") |  | ||||||
|     print("Function call:") |  | ||||||
|     print_outs(fcall) |  | ||||||
|      |  | ||||||
|     if errs: |  | ||||||
|         print(f"Some unexpected errors occurred:") |  | ||||||
|         print_outs(f"{type(errs).__name__}: {errs}") |  | ||||||
|     else: |  | ||||||
|         print("Notes:") |  | ||||||
|         print_outs("No issues found" if comment == "" else comment) |  | ||||||
|      |  | ||||||
|     # due to floating point calculations it could happen that we get -0 here |  | ||||||
|     if points < 0: |  | ||||||
|         assert abs(points) < 1e-7, f"points were {points} < 0: error when subtracting points?" |  | ||||||
|         points = abs(points) |  | ||||||
|     print(f"Current points: {points:.2f}") |  | ||||||
|  |  | ||||||
| print(f"\nEstimated points upon submission: {points:.2f} (out of {full_points:.2f})") |  | ||||||
| if points < full_points: |  | ||||||
|     print(f"Check the folder '{outputs_dir}' to see where your errors are") |  | ||||||
| else: |  | ||||||
|     shutil.rmtree(os.path.join(outputs_dir)) |  | ||||||
| print(f"This is only an estimate, see 'Instructions for submitting homework' in Moodle " |  | ||||||
|       f"for common mistakes that can still lead to 0 points.") |  | ||||||
							
								
								
									
										204
									
								
								ex3_unittest.py
									
									
									
									
									
								
							
							
						
						
									
										204
									
								
								ex3_unittest.py
									
									
									
									
									
								
							| @@ -1,204 +0,0 @@ | |||||||
| """ |  | ||||||
| Author -- Michael Widrich, Andreas Schörgenhumer |  | ||||||
| Contact -- schoergenhumer@ml.jku.at |  | ||||||
| Date -- 02.03.2022 |  | ||||||
|  |  | ||||||
| ############################################################################### |  | ||||||
|  |  | ||||||
| The following copyright statement applies to all code within this file. |  | ||||||
|  |  | ||||||
| Copyright statement: |  | ||||||
| This  material,  no  matter  whether  in  printed  or  electronic  form, |  | ||||||
| may  be  used  for personal  and non-commercial educational use only. |  | ||||||
| Any reproduction of this manuscript, no matter whether as a whole or in parts, |  | ||||||
| no matter whether in printed or in electronic form, requires explicit prior |  | ||||||
| acceptance of the authors. |  | ||||||
|  |  | ||||||
| ############################################################################### |  | ||||||
|  |  | ||||||
| Images taken from: https://pixabay.com/ |  | ||||||
| """ |  | ||||||
|  |  | ||||||
| import gzip |  | ||||||
| import os |  | ||||||
| import signal |  | ||||||
| import sys |  | ||||||
| from glob import glob |  | ||||||
| from types import GeneratorType |  | ||||||
|  |  | ||||||
| import dill as pkl |  | ||||||
| import numpy as np |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def print_outs(outs, line_token="-"): |  | ||||||
|     print(line_token * 40) |  | ||||||
|     print(outs, end="" if isinstance(outs, str) and outs.endswith("\n") else "\n") |  | ||||||
|     print(line_token * 40) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| time_given = int(15) |  | ||||||
|  |  | ||||||
| check_for_timeout = hasattr(signal, "SIGALRM") |  | ||||||
|  |  | ||||||
| if check_for_timeout: |  | ||||||
|     def handler(signum, frame): |  | ||||||
|         raise TimeoutError(f"Timeout after {time_given}sec") |  | ||||||
|      |  | ||||||
|      |  | ||||||
|     signal.signal(signal.SIGALRM, handler) |  | ||||||
|  |  | ||||||
| ex_file = "ex3.py" |  | ||||||
| full_points = 15 |  | ||||||
| points = full_points |  | ||||||
| python = sys.executable |  | ||||||
|  |  | ||||||
| inputs = sorted(glob(os.path.join("unittest", "unittest_input_*"), recursive=True)) |  | ||||||
|  |  | ||||||
| if not len(inputs): |  | ||||||
|     raise FileNotFoundError("Could not find unittest_input_* files") |  | ||||||
|  |  | ||||||
| for test_i, input_folder in enumerate(inputs): |  | ||||||
|     comment = "" |  | ||||||
|     fcall = "" |  | ||||||
|      |  | ||||||
|     with open(os.devnull, "w") as null: |  | ||||||
|         # sys.stdout = null |  | ||||||
|         try: |  | ||||||
|             if check_for_timeout: |  | ||||||
|                 signal.alarm(time_given) |  | ||||||
|                 from ex3 import ImageStandardizer |  | ||||||
|                  |  | ||||||
|                 signal.alarm(0) |  | ||||||
|             else: |  | ||||||
|                 from ex3 import ImageStandardizer |  | ||||||
|             proper_import = True |  | ||||||
|         except Exception as e: |  | ||||||
|             errs = e |  | ||||||
|             points -= full_points / len(inputs) |  | ||||||
|             proper_import = False |  | ||||||
|         finally: |  | ||||||
|             sys.stdout.flush() |  | ||||||
|             sys.stdout = sys.__stdout__ |  | ||||||
|      |  | ||||||
|     if proper_import: |  | ||||||
|         with open(os.devnull, "w") as null: |  | ||||||
|             # sys.stdout = null |  | ||||||
|             try: |  | ||||||
|                 if check_for_timeout: |  | ||||||
|                     signal.alarm(time_given) |  | ||||||
|                     # check constructor |  | ||||||
|                     instance = ImageStandardizer(input_dir=input_folder) |  | ||||||
|                     fcall = f"ImageStandardizer(input_dir='{input_folder}')" |  | ||||||
|                     signal.alarm(0) |  | ||||||
|                 else: |  | ||||||
|                     # check constructor |  | ||||||
|                     instance = ImageStandardizer(input_dir=input_folder) |  | ||||||
|                     fcall = f"ImageStandardizer(input_dir='{input_folder}')" |  | ||||||
|                 errs = "" |  | ||||||
|                  |  | ||||||
|                 # check correct file names + sorting |  | ||||||
|                 input_basename = os.path.basename(input_folder) |  | ||||||
|                 with open(os.path.join("unittest", "solutions", input_basename, f"filenames.txt"), "r") as f: |  | ||||||
|                     # must replace the separator that was used when creating the solution files |  | ||||||
|                     files_sol = f.read().replace("\\", os.path.sep).splitlines() |  | ||||||
|                     # for simplicity's sake, only compare relative paths here |  | ||||||
|                     common = os.path.commonprefix(instance.files) |  | ||||||
|                     rel_instance_files = [os.path.join(input_folder, f[len(common):]) for f in instance.files] |  | ||||||
|                     if not hasattr(instance, "files"): |  | ||||||
|                         points -= full_points / len(inputs) / 3 |  | ||||||
|                         comment += f"Attributes 'files' missing.\n" |  | ||||||
|                     elif rel_instance_files != files_sol: |  | ||||||
|                         points -= full_points / len(inputs) / 3 |  | ||||||
|                         comment += f"Attribute 'files' should be {files_sol} but is {instance.files} (see directory 'solutions').\n" |  | ||||||
|                     elif len(instance.files) != len(files_sol): |  | ||||||
|                         points -= full_points / len(inputs) / 3 |  | ||||||
|                         comment += f"Number of files should be {len(files_sol)} but is {len(instance.files)} (see directory 'solutions').\n" |  | ||||||
|                  |  | ||||||
|                 # check if class has method analyze_images |  | ||||||
|                 method = "analyze_images" |  | ||||||
|                 if not hasattr(instance, method): |  | ||||||
|                     comment += f"Method '{method}' missing.\n" |  | ||||||
|                     points -= full_points / len(inputs) / 3 |  | ||||||
|                 else: |  | ||||||
|                     # check for correct data types |  | ||||||
|                     stats = instance.analyze_images() |  | ||||||
|                     if (type(stats) is not tuple) or (len(stats) != 2): |  | ||||||
|                         points -= full_points / len(inputs) / 3 |  | ||||||
|                         comment += f"Incorrect return value of method '{method}' (should be tuple of length 2).\n" |  | ||||||
|                     else: |  | ||||||
|                         with open(os.path.join("unittest", "solutions", input_basename, f"mean_and_std.pkl"), |  | ||||||
|                                   "rb") as fh: |  | ||||||
|                             data = pkl.load(fh) |  | ||||||
|                             m = data["mean"] |  | ||||||
|                             s = data["std"] |  | ||||||
|                         if not (isinstance(stats[0], np.ndarray) and isinstance(stats[1], np.ndarray) and |  | ||||||
|                                 stats[0].dtype == np.float64 and stats[1].dtype == np.float64 and |  | ||||||
|                                 stats[0].shape == (3,) and stats[1].shape == (3,)): |  | ||||||
|                             points -= full_points / len(inputs) / 3 |  | ||||||
|                             comment += f"Incorrect return data type of method '{method}' (tuple entries should be np.ndarray of dtype np.float64 and shape (3,)).\n" |  | ||||||
|                         else: |  | ||||||
|                             if not np.isclose(stats[0], m, atol=0).all(): |  | ||||||
|                                 points -= full_points / len(inputs) / 6 |  | ||||||
|                                 comment += f"Mean should be {m} but is {stats[0]} (see directory 'solutions').\n" |  | ||||||
|                             if not np.isclose(stats[1], s, atol=0).all(): |  | ||||||
|                                 points -= full_points / len(inputs) / 6 |  | ||||||
|                                 comment += f"Std should be {s} but is {stats[1]} (see directory 'solutions').\n" |  | ||||||
|                  |  | ||||||
|                 # check if class has method get_standardized_images |  | ||||||
|                 method = "get_standardized_images" |  | ||||||
|                 if not hasattr(instance, method): |  | ||||||
|                     comment += f"Method '{method}' missing.\n" |  | ||||||
|                     points -= full_points / len(inputs) / 3 |  | ||||||
|                 # check for correct data types |  | ||||||
|                 elif not isinstance(instance.get_standardized_images(), GeneratorType): |  | ||||||
|                     points -= full_points / len(inputs) / 3 |  | ||||||
|                     comment += f"'{method}' is not a generator.\n" |  | ||||||
|                 else: |  | ||||||
|                     # Read correct image solutions |  | ||||||
|                     with gzip.open(os.path.join("unittest", "solutions", input_basename, "images.pkl"), "rb") as fh: |  | ||||||
|                         ims_sol = pkl.load(file=fh) |  | ||||||
|                      |  | ||||||
|                     # Get image submissions |  | ||||||
|                     ims_sub = list(instance.get_standardized_images()) |  | ||||||
|                      |  | ||||||
|                     if not len(ims_sub) == len(ims_sol): |  | ||||||
|                         points -= full_points / len(inputs) / 3 |  | ||||||
|                         comment += f"{len(ims_sol)} image arrays should have been returned but got {len(ims_sub)}.\n" |  | ||||||
|                     elif any([im_sub.dtype.num != np.dtype(np.float32).num for im_sub in ims_sub]): |  | ||||||
|                         points -= full_points / len(inputs) / 3 |  | ||||||
|                         comment += f"Returned image arrays should have datatype np.float32 but at least one array isn't.\n" |  | ||||||
|                     else: |  | ||||||
|                         equal = [np.all(np.isclose(im_sub, im_sol, atol=0)) for im_sub, im_sol in zip(ims_sub, ims_sol)] |  | ||||||
|                         if not all(equal): |  | ||||||
|                             points -= full_points / len(inputs) / 3 |  | ||||||
|                             comment += f"Returned images {list(np.where(np.logical_not(equal))[0])} do not match solution (see images.pkl files for solution).\n" |  | ||||||
|             except Exception as e: |  | ||||||
|                 errs = e |  | ||||||
|                 points -= full_points / len(inputs) |  | ||||||
|             finally: |  | ||||||
|                 sys.stdout.flush() |  | ||||||
|                 sys.stdout = sys.__stdout__ |  | ||||||
|      |  | ||||||
|     print() |  | ||||||
|     print_outs(f"Test {test_i}", line_token="#") |  | ||||||
|     print("Function call:") |  | ||||||
|     print_outs(fcall) |  | ||||||
|      |  | ||||||
|     if errs: |  | ||||||
|         print(f"Some unexpected errors occurred:") |  | ||||||
|         print_outs(f"{type(errs).__name__}: {errs}") |  | ||||||
|     else: |  | ||||||
|         print("Notes:") |  | ||||||
|         print_outs("No issues found" if comment == "" else comment) |  | ||||||
|          |  | ||||||
|     # due to floating point calculations it could happen that we get -0 here |  | ||||||
|     if points < 0: |  | ||||||
|         assert abs(points) < 1e-7, f"points were {points} < 0: error when subtracting points?" |  | ||||||
|         points = abs(points) |  | ||||||
|     print(f"Current points: {points:.2f}") |  | ||||||
|  |  | ||||||
| print(f"\nEstimated points upon submission: {points:.2f} (out of {full_points:.2f})") |  | ||||||
| print(f"This is only an estimate, see 'Instructions for submitting homework' in Moodle " |  | ||||||
|       f"for common mistakes that can still lead to 0 points.") |  | ||||||
| if not check_for_timeout: |  | ||||||
|     print("\n!!Warning: Had to switch to Windows compatibility version and did not check for timeouts!!") |  | ||||||
							
								
								
									
										143
									
								
								ex4_unittest.py
									
									
									
									
									
								
							
							
						
						
									
										143
									
								
								ex4_unittest.py
									
									
									
									
									
								
							| @@ -1,143 +0,0 @@ | |||||||
| """ |  | ||||||
| Author -- Michael Widrich |  | ||||||
| Contact -- widrich@ml.jku.at |  | ||||||
| Date -- 01.10.2019 |  | ||||||
|  |  | ||||||
| ############################################################################### |  | ||||||
|  |  | ||||||
| The following copyright statement applies to all code within this file. |  | ||||||
|  |  | ||||||
| Copyright statement: |  | ||||||
| This  material,  no  matter  whether  in  printed  or  electronic  form, |  | ||||||
| may  be  used  for personal  and non-commercial educational use only. |  | ||||||
| Any reproduction of this manuscript, no matter whether as a whole or in parts, |  | ||||||
| no matter whether in printed or in electronic form, requires explicit prior |  | ||||||
| acceptance of the authors. |  | ||||||
|  |  | ||||||
| ############################################################################### |  | ||||||
|  |  | ||||||
| """ |  | ||||||
|  |  | ||||||
| import os |  | ||||||
| import sys |  | ||||||
| import traceback |  | ||||||
|  |  | ||||||
| import dill as pkl |  | ||||||
| import numpy as np |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def print_outs(outs, line_token="-"): |  | ||||||
|     print(line_token * 40) |  | ||||||
|     print(outs, end="" if isinstance(outs, str) and outs.endswith("\n") else "\n") |  | ||||||
|     print(line_token * 40) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| ex_file = 'ex4.py' |  | ||||||
| full_points = 20 |  | ||||||
| points = full_points |  | ||||||
| python = sys.executable |  | ||||||
|  |  | ||||||
| with open(os.path.join("unittest", "unittest_inputs_outputs.pkl"), "rb") as ufh: |  | ||||||
|     all_inputs_outputs = pkl.load(ufh) |  | ||||||
|     all_inputs = all_inputs_outputs['inputs'] |  | ||||||
|     all_outputs = all_inputs_outputs['outputs'] |  | ||||||
|  |  | ||||||
| feedback = '' |  | ||||||
|  |  | ||||||
| for test_i, (inputs, outputs) in enumerate(zip(all_inputs, all_outputs)): |  | ||||||
|      |  | ||||||
|     comment = '' |  | ||||||
|     fcall = '' |  | ||||||
|     with open(os.devnull, 'w') as null: |  | ||||||
|         # sys.stdout = null |  | ||||||
|         try: |  | ||||||
|             from ex4 import ex4 |  | ||||||
|             proper_import = True |  | ||||||
|         except Exception: |  | ||||||
|             outs = '' |  | ||||||
|             errs = traceback.format_exc() |  | ||||||
|             points -= full_points / len(all_inputs_outputs) |  | ||||||
|             proper_import = False |  | ||||||
|         finally: |  | ||||||
|             sys.stdout.flush() |  | ||||||
|             sys.stdout = sys.__stdout__ |  | ||||||
|      |  | ||||||
|     if proper_import: |  | ||||||
|         with open(os.devnull, 'w') as null: |  | ||||||
|             # sys.stdout = null |  | ||||||
|             try: |  | ||||||
|                 errs = '' |  | ||||||
|                 fcall = f"ex4(image_array={inputs[0]}, offset={inputs[1]}, spacing={inputs[2]}))" |  | ||||||
|                 returns = ex4(image_array=inputs[0], offset=inputs[1], |  | ||||||
|                               spacing=inputs[2]) |  | ||||||
|  |  | ||||||
|                 # Check if returns and outputs are of same type |  | ||||||
|                 if type(returns) != type(outputs): |  | ||||||
|                     comment = f"Output should be: {type(outputs).__name__} ('{outputs}'). \n" \ |  | ||||||
|                               f"          but is: {returns}" |  | ||||||
|                     points -= full_points / len(all_inputs) |  | ||||||
|                 else: |  | ||||||
|                     # Check input_array output |  | ||||||
|                     if (len(returns) != 3 |  | ||||||
|                             or not isinstance(returns[0], np.ndarray) |  | ||||||
|                             or returns[0].dtype != outputs[0].dtype |  | ||||||
|                             or returns[0].shape != outputs[0].shape |  | ||||||
|                             or np.any(returns[0] != outputs[0])): |  | ||||||
|                         points -= (full_points / len(all_inputs)) / 3 |  | ||||||
|                         comment = f"Incorrect 'input_array'. Output should be: " \ |  | ||||||
|                                 f"{outputs} \n" \ |  | ||||||
|                                 f"but is {returns}" |  | ||||||
|                      |  | ||||||
|                     # Check known_array output |  | ||||||
|                     if (len(returns) != 3 |  | ||||||
|                             or not isinstance(returns[1], np.ndarray) |  | ||||||
|                             or returns[1].dtype != outputs[1].dtype |  | ||||||
|                             or returns[1].shape != outputs[1].shape |  | ||||||
|                             or np.any(returns[1] != outputs[1])): |  | ||||||
|                         points -= (full_points / len(all_inputs)) / 3 |  | ||||||
|                         comment = f"Incorrect 'known_array'. Output should be: " \ |  | ||||||
|                                 f"{outputs} \n" \ |  | ||||||
|                                 f"but is {returns}" |  | ||||||
|                      |  | ||||||
|                     # Check target_array output |  | ||||||
|                     if (len(returns) != 3 |  | ||||||
|                             or not isinstance(returns[2], np.ndarray) |  | ||||||
|                             or returns[2].dtype != outputs[2].dtype |  | ||||||
|                             or returns[2].shape != outputs[2].shape |  | ||||||
|                             or np.any(returns[2] != outputs[2])): |  | ||||||
|                         points -= (full_points / len(all_inputs)) / 3 |  | ||||||
|                         comment = f"Incorrect 'target_array'. Output should be: " \ |  | ||||||
|                                 f"{outputs} \n" \ |  | ||||||
|                                 f"but is {returns}" |  | ||||||
|              |  | ||||||
|             except Exception as e: |  | ||||||
|                 outs = '' |  | ||||||
|                 if not type(e) == type(outputs): |  | ||||||
|                     comment = f"Output should be: {type(outputs).__name__} ('{outputs}'). \n" \ |  | ||||||
|                               f"          but is:\n{traceback.format_exc()}" |  | ||||||
|                     points -= full_points / len(all_inputs) |  | ||||||
|             finally: |  | ||||||
|                 sys.stdout.flush() |  | ||||||
|                 sys.stdout = sys.__stdout__ |  | ||||||
|      |  | ||||||
|     print() |  | ||||||
|     print_outs(f"Test {test_i}", line_token="#") |  | ||||||
|     print("Function call:") |  | ||||||
|     print_outs(fcall) |  | ||||||
|  |  | ||||||
|     if errs: |  | ||||||
|         print(f"Some unexpected errors occurred:") |  | ||||||
|         print_outs(errs) |  | ||||||
|     else: |  | ||||||
|         print("Notes:") |  | ||||||
|         print_outs("No issues found" if comment == "" else comment) |  | ||||||
|  |  | ||||||
|     # due to floating point calculations it could happen that we get -0 here |  | ||||||
|     if points < 0: |  | ||||||
|         assert abs(points) < 1e-7, f"points were {points} < 0: error when subtracting points?" |  | ||||||
|         points = 0 |  | ||||||
|     print(f"Current points: {points:.2f}") |  | ||||||
|  |  | ||||||
| print(f"\nEstimated points upon submission: {points:.2f} (out of {full_points:.2f})") |  | ||||||
| print(f"This is only an estimate, see 'Instructions for submitting homework' in Moodle " |  | ||||||
|       f"for common mistakes that can still lead to 0 points.") |  | ||||||
							
								
								
									
										4
									
								
								netio.py
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								netio.py
									
									
									
									
									
								
							| @@ -4,11 +4,11 @@ from Net import ImageNN | |||||||
|  |  | ||||||
|  |  | ||||||
| def save_model(model: torch.nn.Module): | def save_model(model: torch.nn.Module): | ||||||
|     torch.save(model.state_dict(), 'impaintmodel.pth') |     torch.save(model, 'impaintmodel.pt') | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_model(): | def load_model(): | ||||||
|     model = ImageNN() |     model = ImageNN() | ||||||
|     model.load_state_dict(torch.load('model_weights.pth')) |     model.load_state_dict(torch.load('impaintmodel.pt')) | ||||||
|     model.eval() |     model.eval() | ||||||
|     return model |     return model | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user