implement basic structure of project
This commit is contained in:
parent
24302f1c35
commit
2966df9a39
31
DataLoader.py
Normal file
31
DataLoader.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import glob
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
||||||
|
class ImageDataset(Dataset):
|
||||||
|
def __init__(self, image_dir):
|
||||||
|
self.image_files = sorted(glob.glob(os.path.join(image_dir, "**", "*.jpg"), recursive=True))
|
||||||
|
# Mean and std arrays could also be defined as class attributes
|
||||||
|
self.norm_mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
||||||
|
self.norm_std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
# Open image file, convert to numpy array and scale to [0, 1]
|
||||||
|
image = np.array(Image.open(self.image_files[index]), dtype=np.float32) / 255
|
||||||
|
# Perform normalization for each channel
|
||||||
|
image = (image - self.norm_mean) / self.norm_std
|
||||||
|
return image, index
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.image_files)
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_loader(path: str):
|
||||||
|
image_dataset = ImageDataset(path)
|
||||||
|
image_loader = DataLoader(image_dataset, shuffle=True, batch_size=10)
|
||||||
|
return image_loader
|
32
ImageImpaint.py
Normal file
32
ImageImpaint.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from DataLoader import get_image_loader
|
||||||
|
from Net import ImageNN
|
||||||
|
|
||||||
|
|
||||||
|
def train_model():
|
||||||
|
image_loader = get_image_loader("my/supercool/image/dir")
|
||||||
|
# todo split to train and test (maybe evaluation sets)
|
||||||
|
nn = ImageNN() # todo pass size ason.
|
||||||
|
|
||||||
|
optimizer = torch.optim.SGD(nn.parameters(), lr=0.1) # todo adjust parameters and lr
|
||||||
|
loss_function = torch.nn.CrossEntropyLoss()
|
||||||
|
n_epochs = 15 # todo epcchs here
|
||||||
|
|
||||||
|
# Training
|
||||||
|
losses = []
|
||||||
|
for epoch in range(n_epochs):
|
||||||
|
for input_tensor, target_tensor in image_loader:
|
||||||
|
output = nn(input_tensor) # get model output (forward pass)
|
||||||
|
loss = loss_function(output, target_tensor) # compute loss given model output and true target
|
||||||
|
loss.backward() # compute gradients (backward pass)
|
||||||
|
optimizer.step() # perform gradient descent update step
|
||||||
|
optimizer.zero_grad() # reset gradients
|
||||||
|
losses.append(loss.item())
|
||||||
|
|
||||||
|
# todo evaluate trained model
|
||||||
|
# todo save trained model to blob file
|
||||||
|
|
||||||
|
|
||||||
|
def apply_model():
|
||||||
|
pass
|
11
Net.py
Normal file
11
Net.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNN(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
# todo implement the nn structure
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
pass
|
||||||
|
# todo implement forward
|
Loading…
Reference in New Issue
Block a user