ImageImpaint_Python_II/netio.py

15 lines
270 B
Python

import torch
from Net import ImageNN
def save_model(model: torch.nn.Module):
torch.save(model.state_dict(), 'impaintmodel.pth')
def load_model():
model = ImageNN()
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
return model