From c038f5b16e48a81f76ad37471891e1f465e88fb1 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 31 Oct 2024 06:02:34 +0100
Subject: [PATCH] Added option to crop list representation of patch stack

---
 model_server/base/accessors.py            | 7 +++++--
 model_server/extensions/ilastik/models.py | 4 ++--
 tests/base/test_accessors.py              | 7 +++++--
 3 files changed, 12 insertions(+), 6 deletions(-)

diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py
index 6c372b6e..0eb22a65 100644
--- a/model_server/base/accessors.py
+++ b/model_server/base/accessors.py
@@ -484,8 +484,11 @@ class PatchStack(InMemoryDataAccessor):
     def shape_dict(self):
         return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape))
 
-    def get_list(self):
-        return [self._data[i][self._slices[i]] for i in range(0, self.count)]
+    def get_list(self, crop=True):
+        if crop:
+            return [self._data[i][self._slices[i]] for i in range(0, self.count)]
+        else:
+            return [self._data[i] for i in range(0, self.count)]
 
     @property
     def pyxcz(self):
diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py
index 67444f74..1424ee55 100644
--- a/model_server/extensions/ilastik/models.py
+++ b/model_server/extensions/ilastik/models.py
@@ -152,7 +152,7 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
         )
         return InMemoryDataAccessor(data=yxcz)
 
-    def infer_patch_stack(self, img: PatchStack, **kwargs) -> (np.ndarray, dict):
+    def infer_patch_stack(self, img: PatchStack, crop=True, **kwargs) -> (np.ndarray, dict):
         """
         Iterative over a patch stack, call inference separately on each cropped patch
         """
@@ -161,7 +161,7 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
                 'Raw Data': self.PreloadedArrayDatasetInfo(
                     preloaded_array=vigra.taggedView(patch, 'yxcz'))
 
-            } for patch in img.get_list()
+            } for patch in img.get_list(crop=crop)
         ]
         pxmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True)  # [z x h x w x n]
         yxcz = [np.moveaxis(pm, [1, 2, 3, 0], [0, 1, 2, 3]) for pm in pxmaps]
diff --git a/tests/base/test_accessors.py b/tests/base/test_accessors.py
index 5e3af0af..16053c91 100644
--- a/tests/base/test_accessors.py
+++ b/tests/base/test_accessors.py
@@ -290,8 +290,11 @@ class TestPatchStackAccessor(unittest.TestCase):
             self.assertEqual(patches[i].shape, acc.iat(i, crop=True).shape)
             self.assertEqual(acc.shape[1:], acc.iat(i, crop=False).shape)
 
-        ps_list = acc.get_list()
-        self.assertTrue(all([np.all(ps_list[i] == patches[i]) for i in range(0, n)]))
+        ps_list_cropped = acc.get_list(crop=True)
+        self.assertTrue(all([np.all(ps_list_cropped[i] == patches[i]) for i in range(0, n)]))
+
+        ps_list_uncropped = acc.get_list(crop=False)
+        self.assertTrue(all([p.shape == acc.shape[1:] for p in ps_list_uncropped]))
 
     def test_make_3d_patch_stack_from_list_force_long_dim(self):
         def _r(h, w):
-- 
GitLab