ImageImpaint_Python_II/netio.py

15 lines
254 B
Python

import torch
from Net import ImageNN
def save_model(model: torch.nn.Module):
torch.save(model, 'impaintmodel.pt')
def load_model():
model = ImageNN()
model.load_state_dict(torch.load('impaintmodel.pt'))
model.eval()
return model