diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py index 6c372b6eebd5c56f0228e771d9771b51a9f998cd..0eb22a65569c8bc5720383e235deac3ff84e517e 100644 --- a/model_server/base/accessors.py +++ b/model_server/base/accessors.py @@ -484,8 +484,11 @@ class PatchStack(InMemoryDataAccessor): def shape_dict(self): return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape)) - def get_list(self): - return [self._data[i][self._slices[i]] for i in range(0, self.count)] + def get_list(self, crop=True): + if crop: + return [self._data[i][self._slices[i]] for i in range(0, self.count)] + else: + return [self._data[i] for i in range(0, self.count)] @property def pyxcz(self): diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index 67444f742c32e865d6a4d165b9db439209920dda..1424ee556b853ebfea2d4b8b4302527c9c958b1a 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -152,7 +152,7 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): ) return InMemoryDataAccessor(data=yxcz) - def infer_patch_stack(self, img: PatchStack, **kwargs) -> (np.ndarray, dict): + def infer_patch_stack(self, img: PatchStack, crop=True, **kwargs) -> (np.ndarray, dict): """ Iterative over a patch stack, call inference separately on each cropped patch """ @@ -161,7 +161,7 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): 'Raw Data': self.PreloadedArrayDatasetInfo( preloaded_array=vigra.taggedView(patch, 'yxcz')) - } for patch in img.get_list() + } for patch in img.get_list(crop=crop) ] pxmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] yxcz = [np.moveaxis(pm, [1, 2, 3, 0], [0, 1, 2, 3]) for pm in pxmaps] diff --git a/tests/base/test_accessors.py b/tests/base/test_accessors.py index 5e3af0afe143695d63ac0bb37e31e8f3aa39390c..16053c91a8129fd83b5e7cc2bb5e372cf5efc96f 100644 --- a/tests/base/test_accessors.py +++ b/tests/base/test_accessors.py @@ -290,8 +290,11 @@ class TestPatchStackAccessor(unittest.TestCase): self.assertEqual(patches[i].shape, acc.iat(i, crop=True).shape) self.assertEqual(acc.shape[1:], acc.iat(i, crop=False).shape) - ps_list = acc.get_list() - self.assertTrue(all([np.all(ps_list[i] == patches[i]) for i in range(0, n)])) + ps_list_cropped = acc.get_list(crop=True) + self.assertTrue(all([np.all(ps_list_cropped[i] == patches[i]) for i in range(0, n)])) + + ps_list_uncropped = acc.get_list(crop=False) + self.assertTrue(all([p.shape == acc.shape[1:] for p in ps_list_uncropped])) def test_make_3d_patch_stack_from_list_force_long_dim(self): def _r(h, w):