From 583c415be167c64b1773000ed749ebbada1f1e4f Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Tue, 10 Dec 2024 09:33:15 +0100 Subject: [PATCH] Export products no longer throw exceptions when called on empty RoiSet --- model_server/base/roiset.py | 15 +++++++++++++-- tests/base/test_roiset.py | 8 ++++++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 41edb804..8a1ff62a 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -630,7 +630,9 @@ class RoiSet(object): def add_df_col(self, name, se: pd.Series) -> None: self._df[name] = se - def get_patches_acc(self, channels: list = None, **kwargs) -> PatchStack: # padded, un-annotated 2d patches + def get_patches_acc(self, channels: list = None, **kwargs) -> Union[PatchStack, None]: # padded, un-annotated 2d patches + if self.count == 0: + return None if channels and len(channels) == 1: return PatchStack(list(self.get_patches(white_channel=channels[0], **kwargs))) else: @@ -857,6 +859,10 @@ class RoiSet(object): return patches_df.apply(_export_patch, axis=1) def get_patch_masks(self, pad_to: int = None, expanded: bool = False, make_3d=True) -> pd.DataFrame: + + if self.count == 0: + return pd.DataFrame() + def _make_patch_mask(roi): if expanded: patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8') @@ -877,7 +883,9 @@ class RoiSet(object): dfe['patch_mask'] = dfe.apply(_make_patch_mask, axis=1) return dfe - def get_patch_masks_acc(self, **kwargs) -> PatchStack: + def get_patch_masks_acc(self, **kwargs) -> Union[PatchStack, None]: + if self.count == 0: + return None se_pm = self.get_patch_masks(**kwargs).patch_mask se_ext = se_pm.apply(lambda x: np.expand_dims(x, 2)) return PatchStack(list(se_ext)) @@ -896,6 +904,9 @@ class RoiSet(object): **kwargs ) -> pd.Series: + if self.count == 0: + return pd.Series() + # arrange RGB channels if so specified, otherwise copy roiset.raw_acc data raw = self.acc_raw if isinstance(rgb_overlay_channels, (list, tuple)) and isinstance(rgb_overlay_weights, (list, tuple)): diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index c134aa9d..943fd6a9 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -871,8 +871,12 @@ class TestEmptyRoiSet(unittest.TestCase): def test_classify_by(self): roiset = self.empty_roiset - res = roiset.classify_by('permissive_model', [0], IntensityThresholdInstanceMaskSegmentationModel(tr=0.0)) - self.assertEqual(1, 0) + self.assertFalse('classify_by_permissive_model' in roiset.get_df().columns) + self.assertTrue( + roiset.classify_by('permissive_model', [0], IntensityThresholdInstanceMaskSegmentationModel(tr=0.0)) + ) + self.assertEqual(roiset.count, 0) + self.assertTrue('classify_by_permissive_model' in roiset.get_df().columns) class TestRoiSetObjectDetection(unittest.TestCase): -- GitLab