From c02b572af5e4080e6c44ec1488c2b0be282f8ef4 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 30 Oct 2023 15:06:51 +0100
Subject: [PATCH] Separate method to mask out all but largest object

---
 extensions/chaeo/process.py            | 38 ++++++++++++++++++++++++++
 extensions/chaeo/tests/test_process.py | 21 ++++++++++++++
 2 files changed, 59 insertions(+)
 create mode 100644 extensions/chaeo/process.py
 create mode 100644 extensions/chaeo/tests/test_process.py

diff --git a/extensions/chaeo/process.py b/extensions/chaeo/process.py
new file mode 100644
index 00000000..368f254b
--- /dev/null
+++ b/extensions/chaeo/process.py
@@ -0,0 +1,38 @@
+import numpy as np
+from skimage.measure import label, regionprops_table
+
+
+def mask_largest_object(
+        img: np.ndarray,
+        max_allowed: int = 10,
+        verbose: bool = True
+) -> np.ndarray:
+    """
+    Where more than one connected component is found in an image, return the largest object by area
+    :param img: (np.ndarray) containing object labels
+    :param max_allowed: raise an error if more than this number of objects is found
+    :param verbose: print a message each time more than one object is found
+    :return: np.ndarray of same size as img
+    """
+    binary = img > 0
+    ob_id = label(binary)
+    num_obj = len(np.unique(ob_id)) - 1
+    if num_obj > max_allowed:
+        raise TooManyObjectError(f'Found {num_obj} objects in frame')
+    if num_obj > 1:
+        if verbose:
+            print(f'Found {num_obj} nonzero unique values in object map; keeping the one with the largest area')
+        pr = regionprops_table(ob_id, properties=['label', 'area'])
+        idx_max_area = pr['area'].argmax()
+        mask = ob_id == pr['label'][idx_max_area]
+        return mask * img
+    else:
+        return img
+
+
+class Error(Exception):
+    pass
+
+
+class TooManyObjectError(Exception):
+    pass
diff --git a/extensions/chaeo/tests/test_process.py b/extensions/chaeo/tests/test_process.py
new file mode 100644
index 00000000..06f42c35
--- /dev/null
+++ b/extensions/chaeo/tests/test_process.py
@@ -0,0 +1,21 @@
+import unittest
+
+import numpy as np
+
+from extensions.chaeo.process import mask_largest_object
+
+class TestMaskLargestObject(unittest.TestCase):
+     def test_mask_largest_object(self):
+         arr = np.zeros([5, 5])
+         arr[0:3, 0:3] = 2
+         arr[4, 2:5] = 4
+         masked = mask_largest_object(arr)
+         self.assertTrue(np.all(np.unique(masked) == [0, 2]))
+         self.assertTrue(np.all(masked[4:5, :] == 0))
+         self.assertTrue(np.all(masked[:, 4:5] == 0))
+
+     def test_no_change(self):
+        arr = np.zeros([5, 5])
+        arr[0:3, 0:3] = 2
+        masked = mask_largest_object(arr)
+        self.assertTrue(np.all(masked == arr))
-- 
GitLab