ImageImpaint_Python_II/netio.py

15 lines
254 B
Python
Raw Normal View History

2022-06-01 14:07:32 +00:00
import torch
from Net import ImageNN
def save_model(model: torch.nn.Module):
2022-06-28 16:28:36 +00:00
torch.save(model, 'impaintmodel.pt')
2022-06-01 14:07:32 +00:00
def load_model():
model = ImageNN()
2022-06-28 16:28:36 +00:00
model.load_state_dict(torch.load('impaintmodel.pt'))
2022-06-01 14:07:32 +00:00
model.eval()
return model