ImageImpaint_Python_II/netio.py

47 lines
1.3 KiB
Python
Raw Normal View History

import pickle
import sys
2022-06-01 14:07:32 +00:00
import numpy as np
import torch
import Compress
import DataLoader
2022-06-01 14:07:32 +00:00
from Net import ImageNN
MODEL_PATH = 'impaintmodel.pt'
PICKEL_PATH = 'impaintmodel.pkl'
2022-06-01 14:07:32 +00:00
def save_model(model: torch.nn.Module):
print(f"Saved raw model to {MODEL_PATH}")
torch.save(model, MODEL_PATH)
# read the provided testing pickle file
print("Generating pickle file with privided test data")
model.eval()
with open('testing/inputs.pkl', 'rb') as handle:
with open(PICKEL_PATH, 'wb') as writehandle:
b: dict = pickle.load(handle)
outarr = []
i=0
piclen = len(b['input_arrays'])
for pic in b['input_arrays']:
pic = DataLoader.preprocess(pic)
out = model(torch.from_numpy(pic))
out = DataLoader.postprocess(out.detach().numpy())
pickle.dump(out, writehandle, protocol=pickle.HIGHEST_PROTOCOL)
print(
f'\rApplying model [{i}/{piclen}] {sys.getsizeof(outarr)}',end='')
i += 1
# compress the generated pickle arr
Compress.compress(PICKEL_PATH)
2022-06-01 14:07:32 +00:00
def load_model():
model = ImageNN()
model.load_state_dict(torch.load(MODEL_PATH))
2022-06-01 14:07:32 +00:00
model.eval()
return model