From 426b5b79d7309ac763298344c8c87630e71aca7a Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Fri, 3 May 2024 16:21:58 +0200 Subject: [PATCH] Added static constructor for making RoiSets from segmentation masks --- model_server/base/roiset.py | 19 +++++++++++++++++++ tests/test_roiset.py | 21 ++++++++------------- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index f3c0c947..d289284f 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -152,6 +152,25 @@ class RoiSet(object): """Expose ROI meta information via the Pandas.DataFrame API""" return self._df.itertuples(name='Roi') + @staticmethod + def from_segmentation( + acc_raw: GenericImageDataAccessor, + acc_seg: GenericImageDataAccessor, + allow_3d=False, + connect_3d=True, + params: RoiSetMetaParams = RoiSetMetaParams() + ): + """ + Create a RoiSet from a binary segmentation mask (either 2D or 3D) + :param acc_raw: accessor to a generally a multichannel z-stack + :param acc_seg: accessor of a binary segmentation mask (mono) of either two or three dimensions + :param allow_3d: return a 3D map if True; return a 2D map of the mask's maximum intensity project if False + :param connect_3d: objects can span multiple z-positions if True; objects are unique to a single z if False + :param params: optional arguments that influence the definition and representation of ROIs + :return: object identities map + """ + return RoiSet(acc_raw, _get_label_ids(acc_seg, allow_3d=allow_3d, connect_3d=connect_3d), params) + @staticmethod def make_df(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame: """ diff --git a/tests/test_roiset.py b/tests/test_roiset.py index 47856888..c1d773c6 100644 --- a/tests/test_roiset.py +++ b/tests/test_roiset.py @@ -10,7 +10,7 @@ import pandas as pd from model_server.conf.testing import output_path, roiset_test_data from model_server.base.roiset import RoiSetExportParams, RoiSetMetaParams -from model_server.base.roiset import _get_label_ids, RoiSet +from model_server.base.roiset import RoiSet from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file, PatchStack from model_server.base.models import DummyInstanceSegmentationModel @@ -26,10 +26,9 @@ class BaseTestRoiSetMonoProducts(object): class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def _make_roi_set(self, mask_type='boxes', **kwargs): - id_map = _get_label_ids(self.seg_mask) - roiset = RoiSet( + roiset = RoiSet.from_segmentation( self.stack_ch_pa, - id_map, + self.seg_mask, params=RoiSetMetaParams( mask_type=mask_type, filters=kwargs.get('filters', {'area': {'min': 1e3, 'max': 1e4}}), @@ -67,9 +66,8 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_roiset_from_non_zstacks(self, **kwargs): acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0]) self.assertEqual(acc_zstack_slice.nz, 1) - id_map = _get_label_ids(self.seg_mask) - roiset = RoiSet(acc_zstack_slice, id_map, params=RoiSetMetaParams(mask_type='boxes')) + roiset = RoiSet.from_segmentation(acc_zstack_slice, self.seg_mask, params=RoiSetMetaParams(mask_type='boxes')) zmask = roiset.get_zmask() zmask_acc = InMemoryDataAccessor(zmask) @@ -158,9 +156,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): self.assertEqual(result.shape, roiset.acc_raw.shape) def test_flatten_image(self): - id_map = _get_label_ids(self.seg_mask) - - roiset = RoiSet(self.stack_ch_pa, id_map, params=RoiSetMetaParams(mask_type='boxes')) + roiset = RoiSet.from_segmentation(self.stack_ch_pa, self.seg_mask, params=RoiSetMetaParams(mask_type='boxes')) df = roiset.get_df() from model_server.base.roiset import project_stack_from_focal_points @@ -270,10 +266,9 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa def setUp(self) -> None: super().setUp() - id_map = _get_label_ids(self.seg_mask) - self.roiset = RoiSet( + self.roiset = RoiSet.from_segmentation( self.stack, - id_map, + self.seg_mask, params=RoiSetMetaParams( expand_box_by=(128, 2), mask_type='boxes', @@ -492,7 +487,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa print('res') - +from model_server.base.roiset import _get_label_ids class TestRoiSetSerialization(unittest.TestCase): def setUp(self) -> None: -- GitLab