import torch from Net import ImageNN def save_model(model: torch.nn.Module): torch.save(model, 'impaintmodel.pt') def load_model(): model = ImageNN() model.load_state_dict(torch.load('impaintmodel.pt')) model.eval() return model