diff --git a/DataLoader.py b/DataLoader.py index 5d71143..4d15e40 100644 --- a/DataLoader.py +++ b/DataLoader.py @@ -80,17 +80,17 @@ def get_image_loader(path: str, precision: np.float32 or np.float64): transform_chain=transforms.Compose([transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip()]), precision=precision) - - image_dataset_augmented = ImageDataset(path, - offsetrange=(0, 8), - spacingrange=(2, 6), - transform_chain=transforms.Compose([transforms.RandomHorizontalFlip(), - transforms.RandomVerticalFlip(), - transforms.GaussianBlur(3, 4)]), - precision=precision) + # + # image_dataset_augmented = ImageDataset(path, + # offsetrange=(0, 8), + # spacingrange=(2, 6), + # transform_chain=transforms.Compose([transforms.RandomHorizontalFlip(), + # transforms.RandomVerticalFlip(), + # transforms.GaussianBlur(3, 4)]), + # precision=precision) # merge different datasets here! - merged_dataset = torch.utils.data.ConcatDataset([image_dataset, image_dataset_augmented]) + merged_dataset = torch.utils.data.ConcatDataset([image_dataset]) totlen = len(merged_dataset) test_set_size = .1 diff --git a/ImageImpaint.py b/ImageImpaint.py index 65fcff4..fe89343 100644 --- a/ImageImpaint.py +++ b/ImageImpaint.py @@ -48,13 +48,22 @@ def train_model(): print(f"Epoch {epoch}/{n_epochs}\n") i = 0 for input_tensor, mask, target_tensor in train_loader: + input_tensor = input_tensor.to(device) + mask = mask.to(device) + target_tensor = target_tensor.to(device) + optimizer.zero_grad() # reset gradients input_tensor = torch.cat((input_tensor, mask), 1) - output = nn(input_tensor.to(device)) # get model output (forward pass) + output = nn(input_tensor) # get model output (forward pass) - loss = loss_function(output.to(device), - target_tensor.to(device)) # compute loss given model output and true target + output_flat = output * (1 - mask) + output_flat = output_flat[1 - mask > 0] + + rest = target_tensor * (1 - mask) + rest = rest[1 - mask > 0] + + loss = loss_function(output_flat, rest) # compute loss given model output and true target loss.backward() # compute gradients (backward pass) optimizer.step() # perform gradient descent update step @@ -101,7 +110,14 @@ def eval_model(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, input = torch.cat((input, mask), 1) out = model(input) - loss += loss_fn(out, target).item() + + out = out * (1 - mask) + out = out[1 - mask > 0] + + rest = target * (1 - mask) + rest = rest[1 - mask > 0] + + loss += loss_fn(out, rest).item() print(f'\rEval prog[{i}/{len(dataloader) * dataloader.batch_size}]', end='') i += dataloader.batch_size print() @@ -112,9 +128,9 @@ def eval_model(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, def plot(inputs, targets, predictions, path, update, epoch): """Plotting the inputs, targets and predictions to file `path`""" os.makedirs(path, exist_ok=True) - fig, axes = plt.subplots(ncols=3, figsize=(15, 5)) + fig, axes = plt.subplots(ncols=4, figsize=(15, 5)) - for ax, data, title in zip(axes, [inputs, targets, predictions], ["Input", "Target", "Prediction"]): + for ax, data, title in zip(axes, [inputs, targets, predictions, predictions-targets], ["Input", "Target", "Prediction", "diff"]): ax.clear() ax.set_title(title) ax.imshow(DataLoader.postprocess(np.transpose(data[:3, :, :], (1, 2, 0))), interpolation="none") diff --git a/Net.py b/Net.py index 60dd649..074e8c0 100644 --- a/Net.py +++ b/Net.py @@ -7,18 +7,24 @@ from torch import nn class ImageNN(torch.nn.Module): def __init__(self, precision: np.float32 or np.float64, n_in_channels: int = 1, n_hidden_layers: int = 3, - n_kernels: int = 32, kernel_size: int = 7): + n_kernels: int = 32, kernel_size: int = 9): """Simple CNN with `n_hidden_layers`, `n_kernels`, and `kernel_size` as hyperparameters""" super().__init__() + ksize = kernel_size + cnn = [] for i in range(n_hidden_layers): cnn.append(torch.nn.Conv2d( in_channels=n_in_channels, out_channels=n_kernels, - kernel_size=kernel_size, - padding=int(kernel_size / 2) + kernel_size=ksize, + padding=int(ksize / 2), + padding_mode="replicate" )) + + kernel_size -= 2 + cnn.append(torch.nn.ReLU()) n_in_channels = n_kernels self.hidden_layers = torch.nn.Sequential(*cnn) diff --git a/netio.py b/netio.py index 6da3a45..dba1105 100644 --- a/netio.py +++ b/netio.py @@ -36,7 +36,10 @@ def eval_evalset(): piclen = len(b['input_arrays']) for input_array, known_array in zip(b['input_arrays'], b['known_arrays']): input_array = DataLoader.preprocess(input_array, precision=np.float32) - input_tensor = torch.cat((torch.from_numpy(input_array), torch.from_numpy(known_array)), 0) + + input_array = np.expand_dims(input_array, 0) + known_array = np.expand_dims(known_array, 0) + input_tensor = torch.cat((torch.from_numpy(input_array), torch.from_numpy(known_array)), 1) out = model(input_tensor) out = DataLoader.postprocess(out.cpu().detach().numpy())