From b9037651b97dd8ddd58f985db2fdb260fd2e72b4 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 26 Apr 2024 15:05:13 +0200
Subject: [PATCH] Test coverage of ilastik workflow and API endpoint that
 currently restricts multichannel inputs

---
 .../extensions/ilastik/tests/test_ilastik.py  | 58 ++++++++++++++++---
 1 file changed, 50 insertions(+), 8 deletions(-)

diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py
index d98a41eb..2406cdbe 100644
--- a/model_server/extensions/ilastik/tests/test_ilastik.py
+++ b/model_server/extensions/ilastik/tests/test_ilastik.py
@@ -7,6 +7,7 @@ import numpy as np
 from model_server.conf.testing import czifile, ilastik_classifiers, output_path, roiset_test_data
 from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, PatchStack, write_accessor_data_to_file
 from model_server.extensions.ilastik import models as ilm
+from model_server.extensions.ilastik.workflows import infer_px_then_ob_model
 from model_server.base.models import InvalidObjectLabelsError
 from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams
 from model_server.base.workflows import classify_pixels
@@ -325,7 +326,7 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass):
         super(TestIlastikOnMultichannelInputs, self).setUp()
         self.pa_px_classifier = ilastik_classifiers['px_color_zstack']
         self.pa_ob_classifier = ilastik_classifiers['ob_color_zstack']
-        self.input_image = generate_file_accessor(roiset_test_data['multichannel_zstack']['path'])
+        self.pa_input_image = roiset_test_data['multichannel_zstack']['path']
 
     def _copy_input_file_to_server(self):
         from shutil import copyfile
@@ -340,21 +341,62 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass):
         )
 
     def test_classify_pixels(self):
-        self.assertGreater(self.input_image.chroma, 1)
+        img = generate_file_accessor(self.pa_input_image)
+        self.assertGreater(img.chroma, 1)
         mod = ilm.IlastikPixelClassifierModel(params={'project_file': self.pa_px_classifier})
-        pxmap = mod.infer(self.input_image)[0]
-        self.assertEqual(pxmap.hw, self.input_image.hw)
-        self.assertEqual(pxmap.nz, self.input_image.nz)
+        pxmap = mod.infer(img)[0]
+        self.assertEqual(pxmap.hw, img.hw)
+        self.assertEqual(pxmap.nz, img.nz)
         return pxmap
 
     def test_classify_objects(self):
         pxmap = self.test_classify_pixels()
+        img = generate_file_accessor(self.pa_input_image)
         mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_classifier})
-        obmap = mod.infer(self.input_image, pxmap)[0]
-        self.assertEqual(obmap.hw, self.input_image.hw)
-        self.assertEqual(obmap.nz, self.input_image.nz)
+        obmap = mod.infer(img, pxmap)[0]
+        self.assertEqual(obmap.hw, img.hw)
+        self.assertEqual(obmap.nz, img.nz)
+
+    def _call_workflow(self, channel):
+        return infer_px_then_ob_model(
+            self.pa_input_image,
+            ilm.IlastikPixelClassifierModel(params={'project_file': self.pa_px_classifier}),
+            ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_classifier}),
+            output_path,
+            channel=channel,
+        )
+
+    def test_workflow(self):
+        with self.assertRaises(ilm.IlastikInputShapeError):
+            self._call_workflow(channel=0)
+
+        res = self._call_workflow(channel=None)
 
+    def test_api(self):
+        resp_load = self._put(
+            'ilastik/seg/load/',
+            query={'project_file': str(self.pa_px_classifier)},
+        )
+        self.assertEqual(resp_load.status_code, 200, resp_load.json())
+        px_model_id = resp_load.json()['model_id']
+
+        resp_load = self._put(
+            'ilastik/pxmap_to_obj/load/',
+            query={'project_file': str(self.pa_ob_classifier)},
+        )
+        self.assertEqual(resp_load.status_code, 200, resp_load.json())
+        ob_model_id = resp_load.json()['model_id']
 
+        resp_infer = self._put(
+            'ilastik/pixel_then_object_classification/infer/',
+            query={
+                'px_model_id': px_model_id,
+                'ob_model_id': ob_model_id,
+                'input_filename': self.pa_input_image,
+                # 'channel': 0,
+            }
+        )
+        self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
 
 
 class TestIlastikObjectClassification(unittest.TestCase):
-- 
GitLab