ImageImpaint_Python_II/netio.py
2022-07-11 23:38:51 +02:00

66 lines
1.8 KiB
Python

import os
import pickle
import numpy as np
import torch
import Compress
import DataLoader
from Net import ImageNN
MODEL_PATH = 'impaintmodel.pt'
PICKEL_PATH = 'impaintmodel.pkl'
def save_model(model: torch.nn.Module):
print(f"Saved raw model to {MODEL_PATH}")
torch.save(model, MODEL_PATH)
dummy_input = torch.randn(1, 6, 100, 100)
torch.onnx.export(model, dummy_input, MODEL_PATH + ".onnx", verbose=False, opset_version=11)
def eval_evalset():
# read the provided testing pickle file
print("Generating pickle file with privided test data")
try:
os.unlink(PICKEL_PATH)
except:
pass
model = load_model()
model.eval()
with open('testing/inputs.pkl', 'rb') as handle:
b: dict = pickle.load(handle)
outarr = ()
i = 0
piclen = len(b['input_arrays'])
for input_array, known_array in zip(b['input_arrays'], b['known_arrays']):
input_array = DataLoader.preprocess(input_array, precision=np.float32)
input_array = np.expand_dims(input_array, 0)
known_array = np.expand_dims(known_array, 0)
input_tensor = torch.cat((torch.from_numpy(input_array), torch.from_numpy(known_array)), 1)
out = model(input_tensor)
out = DataLoader.postprocess(out.cpu().detach().numpy())
rest = out * (1 - known_array)
rest = rest[1 - known_array > 0]
outarr = (*outarr, rest)
print(f'\rApplying model [{i}/{piclen}]', end='')
i += 1
write_to_pickle(PICKEL_PATH, list(outarr))
# compress the generated pickle arr
Compress.compress(PICKEL_PATH)
def write_to_pickle(filename: str, data):
with open(filename, 'wb') as handle:
pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
def load_model():
return torch.load(MODEL_PATH)