15 lines
254 B
Python
15 lines
254 B
Python
import torch
|
|
|
|
from Net import ImageNN
|
|
|
|
|
|
def save_model(model: torch.nn.Module):
|
|
torch.save(model, 'impaintmodel.pt')
|
|
|
|
|
|
def load_model():
|
|
model = ImageNN()
|
|
model.load_state_dict(torch.load('impaintmodel.pt'))
|
|
model.eval()
|
|
return model
|