rm wrong file
edit net a bit
This commit is contained in:
parent
0f0c789981
commit
3b4ef675ad
@ -80,17 +80,17 @@ def get_image_loader(path: str, precision: np.float32 or np.float64):
|
|||||||
transform_chain=transforms.Compose([transforms.RandomHorizontalFlip(),
|
transform_chain=transforms.Compose([transforms.RandomHorizontalFlip(),
|
||||||
transforms.RandomVerticalFlip()]),
|
transforms.RandomVerticalFlip()]),
|
||||||
precision=precision)
|
precision=precision)
|
||||||
|
#
|
||||||
image_dataset_augmented = ImageDataset(path,
|
# image_dataset_augmented = ImageDataset(path,
|
||||||
offsetrange=(0, 8),
|
# offsetrange=(0, 8),
|
||||||
spacingrange=(2, 6),
|
# spacingrange=(2, 6),
|
||||||
transform_chain=transforms.Compose([transforms.RandomHorizontalFlip(),
|
# transform_chain=transforms.Compose([transforms.RandomHorizontalFlip(),
|
||||||
transforms.RandomVerticalFlip(),
|
# transforms.RandomVerticalFlip(),
|
||||||
transforms.GaussianBlur(3, 4)]),
|
# transforms.GaussianBlur(3, 4)]),
|
||||||
precision=precision)
|
# precision=precision)
|
||||||
|
|
||||||
# merge different datasets here!
|
# merge different datasets here!
|
||||||
merged_dataset = torch.utils.data.ConcatDataset([image_dataset, image_dataset_augmented])
|
merged_dataset = torch.utils.data.ConcatDataset([image_dataset])
|
||||||
|
|
||||||
totlen = len(merged_dataset)
|
totlen = len(merged_dataset)
|
||||||
test_set_size = .1
|
test_set_size = .1
|
||||||
|
@ -48,13 +48,22 @@ def train_model():
|
|||||||
print(f"Epoch {epoch}/{n_epochs}\n")
|
print(f"Epoch {epoch}/{n_epochs}\n")
|
||||||
i = 0
|
i = 0
|
||||||
for input_tensor, mask, target_tensor in train_loader:
|
for input_tensor, mask, target_tensor in train_loader:
|
||||||
|
input_tensor = input_tensor.to(device)
|
||||||
|
mask = mask.to(device)
|
||||||
|
target_tensor = target_tensor.to(device)
|
||||||
|
|
||||||
optimizer.zero_grad() # reset gradients
|
optimizer.zero_grad() # reset gradients
|
||||||
input_tensor = torch.cat((input_tensor, mask), 1)
|
input_tensor = torch.cat((input_tensor, mask), 1)
|
||||||
|
|
||||||
output = nn(input_tensor.to(device)) # get model output (forward pass)
|
output = nn(input_tensor) # get model output (forward pass)
|
||||||
|
|
||||||
loss = loss_function(output.to(device),
|
output_flat = output * (1 - mask)
|
||||||
target_tensor.to(device)) # compute loss given model output and true target
|
output_flat = output_flat[1 - mask > 0]
|
||||||
|
|
||||||
|
rest = target_tensor * (1 - mask)
|
||||||
|
rest = rest[1 - mask > 0]
|
||||||
|
|
||||||
|
loss = loss_function(output_flat, rest) # compute loss given model output and true target
|
||||||
loss.backward() # compute gradients (backward pass)
|
loss.backward() # compute gradients (backward pass)
|
||||||
optimizer.step() # perform gradient descent update step
|
optimizer.step() # perform gradient descent update step
|
||||||
|
|
||||||
@ -101,7 +110,14 @@ def eval_model(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader,
|
|||||||
|
|
||||||
input = torch.cat((input, mask), 1)
|
input = torch.cat((input, mask), 1)
|
||||||
out = model(input)
|
out = model(input)
|
||||||
loss += loss_fn(out, target).item()
|
|
||||||
|
out = out * (1 - mask)
|
||||||
|
out = out[1 - mask > 0]
|
||||||
|
|
||||||
|
rest = target * (1 - mask)
|
||||||
|
rest = rest[1 - mask > 0]
|
||||||
|
|
||||||
|
loss += loss_fn(out, rest).item()
|
||||||
print(f'\rEval prog[{i}/{len(dataloader) * dataloader.batch_size}]', end='')
|
print(f'\rEval prog[{i}/{len(dataloader) * dataloader.batch_size}]', end='')
|
||||||
i += dataloader.batch_size
|
i += dataloader.batch_size
|
||||||
print()
|
print()
|
||||||
@ -112,9 +128,9 @@ def eval_model(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader,
|
|||||||
def plot(inputs, targets, predictions, path, update, epoch):
|
def plot(inputs, targets, predictions, path, update, epoch):
|
||||||
"""Plotting the inputs, targets and predictions to file `path`"""
|
"""Plotting the inputs, targets and predictions to file `path`"""
|
||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
fig, axes = plt.subplots(ncols=3, figsize=(15, 5))
|
fig, axes = plt.subplots(ncols=4, figsize=(15, 5))
|
||||||
|
|
||||||
for ax, data, title in zip(axes, [inputs, targets, predictions], ["Input", "Target", "Prediction"]):
|
for ax, data, title in zip(axes, [inputs, targets, predictions, predictions-targets], ["Input", "Target", "Prediction", "diff"]):
|
||||||
ax.clear()
|
ax.clear()
|
||||||
ax.set_title(title)
|
ax.set_title(title)
|
||||||
ax.imshow(DataLoader.postprocess(np.transpose(data[:3, :, :], (1, 2, 0))), interpolation="none")
|
ax.imshow(DataLoader.postprocess(np.transpose(data[:3, :, :], (1, 2, 0))), interpolation="none")
|
||||||
|
12
Net.py
12
Net.py
@ -7,18 +7,24 @@ from torch import nn
|
|||||||
|
|
||||||
class ImageNN(torch.nn.Module):
|
class ImageNN(torch.nn.Module):
|
||||||
def __init__(self, precision: np.float32 or np.float64, n_in_channels: int = 1, n_hidden_layers: int = 3,
|
def __init__(self, precision: np.float32 or np.float64, n_in_channels: int = 1, n_hidden_layers: int = 3,
|
||||||
n_kernels: int = 32, kernel_size: int = 7):
|
n_kernels: int = 32, kernel_size: int = 9):
|
||||||
"""Simple CNN with `n_hidden_layers`, `n_kernels`, and `kernel_size` as hyperparameters"""
|
"""Simple CNN with `n_hidden_layers`, `n_kernels`, and `kernel_size` as hyperparameters"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
ksize = kernel_size
|
||||||
|
|
||||||
cnn = []
|
cnn = []
|
||||||
for i in range(n_hidden_layers):
|
for i in range(n_hidden_layers):
|
||||||
cnn.append(torch.nn.Conv2d(
|
cnn.append(torch.nn.Conv2d(
|
||||||
in_channels=n_in_channels,
|
in_channels=n_in_channels,
|
||||||
out_channels=n_kernels,
|
out_channels=n_kernels,
|
||||||
kernel_size=kernel_size,
|
kernel_size=ksize,
|
||||||
padding=int(kernel_size / 2)
|
padding=int(ksize / 2),
|
||||||
|
padding_mode="replicate"
|
||||||
))
|
))
|
||||||
|
|
||||||
|
kernel_size -= 2
|
||||||
|
|
||||||
cnn.append(torch.nn.ReLU())
|
cnn.append(torch.nn.ReLU())
|
||||||
n_in_channels = n_kernels
|
n_in_channels = n_kernels
|
||||||
self.hidden_layers = torch.nn.Sequential(*cnn)
|
self.hidden_layers = torch.nn.Sequential(*cnn)
|
||||||
|
5
netio.py
5
netio.py
@ -36,7 +36,10 @@ def eval_evalset():
|
|||||||
piclen = len(b['input_arrays'])
|
piclen = len(b['input_arrays'])
|
||||||
for input_array, known_array in zip(b['input_arrays'], b['known_arrays']):
|
for input_array, known_array in zip(b['input_arrays'], b['known_arrays']):
|
||||||
input_array = DataLoader.preprocess(input_array, precision=np.float32)
|
input_array = DataLoader.preprocess(input_array, precision=np.float32)
|
||||||
input_tensor = torch.cat((torch.from_numpy(input_array), torch.from_numpy(known_array)), 0)
|
|
||||||
|
input_array = np.expand_dims(input_array, 0)
|
||||||
|
known_array = np.expand_dims(known_array, 0)
|
||||||
|
input_tensor = torch.cat((torch.from_numpy(input_array), torch.from_numpy(known_array)), 1)
|
||||||
out = model(input_tensor)
|
out = model(input_tensor)
|
||||||
out = DataLoader.postprocess(out.cpu().detach().numpy())
|
out = DataLoader.postprocess(out.cpu().detach().numpy())
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user