From 644b4e69bda963709abc25b0bdbe7827b6791f75 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Wed, 30 Oct 2024 12:14:35 +0100 Subject: [PATCH] Threshold object classifier using PatchStacks passes tests --- model_server/base/roiset.py | 33 ++++++++++++++++++++------------- tests/base/test_accessors.py | 1 + 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 23532707..27e32944 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -1190,19 +1190,26 @@ class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationMo allow_3d: bool = False, connect_3d: bool = True, ) -> GenericImageDataAccessor: - labels = get_label_ids(mask) - df = pd.DataFrame(regionprops_table( - labels.data_yxz, - intensity_image=img.get_mono(self.channel).data_yxz, - properties=('label', 'area', 'intensity_mean', 'bbox') - )) - df['intensity_mean'] > self.tr - - om = np.zeros(labels.shape, labels.dtype) - def _label_object_class(la): - om[labels.data == la] = 1 - df.loc[df['intensity_mean'] > (self.tr * img.dtype_max), 'label'].apply(_label_object_class) - return InMemoryDataAccessor(om) + if isinstance(img, PatchStack): # assume one object per patch + df = img.get_object_df(mask) + om = np.zeros(mask.shape, 'uint16') + def _label_patch_class(la): + om[la] = (mask.iat(la).data > 0) * 1 + df.loc[df['intensity_mean'] > (self.tr * img.dtype_max), 'label'].apply(_label_patch_class) + return PatchStack(om) + else: + labels = get_label_ids(mask) + df = pd.DataFrame(regionprops_table( + labels.data_yxz, + intensity_image=img.get_mono(self.channel).data_yxz, + properties=('label', 'area', 'intensity_mean') + )) + + om = np.zeros(labels.shape, labels.dtype) + def _label_object_class(la): + om[labels.data == la] = 1 + df.loc[df['intensity_mean'] > (self.tr * img.dtype_max), 'label'].apply(_label_object_class) + return InMemoryDataAccessor(om) def label_instance_class( self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs diff --git a/tests/base/test_accessors.py b/tests/base/test_accessors.py index 05663dd7..2c5c052b 100644 --- a/tests/base/test_accessors.py +++ b/tests/base/test_accessors.py @@ -302,6 +302,7 @@ class TestPatchStackAccessor(unittest.TestCase): self.assertEqual(acc.hw, (h, w)) self.assertEqual(acc.get_mono(channel=0).data_yxz.shape, (n, h, w, nz)) self.assertEqual(acc.get_mono(channel=0, mip=True).data_yx.shape, (n, h, w)) + return acc def test_object_df(self): w = 30 -- GitLab