import torch from Net import ImageNN def save_model(model: torch.nn.Module): torch.save(model.state_dict(), 'impaintmodel.pth') def load_model(): model = ImageNN() model.load_state_dict(torch.load('model_weights.pth')) model.eval() return model