Skip to content
Snippets Groups Projects
Commit 93d843b5 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Merge branch 'dev_watershed_centroids' into 'staging'

Added static constructor for making RoiSets from segmentation masks

See merge request !41
parents cf492a3e 426b5b79
No related branches found
No related tags found
2 merge requests!50Release 2024.06.03,!41Added static constructor for making RoiSets from segmentation masks
......@@ -152,6 +152,25 @@ class RoiSet(object):
"""Expose ROI meta information via the Pandas.DataFrame API"""
return self._df.itertuples(name='Roi')
@staticmethod
def from_segmentation(
acc_raw: GenericImageDataAccessor,
acc_seg: GenericImageDataAccessor,
allow_3d=False,
connect_3d=True,
params: RoiSetMetaParams = RoiSetMetaParams()
):
"""
Create a RoiSet from a binary segmentation mask (either 2D or 3D)
:param acc_raw: accessor to a generally a multichannel z-stack
:param acc_seg: accessor of a binary segmentation mask (mono) of either two or three dimensions
:param allow_3d: return a 3D map if True; return a 2D map of the mask's maximum intensity project if False
:param connect_3d: objects can span multiple z-positions if True; objects are unique to a single z if False
:param params: optional arguments that influence the definition and representation of ROIs
:return: object identities map
"""
return RoiSet(acc_raw, _get_label_ids(acc_seg, allow_3d=allow_3d, connect_3d=connect_3d), params)
@staticmethod
def make_df(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame:
"""
......
......@@ -10,7 +10,7 @@ import pandas as pd
from model_server.conf.testing import output_path, roiset_test_data
from model_server.base.roiset import RoiSetExportParams, RoiSetMetaParams
from model_server.base.roiset import _get_label_ids, RoiSet
from model_server.base.roiset import RoiSet
from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file, PatchStack
from model_server.base.models import DummyInstanceSegmentationModel
......@@ -26,10 +26,9 @@ class BaseTestRoiSetMonoProducts(object):
class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
def _make_roi_set(self, mask_type='boxes', **kwargs):
id_map = _get_label_ids(self.seg_mask)
roiset = RoiSet(
roiset = RoiSet.from_segmentation(
self.stack_ch_pa,
id_map,
self.seg_mask,
params=RoiSetMetaParams(
mask_type=mask_type,
filters=kwargs.get('filters', {'area': {'min': 1e3, 'max': 1e4}}),
......@@ -67,9 +66,8 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
def test_roiset_from_non_zstacks(self, **kwargs):
acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0])
self.assertEqual(acc_zstack_slice.nz, 1)
id_map = _get_label_ids(self.seg_mask)
roiset = RoiSet(acc_zstack_slice, id_map, params=RoiSetMetaParams(mask_type='boxes'))
roiset = RoiSet.from_segmentation(acc_zstack_slice, self.seg_mask, params=RoiSetMetaParams(mask_type='boxes'))
zmask = roiset.get_zmask()
zmask_acc = InMemoryDataAccessor(zmask)
......@@ -158,9 +156,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
self.assertEqual(result.shape, roiset.acc_raw.shape)
def test_flatten_image(self):
id_map = _get_label_ids(self.seg_mask)
roiset = RoiSet(self.stack_ch_pa, id_map, params=RoiSetMetaParams(mask_type='boxes'))
roiset = RoiSet.from_segmentation(self.stack_ch_pa, self.seg_mask, params=RoiSetMetaParams(mask_type='boxes'))
df = roiset.get_df()
from model_server.base.roiset import project_stack_from_focal_points
......@@ -270,10 +266,9 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
def setUp(self) -> None:
super().setUp()
id_map = _get_label_ids(self.seg_mask)
self.roiset = RoiSet(
self.roiset = RoiSet.from_segmentation(
self.stack,
id_map,
self.seg_mask,
params=RoiSetMetaParams(
expand_box_by=(128, 2),
mask_type='boxes',
......@@ -492,7 +487,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
print('res')
from model_server.base.roiset import _get_label_ids
class TestRoiSetSerialization(unittest.TestCase):
def setUp(self) -> None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment