diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index d437683cd0c929bc2111f0e9652fd9209189bed1..1e39496b072d43a5e0cd31878ea51974dc580a86 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -123,6 +123,7 @@ def _safe_add(a, g, b): class RoiSet(object): + # TODO: __init__ to take bounding boxes e.g. from obj det model; flag if overlaps are allowed def __init__( self, acc_raw: GenericImageDataAccessor, @@ -157,8 +158,9 @@ class RoiSet(object): """Expose ROI meta information via the Pandas.DataFrame API""" return self._df.itertuples(name='Roi') + # TODO: add or overload for object detection case @staticmethod - def from_segmentation( + def from_binary_mask( acc_raw: GenericImageDataAccessor, acc_seg: GenericImageDataAccessor, allow_3d=False, @@ -247,6 +249,8 @@ class RoiSet(object): axis=1, result_type='reduce', ) + + # TODO: make this contingent on whether seg is included df['binary_mask'] = df.apply( lambda r: (acc_obj_ids.data == r.label).max(axis=-1)[r.y0: r.y1, r.x0: r.x1, 0], axis=1, @@ -389,6 +393,7 @@ class RoiSet(object): self._df['classify_by_' + name] = pd.Series(dtype='Int64') + # TODO: separate method to get object map # assign labels to object map: for i, roi in enumerate(self): oc = np.unique( @@ -673,6 +678,7 @@ class RoiSet(object): record['tight_patch_masks'] = list(se_pa) return record + # TODO: implement def serialize_coco(self, where: Path, prefix='') -> dict: """ Export the RoiSet according to the COCO seg standard @@ -695,6 +701,8 @@ class RoiSet(object): return {} + # TODO: add docstring + # TODO: make this work with obj det dataset @staticmethod def deserialize(acc_raw: GenericImageDataAccessor, where: Path, prefix=''): df = pd.read_csv(where / 'dataframe' / (prefix + '.csv'))[['label', 'zi', 'y0', 'y1', 'x0', 'x1']] diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index cff0ec010748c6e8b9d6f7aa5234460691fc3f20..beff2a7369a041226c8cc2b122be97b371e27b63 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -29,7 +29,7 @@ class BaseTestRoiSetMonoProducts(object): class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def _make_roi_set(self, mask_type='boxes', **kwargs): - roiset = RoiSet.from_segmentation( + roiset = RoiSet.from_binary_mask( self.stack_ch_pa, self.seg_mask, params=RoiSetMetaParams( @@ -70,7 +70,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0]) self.assertEqual(acc_zstack_slice.nz, 1) - roiset = RoiSet.from_segmentation(acc_zstack_slice, self.seg_mask, params=RoiSetMetaParams(mask_type='boxes')) + roiset = RoiSet.from_binary_mask(acc_zstack_slice, self.seg_mask, params=RoiSetMetaParams(mask_type='boxes')) zmask = roiset.get_zmask() zmask_acc = InMemoryDataAccessor(zmask) @@ -164,7 +164,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): self.assertEqual(result.shape, roiset.acc_raw.shape) def test_flatten_image(self): - roiset = RoiSet.from_segmentation(self.stack_ch_pa, self.seg_mask, params=RoiSetMetaParams(mask_type='boxes')) + roiset = RoiSet.from_binary_mask(self.stack_ch_pa, self.seg_mask, params=RoiSetMetaParams(mask_type='boxes')) df = roiset.get_df() from model_server.base.roiset import project_stack_from_focal_points @@ -205,7 +205,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): return roiset def test_classify_by_multiple_channels(self): - roiset = RoiSet.from_segmentation(self.stack, self.seg_mask) + roiset = RoiSet.from_binary_mask(self.stack, self.seg_mask) roiset.classify_by('dummy_class', [0, 1], DummyInstanceSegmentationModel()) self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1])) self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1])) @@ -216,7 +216,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def infer(self, img, mask): return PatchStack(super().infer(img, mask).data * img.chroma) - roiset = RoiSet.from_segmentation( + roiset = RoiSet.from_binary_mask( self.stack, self.seg_mask, params=RoiSetMetaParams( @@ -282,7 +282,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa def setUp(self) -> None: super().setUp() - self.roiset = RoiSet.from_segmentation( + self.roiset = RoiSet.from_binary_mask( self.stack, self.seg_mask, params=RoiSetMetaParams( @@ -617,7 +617,7 @@ class TestRoiSetSerialization(unittest.TestCase): self.assertTrue(np.all(r_acc.data == t_acc.data)) def test_serialize_coco(self): - roiset = RoiSet.from_segmentation( + roiset = RoiSet.from_binary_mask( self.stack_ch_pa, self.seg_mask_3d, params=RoiSetMetaParams(mask_type='contours')