2022-06-29 15:20:16 +00:00
|
|
|
import pickle
|
|
|
|
import sys
|
2022-06-01 14:07:32 +00:00
|
|
|
|
2022-06-29 15:20:16 +00:00
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
import Compress
|
|
|
|
import DataLoader
|
2022-06-01 14:07:32 +00:00
|
|
|
from Net import ImageNN
|
|
|
|
|
2022-06-29 15:20:16 +00:00
|
|
|
MODEL_PATH = 'impaintmodel.pt'
|
|
|
|
PICKEL_PATH = 'impaintmodel.pkl'
|
|
|
|
|
2022-06-01 14:07:32 +00:00
|
|
|
|
|
|
|
def save_model(model: torch.nn.Module):
|
2022-06-29 15:20:16 +00:00
|
|
|
print(f"Saved raw model to {MODEL_PATH}")
|
|
|
|
torch.save(model, MODEL_PATH)
|
|
|
|
|
|
|
|
# read the provided testing pickle file
|
|
|
|
print("Generating pickle file with privided test data")
|
|
|
|
model.eval()
|
|
|
|
with open('testing/inputs.pkl', 'rb') as handle:
|
|
|
|
with open(PICKEL_PATH, 'wb') as writehandle:
|
|
|
|
b: dict = pickle.load(handle)
|
|
|
|
outarr = []
|
|
|
|
i=0
|
|
|
|
piclen = len(b['input_arrays'])
|
|
|
|
for pic in b['input_arrays']:
|
|
|
|
pic = DataLoader.preprocess(pic)
|
|
|
|
out = model(torch.from_numpy(pic))
|
|
|
|
out = DataLoader.postprocess(out.detach().numpy())
|
|
|
|
pickle.dump(out, writehandle, protocol=pickle.HIGHEST_PROTOCOL)
|
|
|
|
|
|
|
|
print(
|
|
|
|
f'\rApplying model [{i}/{piclen}] {sys.getsizeof(outarr)}',end='')
|
|
|
|
i += 1
|
|
|
|
|
|
|
|
# compress the generated pickle arr
|
|
|
|
Compress.compress(PICKEL_PATH)
|
2022-06-01 14:07:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
def load_model():
|
|
|
|
model = ImageNN()
|
2022-06-29 15:20:16 +00:00
|
|
|
model.load_state_dict(torch.load(MODEL_PATH))
|
2022-06-01 14:07:32 +00:00
|
|
|
model.eval()
|
|
|
|
return model
|
2022-06-29 15:20:16 +00:00
|
|
|
|