diff --git a/model_server/base/process.py b/model_server/base/process.py index d475b063ca641922e4c08707bdc2bfc36049f9ea..dac347cd8c2e7a43e1ef82c4c00d9d119fdf303b 100644 --- a/model_server/base/process.py +++ b/model_server/base/process.py @@ -18,7 +18,7 @@ def is_mask(img): return True elif img.dtype == 'uint8': unique = np.unique(img) - if unique.shape[0] == 2 and np.all(unique == [0, 255]): + if unique.shape[0] <= 2 and np.all(unique == [0, 255]): return True return False @@ -136,7 +136,14 @@ def smooth(img: np.ndarray, sig: float) -> np.ndarray: :param sig: threshold parameter :return: smoothed image """ - return gaussian(img, sig) + ga = gaussian(img, sig, preserve_range=True) + if is_mask(img): + if img.dtype == 'bool': + return ga > ga.mean() + elif img.dtype == 'uint8': + return (255 * (ga > ga.mean())).astype('uint8') + else: + return ga class Error(Exception): pass