From c038f5b16e48a81f76ad37471891e1f465e88fb1 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Thu, 31 Oct 2024 06:02:34 +0100 Subject: [PATCH] Added option to crop list representation of patch stack --- model_server/base/accessors.py | 7 +++++-- model_server/extensions/ilastik/models.py | 4 ++-- tests/base/test_accessors.py | 7 +++++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py index 6c372b6e..0eb22a65 100644 --- a/model_server/base/accessors.py +++ b/model_server/base/accessors.py @@ -484,8 +484,11 @@ class PatchStack(InMemoryDataAccessor): def shape_dict(self): return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape)) - def get_list(self): - return [self._data[i][self._slices[i]] for i in range(0, self.count)] + def get_list(self, crop=True): + if crop: + return [self._data[i][self._slices[i]] for i in range(0, self.count)] + else: + return [self._data[i] for i in range(0, self.count)] @property def pyxcz(self): diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index 67444f74..1424ee55 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -152,7 +152,7 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): ) return InMemoryDataAccessor(data=yxcz) - def infer_patch_stack(self, img: PatchStack, **kwargs) -> (np.ndarray, dict): + def infer_patch_stack(self, img: PatchStack, crop=True, **kwargs) -> (np.ndarray, dict): """ Iterative over a patch stack, call inference separately on each cropped patch """ @@ -161,7 +161,7 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): 'Raw Data': self.PreloadedArrayDatasetInfo( preloaded_array=vigra.taggedView(patch, 'yxcz')) - } for patch in img.get_list() + } for patch in img.get_list(crop=crop) ] pxmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] yxcz = [np.moveaxis(pm, [1, 2, 3, 0], [0, 1, 2, 3]) for pm in pxmaps] diff --git a/tests/base/test_accessors.py b/tests/base/test_accessors.py index 5e3af0af..16053c91 100644 --- a/tests/base/test_accessors.py +++ b/tests/base/test_accessors.py @@ -290,8 +290,11 @@ class TestPatchStackAccessor(unittest.TestCase): self.assertEqual(patches[i].shape, acc.iat(i, crop=True).shape) self.assertEqual(acc.shape[1:], acc.iat(i, crop=False).shape) - ps_list = acc.get_list() - self.assertTrue(all([np.all(ps_list[i] == patches[i]) for i in range(0, n)])) + ps_list_cropped = acc.get_list(crop=True) + self.assertTrue(all([np.all(ps_list_cropped[i] == patches[i]) for i in range(0, n)])) + + ps_list_uncropped = acc.get_list(crop=False) + self.assertTrue(all([p.shape == acc.shape[1:] for p in ps_list_uncropped])) def test_make_3d_patch_stack_from_list_force_long_dim(self): def _r(h, w): -- GitLab