15 lines
270 B
Python
15 lines
270 B
Python
|
import torch
|
||
|
|
||
|
from Net import ImageNN
|
||
|
|
||
|
|
||
|
def save_model(model: torch.nn.Module):
|
||
|
torch.save(model.state_dict(), 'impaintmodel.pth')
|
||
|
|
||
|
|
||
|
def load_model():
|
||
|
model = ImageNN()
|
||
|
model.load_state_dict(torch.load('model_weights.pth'))
|
||
|
model.eval()
|
||
|
return model
|