import pickle import sys import numpy as np import torch import Compress import DataLoader from Net import ImageNN MODEL_PATH = 'impaintmodel.pt' PICKEL_PATH = 'impaintmodel.pkl' def save_model(model: torch.nn.Module): print(f"Saved raw model to {MODEL_PATH}") torch.save(model, MODEL_PATH) # read the provided testing pickle file print("Generating pickle file with privided test data") model.eval() with open('testing/inputs.pkl', 'rb') as handle: with open(PICKEL_PATH, 'wb') as writehandle: b: dict = pickle.load(handle) outarr = [] i=0 piclen = len(b['input_arrays']) for pic in b['input_arrays']: pic = DataLoader.preprocess(pic) out = model(torch.from_numpy(pic)) out = DataLoader.postprocess(out.detach().numpy()) pickle.dump(out, writehandle, protocol=pickle.HIGHEST_PROTOCOL) print( f'\rApplying model [{i}/{piclen}] {sys.getsizeof(outarr)}',end='') i += 1 # compress the generated pickle arr Compress.compress(PICKEL_PATH) def load_model(): model = ImageNN() model.load_state_dict(torch.load(MODEL_PATH)) model.eval() return model