diff --git a/extensions/chaeo/process.py b/extensions/chaeo/process.py new file mode 100644 index 0000000000000000000000000000000000000000..368f254b00d318db0d88e44ec83836e85029b349 --- /dev/null +++ b/extensions/chaeo/process.py @@ -0,0 +1,38 @@ +import numpy as np +from skimage.measure import label, regionprops_table + + +def mask_largest_object( + img: np.ndarray, + max_allowed: int = 10, + verbose: bool = True +) -> np.ndarray: + """ + Where more than one connected component is found in an image, return the largest object by area + :param img: (np.ndarray) containing object labels + :param max_allowed: raise an error if more than this number of objects is found + :param verbose: print a message each time more than one object is found + :return: np.ndarray of same size as img + """ + binary = img > 0 + ob_id = label(binary) + num_obj = len(np.unique(ob_id)) - 1 + if num_obj > max_allowed: + raise TooManyObjectError(f'Found {num_obj} objects in frame') + if num_obj > 1: + if verbose: + print(f'Found {num_obj} nonzero unique values in object map; keeping the one with the largest area') + pr = regionprops_table(ob_id, properties=['label', 'area']) + idx_max_area = pr['area'].argmax() + mask = ob_id == pr['label'][idx_max_area] + return mask * img + else: + return img + + +class Error(Exception): + pass + + +class TooManyObjectError(Exception): + pass diff --git a/extensions/chaeo/tests/test_process.py b/extensions/chaeo/tests/test_process.py new file mode 100644 index 0000000000000000000000000000000000000000..06f42c355365819f05248bf38f5c6dc8b2ecb53e --- /dev/null +++ b/extensions/chaeo/tests/test_process.py @@ -0,0 +1,21 @@ +import unittest + +import numpy as np + +from extensions.chaeo.process import mask_largest_object + +class TestMaskLargestObject(unittest.TestCase): + def test_mask_largest_object(self): + arr = np.zeros([5, 5]) + arr[0:3, 0:3] = 2 + arr[4, 2:5] = 4 + masked = mask_largest_object(arr) + self.assertTrue(np.all(np.unique(masked) == [0, 2])) + self.assertTrue(np.all(masked[4:5, :] == 0)) + self.assertTrue(np.all(masked[:, 4:5] == 0)) + + def test_no_change(self): + arr = np.zeros([5, 5]) + arr[0:3, 0:3] = 2 + masked = mask_largest_object(arr) + self.assertTrue(np.all(masked == arr))