47 lines
		
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			47 lines
		
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import pickle
 | |
| import sys
 | |
| 
 | |
| import numpy as np
 | |
| import torch
 | |
| import Compress
 | |
| import DataLoader
 | |
| from Net import ImageNN
 | |
| 
 | |
| MODEL_PATH = 'impaintmodel.pt'
 | |
| PICKEL_PATH = 'impaintmodel.pkl'
 | |
| 
 | |
| 
 | |
| def save_model(model: torch.nn.Module):
 | |
|     print(f"Saved raw model to {MODEL_PATH}")
 | |
|     torch.save(model, MODEL_PATH)
 | |
| 
 | |
|     # read the provided testing pickle file
 | |
|     print("Generating pickle file with privided test data")
 | |
|     model.eval()
 | |
|     with open('testing/inputs.pkl', 'rb') as handle:
 | |
|         with open(PICKEL_PATH, 'wb') as writehandle:
 | |
|             b: dict = pickle.load(handle)
 | |
|             outarr = []
 | |
|             i=0
 | |
|             piclen = len(b['input_arrays'])
 | |
|             for pic in b['input_arrays']:
 | |
|                 pic = DataLoader.preprocess(pic)
 | |
|                 out = model(torch.from_numpy(pic))
 | |
|                 out = DataLoader.postprocess(out.detach().numpy())
 | |
|                 pickle.dump(out, writehandle, protocol=pickle.HIGHEST_PROTOCOL)
 | |
| 
 | |
|                 print(
 | |
|                     f'\rApplying model [{i}/{piclen}] {sys.getsizeof(outarr)}',end='')
 | |
|                 i += 1
 | |
| 
 | |
|     # compress the generated pickle arr
 | |
|     Compress.compress(PICKEL_PATH)
 | |
| 
 | |
| 
 | |
| def load_model():
 | |
|     model = ImageNN()
 | |
|     model.load_state_dict(torch.load(MODEL_PATH))
 | |
|     model.eval()
 | |
|     return model
 | |
| 
 |