67 lines
1.7 KiB
Python
67 lines
1.7 KiB
Python
import os.path
|
|
from glob import glob
|
|
|
|
import numpy as np
|
|
from PIL import Image, ImageStat
|
|
|
|
|
|
class ImageStandardizer:
|
|
def __init__(self, input_dir: str) -> None:
|
|
super().__init__()
|
|
|
|
# scan input dir for files
|
|
files = glob(input_dir + '/**/*.jpg', recursive=True)
|
|
if len(files) == 0:
|
|
raise ValueError
|
|
|
|
# convert them into absolute paths
|
|
files = [os.path.abspath(x) for x in files]
|
|
|
|
# sort filenames
|
|
files.sort()
|
|
|
|
self.files = files
|
|
self.mean = None
|
|
self.std = None
|
|
|
|
def analyze_images(self) -> (np.array, np.array):
|
|
mymean = np.zeros((3,), dtype=np.float64)
|
|
mystd = np.zeros((3,), dtype=np.float64)
|
|
|
|
for file in self.files:
|
|
img = Image.open(file)
|
|
stats = ImageStat.Stat(img)
|
|
mymean += stats.mean
|
|
mystd += stats.stddev
|
|
|
|
del img
|
|
|
|
mymean /= len(self.files)
|
|
mystd /= len(self.files)
|
|
|
|
self.mean = mymean
|
|
self.std = mystd
|
|
|
|
return mymean, mystd
|
|
|
|
def get_standardized_images(self):
|
|
if self.mean is None or self.std is None:
|
|
raise ValueError
|
|
|
|
for file in self.files:
|
|
img = Image.open(file)
|
|
arr = np.asarray(img.getdata(), dtype=np.float32).reshape(
|
|
(img.height, img.width, 3)) # and reshape into 3channel rgb image
|
|
# standardize image
|
|
arr = (arr - self.mean) / self.std
|
|
yield np.array(arr, dtype=np.float32)
|
|
|
|
|
|
# Press the green button in the gutter to run the script.
|
|
if __name__ == '__main__':
|
|
std = ImageStandardizer(input_dir='unittest/unittest_input_0')
|
|
print(std.analyze_images())
|
|
|
|
for i in std.get_standardized_images():
|
|
print(i)
|