diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 3fdc59abe0ed849732fb3349c125ec43697465ce..d7ee2e2f7b48b2a4418f4ace92515dcf322a1f7b 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -200,7 +200,7 @@ class RoiSet(object): df['slice'] = df.apply( lambda r: - np.s_[int(r.y0): int(r.y1), int(r.x0): int(r.x1), :, int(r.zi): int(r.zi + 1) + 1], + np.s_[int(r.y0): int(r.y1), int(r.x0): int(r.x1), :, int(r.zi): int(r.zi + 1)], axis=1 ) df['expanded_slice'] = df.apply( @@ -279,14 +279,12 @@ class RoiSet(object): projected = self.acc_raw.data.max(axis=-1) return projected - # TODO: rename get_patches_acc - def get_raw_patches(self, channel=None, pad_to=256, make_3d=False): # padded, un-annotated 2d patches + def get_patches_acc(self, channel=None, **kwargs): # padded, un-annotated 2d patches if channel: - patches_df = self.get_patches(white_channel=channel, pad_to=pad_to) + patches_df = self.get_patches(white_channel=channel, **kwargs) else: - patches_df = self.get_patches(pad_to=pad_to) - patches = list(patches_df['patch']) - return PatchStack(patches) + patches_df = self.get_patches(**kwargs) + return PatchStack(list(patches_df.patch)) def export_annotated_zstack(self, where, prefix='zstack', **kwargs) -> Path: annotated = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kwargs)) @@ -335,8 +333,8 @@ class RoiSet(object): # do this on a patch basis, i.e. only one object per frame obmap_patches = object_classification_model.label_patch_stack( - self.get_raw_patches(channel=channel, pad_to=256), - self.get_patch_masks_acc(expanded=True, pad_to=256) + self.get_patches_acc(channel=channel, expaned=False, pad_to=None), + self.get_patch_masks_acc(expanded=False, pad_to=None) ) om = np.zeros(self.acc_obj_ids.shape, self.acc_obj_ids.dtype) @@ -419,13 +417,13 @@ class RoiSet(object): def get_patches( self, rescale_clip: float = 0.0, - pad_to: int = 256, + pad_to: int = None, make_3d: bool = False, focus_metric: str = None, rgb_overlay_channels: list = None, rgb_overlay_weights: list = [1.0, 1.0, 1.0], white_channel: int = None, - expanded=True, + expanded=False, **kwargs ) -> pd.DataFrame: @@ -503,6 +501,12 @@ class RoiSet(object): assert len(patch.shape) == 4 + mask = np.zeros(patch3d.shape[0:2], dtype=bool) + if expanded: + mask[roi.relative_slice[0:2]] = roi.binary_mask + else: + mask = roi.binary_mask + if rescale_clip is not None: patch = rescale(patch, rescale_clip) @@ -520,15 +524,15 @@ class RoiSet(object): if kwargs.get('draw_mask'): mci = kwargs.get('mask_channel', 0) - mask = np.zeros(patch.shape[0:2], dtype=bool) - mask[roi.relative_slice[0:2]] = roi.binary_mask + # mask = np.zeros(patch.shape[0:2], dtype=bool) + # mask[roi.relative_slice[0:2]] = roi.binary_mask for zi in range(0, patch.shape[3]): patch[:, :, mci, zi] = np.invert(mask) * patch[:, :, mci, zi] if kwargs.get('draw_contour'): mci = kwargs.get('contour_channel', 0) - mask = np.zeros(patch.shape[0:2], dtype=bool) - mask[roi.relative_slice[0:2]] = roi.binary_mask + # mask = np.zeros(patch.shape[0:2], dtype=bool) + # mask[roi.relative_slice[0:2]] = roi.binary_mask for zi in range(0, patch.shape[3]): patch[:, :, mci, zi] = draw_contours_on_patch( diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py index 15bca41de198bfcd5c4fa714140c5fc8f1f331b0..d042e57150413c18ca36a2bbb8cc379a122bdff5 100644 --- a/model_server/extensions/ilastik/tests/test_ilastik.py +++ b/model_server/extensions/ilastik/tests/test_ilastik.py @@ -329,8 +329,8 @@ class TestIlastikObjectClassification(unittest.TestCase): def test_classify_patches(self): - raw_patches = self.roiset.get_raw_patches() - patch_masks = self.roiset.get_patch_masks() + raw_patches = self.roiset.get_patches_acc() + patch_masks = self.roiset.get_patch_masks_acc() res_patches = self.object_classifier.label_instance_class(raw_patches, patch_masks) self.assertEqual(res_patches.count, self.roiset.count) res_patches.export_pyxcz(output_path / 'res_patches.tif') diff --git a/tests/test_roiset.py b/tests/test_roiset.py index 87883687cfa1cc9b5539840bfaddd81f1451943e..714c20aed3fccb9a37bb53ed973d50a488080b39 100644 --- a/tests/test_roiset.py +++ b/tests/test_roiset.py @@ -105,7 +105,8 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): files = roiset.export_patches( output_path / 'expanded_2d_patches', draw_bounding_box=True, - expanded=True + expanded=True, + pad_to=256, ) df = roiset.get_df() for f in files: @@ -113,9 +114,6 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): la = int(re.search(r'la([\d]+)', str(f)).group(1)) roi_q = df.loc[df.label == la, :] self.assertEqual(len(roi_q), 1) - roi = roi_q.iloc[0] - h = int(roi.ebb_y1 - roi.ebb_y0) - w = int(roi.ebb_x1 - roi.ebb_x0) self.assertEqual((256, 256), acc.hw) def test_make_tight_2d_patches(self): @@ -201,10 +199,11 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_raw_patches_are_correct_shape(self): roiset = self._make_roi_set() - patches = roiset.get_raw_patches() + patches = roiset.get_patches_acc() np, h, w, nc, nz = patches.shape self.assertEqual(np, roiset.count) self.assertEqual(nc, roiset.acc_raw.chroma) + self.assertEqual(nz, 1) def test_patch_masks_are_correct_shape(self): roiset = self._make_roi_set()