From e3162d6bf63329c90ca09eee9637dd4211a8d5e3 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Tue, 29 Aug 2023 15:58:18 +0200
Subject: [PATCH] Updated test for channel selection, and conforming image data
 to XYCZ order

---
 api.py                   |  2 +-
 model_server/ilastik.py  |  9 +++++----
 model_server/image.py    | 32 ++++++++++++++++++++++----------
 model_server/model.py    |  9 ++++-----
 model_server/workflow.py |  4 ++--
 tests/test_api.py        | 15 +++++++++++----
 tests/test_ilastik.py    | 21 +++++++++++++++++----
 tests/test_image.py      | 20 +++++++++++++++++++-
 tests/test_model.py      |  8 ++++----
 tests/test_workflow.py   |  4 ++--
 10 files changed, 87 insertions(+), 37 deletions(-)

diff --git a/api.py b/api.py
index 89e42d8f..1bf0f17d 100644
--- a/api.py
+++ b/api.py
@@ -30,7 +30,7 @@ def load_model(model_id: str, params: Dict[str, str] = None) -> dict:
     session.load_model(model_id, params=params)
     return session.describe_loaded_models()
 
-@app.put('/i2i/infer/{model_id}') # image file in, image file out
+@app.put('/i2i/infer/') # image file in, image file out
 def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
     if model_id not in session.describe_loaded_models().keys():
         raise HTTPException(
diff --git a/model_server/ilastik.py b/model_server/ilastik.py
index a345925f..7be2269c 100644
--- a/model_server/ilastik.py
+++ b/model_server/ilastik.py
@@ -8,7 +8,8 @@ from ilastik.workflows.objectClassification.objectClassificationWorkflow import
 
 import numpy as np
 
-from model_server.model import GenericImageDataAccessor, ImageToImageModel, ParameterExpectedError
+from model_server.image import GenericImageDataAccessor, InMemoryDataAccessor
+from model_server.model import ImageToImageModel, ParameterExpectedError
 
 
 class IlastikImageToImageModel(ImageToImageModel):
@@ -64,11 +65,11 @@ class IlastikPixelClassifierModel(IlastikImageToImageModel):
     def infer(self, input_img: GenericImageDataAccessor, channel=None) -> (np.ndarray, dict):
         dsi = [
             {
-                'Raw Data': PreloadedArrayDatasetInfo(preloaded_array=input_img.data),
+                'Raw Data': PreloadedArrayDatasetInfo(preloaded_array=input_img.data[0:3]),
             }
         ]
-        pxmap = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True)
-        return pxmap
+        pxmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True)
+        return InMemoryDataAccessor(data=pxmaps[0])
 
 
 
diff --git a/model_server/image.py b/model_server/image.py
index 2e314d1b..4bea66a1 100644
--- a/model_server/image.py
+++ b/model_server/image.py
@@ -14,7 +14,7 @@ class GenericImageDataAccessor(ABC):
     def __init__(self):
         """
         Abstract base class that exposes an interfaces for image data, irrespective of whether it is instantiated
-        from file I/O or other means.
+        from file I/O or other means.  Enforces X, Y, C, Z dimensions in that order.
         """
         pass
 
@@ -22,23 +22,34 @@ class GenericImageDataAccessor(ABC):
     def chroma(self):
         return self.shape_dict['C']
 
+    @staticmethod
+    def conform_data(data):
+        if len(data.shape) > 4:
+            raise DataShapeError(f'Cannot handle image with dimensions other than X, Y, C, and Z: {data.shape}')
+        ones = [1 for i in range(0, 4 - len(data.shape))]
+        return data.reshape(*data.shape, *ones)
+
     def is_3d(self):
         return True if self.shape_dict['Z'] > 1 else False
 
     def get_one_channel_data (self, channel: int):
-        return InMemoryDataAccessor(self.data[:, :, channel, :])
+        return InMemoryDataAccessor(self.data[:, :, int(channel), :])
 
     @property
-    def data(self): # XYCZ enforced
+    def data(self):
         return self._data
 
+    @property
+    def shape(self):
+        return self._data.shape
+
     @property
     def shape_dict(self):
         return dict(zip(('X', 'Y', 'C', 'Z'), self.data.shape))
 
 class InMemoryDataAccessor(GenericImageDataAccessor):
     def __init__(self, data):
-        self._data = data
+        self._data = self.conform_data(data)
 
 class GenericImageFileAccessor(GenericImageDataAccessor): # image data is loaded from a file
     def __init__(self, fpath: Path):
@@ -67,7 +78,7 @@ class CziImageFileAccessor(GenericImageFileAccessor):
 
         sd = {ch: cf.shape[cf.axes.index(ch)] for ch in cf.axes}
         if sd['S'] > 1 or sd['T'] > 1:
-            raise FileShapeError(f'Cannot handle image with multiple positions or time points: {sd}')
+            raise DataShapeError(f'Cannot handle image with multiple positions or time points: {sd}')
 
         idx = {k: sd[k] for k in ['X', 'Y', 'C', 'Z']}
         xycz = np.moveaxis(
@@ -76,10 +87,11 @@ class CziImageFileAccessor(GenericImageFileAccessor):
             [0, 1, 2, 3]
         )
 
-        try:
-            self._data = xycz.reshape(xycz.shape[0:4])
-        except Exception:
-            raise FileShapeError(f'Cannot handle image with dimensions other than X, Y, C, and Z')
+        # try:
+        #     self._data = xycz.reshape(xycz.shape[0:4])
+        # except Exception:
+        #     raise FileShapeError(f'Cannot handle image with dimensions other than X, Y, C, and Z')
+        self._data = self.conform_data(xycz.reshape(xycz.shape[0:4]))
 
     def __del__(self):
         self.czifile.close()
@@ -108,7 +120,7 @@ class FileAccessorError(Error):
 class FileNotFoundError(Error):
     pass
 
-class FileShapeError(Error):
+class DataShapeError(Error):
     pass
 
 class FileWriteError(Error):
diff --git a/model_server/model.py b/model_server/model.py
index 672bb63d..8c07d2d8 100644
--- a/model_server/model.py
+++ b/model_server/model.py
@@ -4,7 +4,7 @@ import os
 
 import numpy as np
 
-from model_server.image import GenericImageDataAccessor
+from model_server.image import GenericImageDataAccessor, InMemoryDataAccessor
 
 
 class Model(ABC):
@@ -59,7 +59,7 @@ class ImageToImageModel(Model):
     """
 
     @abstractmethod
-    def infer(self, img) -> (np.ndarray, dict):
+    def infer(self, img) -> (GenericImageDataAccessor, dict):
         super().infer(img)
 
 class DummyImageToImageModel(ImageToImageModel):
@@ -69,14 +69,13 @@ class DummyImageToImageModel(ImageToImageModel):
     def load(self):
         return True
 
-    def infer(self, img: GenericImageDataAccessor) -> (np.ndarray, dict):
+    def infer(self, img: GenericImageDataAccessor) -> GenericImageDataAccessor:
         super().infer(img)
         w = img.shape_dict['X']
         h = img.shape_dict['Y']
         result = np.zeros([h, w], dtype='uint8')
         result[floor(0.25 * h) : floor(0.75 * h), floor(0.25 * w) : floor(0.75 * w)] = 255
-        return (result, {'success': True})
-
+        return InMemoryDataAccessor(data=result)
 
 class Error(Exception):
     pass
diff --git a/model_server/workflow.py b/model_server/workflow.py
index ece9b1b4..38803ed6 100644
--- a/model_server/workflow.py
+++ b/model_server/workflow.py
@@ -28,7 +28,7 @@ def infer_image_to_image(fpi, model, where_output, **kwargs) -> dict:
 
     # run model inference
     # TODO: call this async / await and report out infer status to optional callback
-    outdata, messages = model.infer(img)
+    outdata = model.infer(img)
     dt_inf = time() - t0
 
     # TODO: assert outdata format
@@ -49,7 +49,7 @@ def infer_image_to_image(fpi, model, where_output, **kwargs) -> dict:
         model_id=model.model_id,
         input_filepath=str(fpi),
         output_filepath=str(outpath),
-        success=messages['success'],
+        success=True,
         timer_results=timer_results
     )
 
diff --git a/tests/test_api.py b/tests/test_api.py
index e290d163..f5869bdd 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -62,8 +62,11 @@ class TestApiFromAutomatedClient(unittest.TestCase):
     def test_i2i_inference_errors_model_not_found(self):
         model_id = 'not_a_real_model'
         resp = requests.put(
-            self.uri + f'i2i/infer/{model_id}',
-            params={'input_filename': 'not_a_real_file.name'}
+            self.uri + f'i2i/infer/',
+            params={
+                'model_id': model_id,
+                'input_filename': 'not_a_real_file.name'
+            }
         )
         print(resp.content)
         self.assertEqual(resp.status_code, 409)
@@ -77,7 +80,11 @@ class TestApiFromAutomatedClient(unittest.TestCase):
         self.assertEqual(resp_load.status_code, 200, f'Error loading {model.model_id}')
         self.copy_input_file_to_server()
         resp_infer = requests.put(
-            self.uri + f'i2i/infer/{model.model_id}',
-            params={'input_filename': czifile['filename']},
+            self.uri + f'i2i/infer/',
+            params={
+                'model_id': model.model_id,
+                'input_filename': czifile['filename'],
+                'channel': 2,
+            },
         )
         self.assertEqual(resp_infer.status_code, 200, f'Error inferring from {model.model_id}')
\ No newline at end of file
diff --git a/tests/test_ilastik.py b/tests/test_ilastik.py
index 6fd34563..94a7cbbb 100644
--- a/tests/test_ilastik.py
+++ b/tests/test_ilastik.py
@@ -1,6 +1,6 @@
 import unittest
-from conf.testing import czifile, ilastik
-from model_server.image import CziImageFileAccessor
+from conf.testing import czifile, ilastik, output_path
+from model_server.image import CziImageFileAccessor, write_accessor_data_to_file
 from model_server.ilastik import IlastikPixelClassifierModel
 
 class TestIlastikPixelClassification(unittest.TestCase):
@@ -18,8 +18,21 @@ class TestIlastikPixelClassification(unittest.TestCase):
         with self.assertRaises(io.UnsupportedOperation):
             faulthandler.enable(file=sys.stdout)
 
-    def test_instantiate_pixel_classifier(self):
+    def test_run_pixel_classifier(self):
+        channel = 2
         model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']})
         cf = CziImageFileAccessor(czifile['path'])
-        model.infer(cf.get_one_channel_data(2))
+        pxmap = model.infer(cf.get_one_channel_data(channel))
 
+        print(pxmap.shape_dict)
+
+        # self.assertEqual(pxmap.shape[0:2], cf.shape[0:2])
+        # self.assertEqual(pxmap.shape_dict['C'], 2)
+        # self.assertEqual(pxmap.shape_dict['Z'], 1)
+        #
+        # self.assertTrue(
+        #     write_accessor_data_to_file(
+        #         output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif',
+        #         pxmap
+        #     )
+        # )
\ No newline at end of file
diff --git a/tests/test_image.py b/tests/test_image.py
index 8aef0eaf..3ca3cb2e 100644
--- a/tests/test_image.py
+++ b/tests/test_image.py
@@ -1,6 +1,9 @@
 import unittest
+
+import numpy as np
+
 from conf.testing import czifile, output_path
-from model_server.image import CziImageFileAccessor, write_accessor_data_to_file
+from model_server.image import CziImageFileAccessor, DataShapeError, InMemoryDataAccessor, write_accessor_data_to_file
 
 class TestCziImageFileAccess(unittest.TestCase):
     def setUp(self) -> None:
@@ -13,6 +16,8 @@ class TestCziImageFileAccess(unittest.TestCase):
         self.assertEqual(cf.chroma, czifile['c'])
         self.assertFalse(cf.is_3d())
         self.assertEqual(len(cf.data.shape), 4)
+        self.assertEqual(cf.shape[0], czifile['h'])
+        self.assertEqual(cf.shape[1], czifile['w'])
 
     def test_write_single_channel_tif(self):
         ch = 2
@@ -26,3 +31,16 @@ class TestCziImageFileAccess(unittest.TestCase):
         )
         self.assertEqual(cf.data.shape[0:2], mono.data.shape[0:2])
         self.assertEqual(cf.data.shape[3], mono.data.shape[2])
+
+    def test_conform_data_shorter_than_xycz(self):
+        data = np.random.rand(256, 512)
+        acc = InMemoryDataAccessor(data)
+        self.assertEqual(
+            acc.shape_dict,
+            {'X': 256, 'Y': 512, 'C': 1, 'Z': 1}
+        )
+
+    def test_conform_data_longer_than_xycz(self):
+        data = np.random.rand(256, 512, 12, 8, 3)
+        with self.assertRaises(DataShapeError):
+            acc = InMemoryDataAccessor(data)
\ No newline at end of file
diff --git a/tests/test_model.py b/tests/test_model.py
index 2c1d7c64..9f4ab388 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -25,25 +25,25 @@ class TestCziImageFileAccess(unittest.TestCase):
 
     def test_czifile_is_correct_shape(self):
         model = DummyImageToImageModel()
-        img, _ = model.infer(self.cf)
+        img = model.infer(self.cf)
 
         w = czifile['w']
         h = czifile['h']
 
         self.assertEqual(
             img.shape,
-            (h, w),
+            (h, w, 1, 1),
             'Inferred image is not the expected shape'
         )
 
         self.assertEqual(
-            img[int(w/2), int(h/2)],
+            img.data[int(w/2), int(h/2)],
             255,
             'Middle pixel is not white as expected'
         )
 
         self.assertEqual(
-            img[0, 0],
+            img.data[0, 0],
             0,
             'First pixel is not black as expected'
         )
diff --git a/tests/test_workflow.py b/tests/test_workflow.py
index fff03bbc..83935d3d 100644
--- a/tests/test_workflow.py
+++ b/tests/test_workflow.py
@@ -10,7 +10,7 @@ class TestGetSessionObject(unittest.TestCase):
         self.model = DummyImageToImageModel()
 
     def test_single_session_instance(self):
-        result = infer_image_to_image(czifile['path'], self.model, output_path)
+        result = infer_image_to_image(czifile['path'], self.model, output_path, channel=2)
         self.assertTrue(result.success)
 
         import tifffile
@@ -20,7 +20,7 @@ class TestGetSessionObject(unittest.TestCase):
 
         self.assertEqual(
             img.shape,
-            (h, w),
+            (h, w, 1, 1),
             'Inferred image is not the expected shape'
         )
 
-- 
GitLab