diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index d33243e277f83ebaef26ce316e40b2873963bba3..d0b99080122dbb6f518ecd8913e65dc9ac48223e 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -417,6 +417,7 @@ class RoiSet(object): rgb_overlay_channels: list = None, rgb_overlay_weights: list = [1.0, 1.0, 1.0], white_channel: int = None, + expanded=True, **kwargs ) -> pd.DataFrame: @@ -463,9 +464,14 @@ class RoiSet(object): stack = raw.data def _make_patch(roi): - patch3d = stack[roi.expanded_slice] + if expanded: + patch3d = stack[roi.expanded_slice] + subpatch = patch3d[roi.relative_slice] + else: + patch3d = stack[roi.slice] + subpatch = patch3d + ph, pw, pc, pz = patch3d.shape - subpatch = patch3d[roi.relative_slice] # make a 3d patch if make_3d: @@ -492,7 +498,7 @@ class RoiSet(object): if rescale_clip is not None: patch = rescale(patch, rescale_clip) - if kwargs.get('draw_bounding_box') is True: + if kwargs.get('draw_bounding_box') is True and expanded: bci = kwargs.get('bounding_box_channel', 0) assert bci < 3 if bci > 0: @@ -522,7 +528,7 @@ class RoiSet(object): find_contours(mask) ) - if pad_to: + if pad_to and expanded: patch = pad(patch, pad_to) return patch diff --git a/tests/test_roiset.py b/tests/test_roiset.py index e1784b56118ac41a20c0ddb63065d374547eba16..b9150edbedcc9ccf52b47ac79714aa34d72ef276 100644 --- a/tests/test_roiset.py +++ b/tests/test_roiset.py @@ -100,15 +100,43 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): self.assertTrue(np.all([si >= 1 for si in rbb.shape])) - def test_make_2d_patches(self): + def test_make_expanded_2d_patches(self): roiset = self._make_roi_set() files = roiset.export_patches( - output_path / '2d_patches', + output_path / 'expanded_2d_patches', draw_bounding_box=True, + expanded=True ) - self.assertGreaterEqual(len(files), 1) + df = roiset.get_df() + for f in files: + acc = generate_file_accessor(f) + la = int(re.search(r'la([\d]+)', str(f)).group(1)) + roi_q = df.loc[df.label == la, :] + self.assertEqual(len(roi_q), 1) + roi = roi_q.iloc[0] + h = int(roi.ebb_y1 - roi.ebb_y0) + w = int(roi.ebb_x1 - roi.ebb_x0) + self.assertEqual((256, 256), acc.hw) + + def test_make_tight_2d_patches(self): + roiset = self._make_roi_set() + files = roiset.export_patches( + output_path / 'tight_2d_patches', + draw_bounding_box=True, + expanded=False + ) + df = roiset.get_df() + for f in files: + acc = generate_file_accessor(f) + la = int(re.search(r'la([\d]+)', str(f)).group(1)) + roi_q = df.loc[df.label == la, :] + self.assertEqual(len(roi_q), 1) + roi = roi_q.iloc[0] + h = int(roi.y1 - roi.y0) + w = int(roi.x1 - roi.x0) + self.assertEqual((h, w), acc.hw) - def test_make_3d_patches(self): + def test_make_expanded_3d_patches(self): roiset = self._make_roi_set() files = roiset.export_patches( output_path / '3d_patches',