ImageImpaint_Python_II/netio.py

55 lines
1.4 KiB
Python
Raw Normal View History

2022-07-01 13:35:12 +00:00
import os
import pickle
2022-06-01 14:07:32 +00:00
import numpy as np
import torch
import Compress
import DataLoader
2022-06-01 14:07:32 +00:00
from Net import ImageNN
MODEL_PATH = 'impaintmodel.pt'
PICKEL_PATH = 'impaintmodel.pkl'
2022-06-01 14:07:32 +00:00
def save_model(model: torch.nn.Module):
print(f"Saved raw model to {MODEL_PATH}")
torch.save(model, MODEL_PATH)
2022-07-01 13:35:12 +00:00
def eval_evalset():
# read the provided testing pickle file
print("Generating pickle file with privided test data")
2022-07-01 13:35:12 +00:00
try:
os.unlink(PICKEL_PATH)
except:
pass
model = load_model()
model.eval()
with open('testing/inputs.pkl', 'rb') as handle:
2022-07-01 13:35:12 +00:00
b: dict = pickle.load(handle)
outarr = np.zeros(dtype=np.uint8, shape=(len(b['input_arrays']), 3, 100, 100))
i = 0
piclen = len(b['input_arrays'])
for pic in b['input_arrays']:
pic = DataLoader.preprocess(pic, precision=np.float32)
out = model(torch.from_numpy(pic))
out = DataLoader.postprocess(out.cpu().detach().numpy())
outarr[i] = out
print(f'\rApplying model [{i}/{piclen}]', end='')
i += 1
write_to_pickle(PICKEL_PATH, list(outarr))
# compress the generated pickle arr
Compress.compress(PICKEL_PATH)
2022-06-01 14:07:32 +00:00
2022-07-01 13:35:12 +00:00
def write_to_pickle(filename: str, data):
with open(filename, 'wb') as handle:
pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
2022-07-01 13:35:12 +00:00
def load_model():
return torch.load(MODEL_PATH)