2022-06-01 14:07:32 +00:00
|
|
|
import torch
|
|
|
|
|
|
|
|
from Net import ImageNN
|
|
|
|
|
|
|
|
|
|
|
|
def save_model(model: torch.nn.Module):
|
2022-06-28 16:28:36 +00:00
|
|
|
torch.save(model, 'impaintmodel.pt')
|
2022-06-01 14:07:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
def load_model():
|
|
|
|
model = ImageNN()
|
2022-06-28 16:28:36 +00:00
|
|
|
model.load_state_dict(torch.load('impaintmodel.pt'))
|
2022-06-01 14:07:32 +00:00
|
|
|
model.eval()
|
|
|
|
return model
|