From 531fac85288fcac6f6ceb6344e40b2049db6b72b Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Wed, 31 Jul 2024 16:48:04 +0200
Subject: [PATCH] Implemented polygon fitting

---
 model_server/base/roiset.py | 39 ++++++++++++++++++++++++++++++++++++-
 tests/base/test_roiset.py   | 26 ++++++++++++++++++++++++-
 2 files changed, 63 insertions(+), 2 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 854a53da..12310bec 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -10,7 +10,8 @@ from pydantic import BaseModel
 from scipy.stats import moment
 from skimage.filters import sobel
 
-from skimage.measure import label, regionprops_table, shannon_entropy
+from skimage.measure import approximate_polygon, find_contours, label, points_in_poly, regionprops, regionprops_table, shannon_entropy
+from skimage.morphology import binary_dilation, disk
 
 from .accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file
 from .models import InstanceSegmentationModel
@@ -180,6 +181,10 @@ def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame
     )
     return df
 
+# TODO: implement
+def make_df_from_polygons(acc_raw, polygons:np.ndarray) -> pd.DataFrame:
+    pass
+
 
 def df_insert_slices(df: pd.DataFrame, sd: dict, expand_box_by) -> pd.DataFrame:
     h = sd['Y']
@@ -347,6 +352,11 @@ class RoiSet(object):
         """
         return RoiSet.from_object_ids(acc_raw, get_label_ids(acc_seg, allow_3d=allow_3d, connect_3d=connect_3d), params)
 
+    @staticmethod
+    #TODO: implement
+    def from_polygons(acc_raw, polygons: np.ndarray):
+        pass
+
 
     # TODO: get overlapping segments
     def get_overlap_seg(self) -> pd.DataFrame:
@@ -764,6 +774,33 @@ class RoiSet(object):
         record['tight_patch_masks'] = list(se_pa)
         return record
 
+    def get_polygons(self, poly_threshold=0, dilation_radius=1):
+        pad_to = 1
+
+        def _poly_from_mask(roi):
+            # mask = generate_file_accessor(roi.mask_path).data[:, :, 0, 0]
+            mask = roi.binary_mask
+
+            # label and fill holes
+            labeled = label(mask)
+            filled = [rp.image_filled for rp in regionprops(labeled)]
+            assert (np.unique(labeled)[-1] == 1) and (len(filled) == 1), 'Cannot fit multiple polygons in a single patch mask'
+
+            closed = binary_dilation(filled[0], footprint=disk(dilation_radius))
+            padded = np.pad(closed, pad_to) * 1.0
+            all_contours = find_contours(padded)
+
+            nc = len(all_contours)
+            for j in range(0, nc):
+                if all([points_in_poly(all_contours[k], all_contours[j]).all() for k in range(0, nc)]):
+                    contour = all_contours[j]
+                    break
+
+            rel_polygon = approximate_polygon(contour[:, [1, 0]], poly_threshold) - [pad_to, pad_to]
+            return rel_polygon + [roi.x0, roi.y0]
+
+        return self._df.apply(_poly_from_mask, axis=1)
+
     # TODO: implement
     def serialize_coco(self, where: Path, prefix='') -> dict:
         """
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index 6264d41e..892445bf 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -1,4 +1,3 @@
-import os
 import re
 import unittest
 
@@ -6,6 +5,7 @@ import numpy as np
 from pathlib import Path
 
 import pandas as pd
+from skimage import draw
 
 from model_server.base.roiset import filter_overlap_bbox, RoiSetExportParams, RoiSetMetaParams
 from model_server.base.roiset import RoiSet
@@ -664,3 +664,27 @@ class TestRoiSetSerialization(unittest.TestCase):
         )
         roiset.serialize_coco(output_path / 'serialize_coco')
         self.assertEqual(1, 0)
+
+class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase):
+
+    def test_compute_polygons(self):
+        roiset = RoiSet.from_binary_mask(
+            self.stack_ch_pa,
+            self.seg_mask,
+            params=RoiSetMetaParams(
+                mask_type='contours',
+                filters={'area': {'min': 1e1, 'max': 1e6}}
+            )
+        )
+
+        poly = roiset.get_polygons()
+        binary_poly = np.zeros(self.seg_mask.hw, dtype=bool)
+        for p in poly:
+            pidcs = draw.polygon(p[:, 1], p[:, 0])
+            binary_poly[pidcs] = True
+
+        test_mask = np.logical_and(
+            np.logical_not(binary_poly),
+            (self.seg_mask.data[:, :, 0, 0] == 255)
+        )
+        self.assertLess(test_mask.sum() / test_mask.size, 0.001)  # most mask pixels are within in fitted polygon
\ No newline at end of file
-- 
GitLab