rm wrong file
edit net a bit
This commit is contained in:
		@@ -80,17 +80,17 @@ def get_image_loader(path: str, precision: np.float32 or np.float64):
 | 
				
			|||||||
                                 transform_chain=transforms.Compose([transforms.RandomHorizontalFlip(),
 | 
					                                 transform_chain=transforms.Compose([transforms.RandomHorizontalFlip(),
 | 
				
			||||||
                                                                     transforms.RandomVerticalFlip()]),
 | 
					                                                                     transforms.RandomVerticalFlip()]),
 | 
				
			||||||
                                 precision=precision)
 | 
					                                 precision=precision)
 | 
				
			||||||
 | 
					    #
 | 
				
			||||||
    image_dataset_augmented = ImageDataset(path,
 | 
					    # image_dataset_augmented = ImageDataset(path,
 | 
				
			||||||
                                           offsetrange=(0, 8),
 | 
					    #                                        offsetrange=(0, 8),
 | 
				
			||||||
                                           spacingrange=(2, 6),
 | 
					    #                                        spacingrange=(2, 6),
 | 
				
			||||||
                                           transform_chain=transforms.Compose([transforms.RandomHorizontalFlip(),
 | 
					    #                                        transform_chain=transforms.Compose([transforms.RandomHorizontalFlip(),
 | 
				
			||||||
                                                                               transforms.RandomVerticalFlip(),
 | 
					    #                                                                            transforms.RandomVerticalFlip(),
 | 
				
			||||||
                                                                               transforms.GaussianBlur(3, 4)]),
 | 
					    #                                                                            transforms.GaussianBlur(3, 4)]),
 | 
				
			||||||
                                           precision=precision)
 | 
					    #                                        precision=precision)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # merge different datasets here!
 | 
					    # merge different datasets here!
 | 
				
			||||||
    merged_dataset = torch.utils.data.ConcatDataset([image_dataset, image_dataset_augmented])
 | 
					    merged_dataset = torch.utils.data.ConcatDataset([image_dataset])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    totlen = len(merged_dataset)
 | 
					    totlen = len(merged_dataset)
 | 
				
			||||||
    test_set_size = .1
 | 
					    test_set_size = .1
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -48,13 +48,22 @@ def train_model():
 | 
				
			|||||||
        print(f"Epoch {epoch}/{n_epochs}\n")
 | 
					        print(f"Epoch {epoch}/{n_epochs}\n")
 | 
				
			||||||
        i = 0
 | 
					        i = 0
 | 
				
			||||||
        for input_tensor, mask, target_tensor in train_loader:
 | 
					        for input_tensor, mask, target_tensor in train_loader:
 | 
				
			||||||
 | 
					            input_tensor = input_tensor.to(device)
 | 
				
			||||||
 | 
					            mask = mask.to(device)
 | 
				
			||||||
 | 
					            target_tensor = target_tensor.to(device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            optimizer.zero_grad()  # reset gradients
 | 
					            optimizer.zero_grad()  # reset gradients
 | 
				
			||||||
            input_tensor = torch.cat((input_tensor, mask), 1)
 | 
					            input_tensor = torch.cat((input_tensor, mask), 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            output = nn(input_tensor.to(device))  # get model output (forward pass)
 | 
					            output = nn(input_tensor)  # get model output (forward pass)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            loss = loss_function(output.to(device),
 | 
					            output_flat = output * (1 - mask)
 | 
				
			||||||
                                 target_tensor.to(device))  # compute loss given model output and true target
 | 
					            output_flat = output_flat[1 - mask > 0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            rest = target_tensor * (1 - mask)
 | 
				
			||||||
 | 
					            rest = rest[1 - mask > 0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            loss = loss_function(output_flat, rest)  # compute loss given model output and true target
 | 
				
			||||||
            loss.backward()  # compute gradients (backward pass)
 | 
					            loss.backward()  # compute gradients (backward pass)
 | 
				
			||||||
            optimizer.step()  # perform gradient descent update step
 | 
					            optimizer.step()  # perform gradient descent update step
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -101,7 +110,14 @@ def eval_model(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader,
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            input = torch.cat((input, mask), 1)
 | 
					            input = torch.cat((input, mask), 1)
 | 
				
			||||||
            out = model(input)
 | 
					            out = model(input)
 | 
				
			||||||
            loss += loss_fn(out, target).item()
 | 
					
 | 
				
			||||||
 | 
					            out = out * (1 - mask)
 | 
				
			||||||
 | 
					            out = out[1 - mask > 0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            rest = target * (1 - mask)
 | 
				
			||||||
 | 
					            rest = rest[1 - mask > 0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            loss += loss_fn(out, rest).item()
 | 
				
			||||||
            print(f'\rEval prog[{i}/{len(dataloader) * dataloader.batch_size}]', end='')
 | 
					            print(f'\rEval prog[{i}/{len(dataloader) * dataloader.batch_size}]', end='')
 | 
				
			||||||
            i += dataloader.batch_size
 | 
					            i += dataloader.batch_size
 | 
				
			||||||
    print()
 | 
					    print()
 | 
				
			||||||
@@ -112,9 +128,9 @@ def eval_model(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader,
 | 
				
			|||||||
def plot(inputs, targets, predictions, path, update, epoch):
 | 
					def plot(inputs, targets, predictions, path, update, epoch):
 | 
				
			||||||
    """Plotting the inputs, targets and predictions to file `path`"""
 | 
					    """Plotting the inputs, targets and predictions to file `path`"""
 | 
				
			||||||
    os.makedirs(path, exist_ok=True)
 | 
					    os.makedirs(path, exist_ok=True)
 | 
				
			||||||
    fig, axes = plt.subplots(ncols=3, figsize=(15, 5))
 | 
					    fig, axes = plt.subplots(ncols=4, figsize=(15, 5))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for ax, data, title in zip(axes, [inputs, targets, predictions], ["Input", "Target", "Prediction"]):
 | 
					    for ax, data, title in zip(axes, [inputs, targets, predictions, predictions-targets], ["Input", "Target", "Prediction", "diff"]):
 | 
				
			||||||
        ax.clear()
 | 
					        ax.clear()
 | 
				
			||||||
        ax.set_title(title)
 | 
					        ax.set_title(title)
 | 
				
			||||||
        ax.imshow(DataLoader.postprocess(np.transpose(data[:3, :, :], (1, 2, 0))), interpolation="none")
 | 
					        ax.imshow(DataLoader.postprocess(np.transpose(data[:3, :, :], (1, 2, 0))), interpolation="none")
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										12
									
								
								Net.py
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								Net.py
									
									
									
									
									
								
							@@ -7,18 +7,24 @@ from torch import nn
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class ImageNN(torch.nn.Module):
 | 
					class ImageNN(torch.nn.Module):
 | 
				
			||||||
    def __init__(self, precision: np.float32 or np.float64, n_in_channels: int = 1, n_hidden_layers: int = 3,
 | 
					    def __init__(self, precision: np.float32 or np.float64, n_in_channels: int = 1, n_hidden_layers: int = 3,
 | 
				
			||||||
                 n_kernels: int = 32, kernel_size: int = 7):
 | 
					                 n_kernels: int = 32, kernel_size: int = 9):
 | 
				
			||||||
        """Simple CNN with `n_hidden_layers`, `n_kernels`, and `kernel_size` as hyperparameters"""
 | 
					        """Simple CNN with `n_hidden_layers`, `n_kernels`, and `kernel_size` as hyperparameters"""
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ksize = kernel_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        cnn = []
 | 
					        cnn = []
 | 
				
			||||||
        for i in range(n_hidden_layers):
 | 
					        for i in range(n_hidden_layers):
 | 
				
			||||||
            cnn.append(torch.nn.Conv2d(
 | 
					            cnn.append(torch.nn.Conv2d(
 | 
				
			||||||
                in_channels=n_in_channels,
 | 
					                in_channels=n_in_channels,
 | 
				
			||||||
                out_channels=n_kernels,
 | 
					                out_channels=n_kernels,
 | 
				
			||||||
                kernel_size=kernel_size,
 | 
					                kernel_size=ksize,
 | 
				
			||||||
                padding=int(kernel_size / 2)
 | 
					                padding=int(ksize / 2),
 | 
				
			||||||
 | 
					                padding_mode="replicate"
 | 
				
			||||||
            ))
 | 
					            ))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            kernel_size -= 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            cnn.append(torch.nn.ReLU())
 | 
					            cnn.append(torch.nn.ReLU())
 | 
				
			||||||
            n_in_channels = n_kernels
 | 
					            n_in_channels = n_kernels
 | 
				
			||||||
        self.hidden_layers = torch.nn.Sequential(*cnn)
 | 
					        self.hidden_layers = torch.nn.Sequential(*cnn)
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										5
									
								
								netio.py
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								netio.py
									
									
									
									
									
								
							@@ -36,7 +36,10 @@ def eval_evalset():
 | 
				
			|||||||
        piclen = len(b['input_arrays'])
 | 
					        piclen = len(b['input_arrays'])
 | 
				
			||||||
        for input_array, known_array in zip(b['input_arrays'], b['known_arrays']):
 | 
					        for input_array, known_array in zip(b['input_arrays'], b['known_arrays']):
 | 
				
			||||||
            input_array = DataLoader.preprocess(input_array, precision=np.float32)
 | 
					            input_array = DataLoader.preprocess(input_array, precision=np.float32)
 | 
				
			||||||
            input_tensor = torch.cat((torch.from_numpy(input_array), torch.from_numpy(known_array)), 0)
 | 
					
 | 
				
			||||||
 | 
					            input_array = np.expand_dims(input_array, 0)
 | 
				
			||||||
 | 
					            known_array = np.expand_dims(known_array, 0)
 | 
				
			||||||
 | 
					            input_tensor = torch.cat((torch.from_numpy(input_array), torch.from_numpy(known_array)), 1)
 | 
				
			||||||
            out = model(input_tensor)
 | 
					            out = model(input_tensor)
 | 
				
			||||||
            out = DataLoader.postprocess(out.cpu().detach().numpy())
 | 
					            out = DataLoader.postprocess(out.cpu().detach().numpy())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user