From c1cd28eb667555ba0b9fa03adc68945fdec5233d Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Tue, 13 Aug 2024 17:35:24 +0200
Subject: [PATCH] Test covers derived inputs pipeline method directly

---
 model_server/base/pipelines/roiset_obmap.py | 10 ++-
 tests/test_ilastik/test_roiset_workflow.py  | 74 +--------------------
 2 files changed, 11 insertions(+), 73 deletions(-)

diff --git a/model_server/base/pipelines/roiset_obmap.py b/model_server/base/pipelines/roiset_obmap.py
index 1bd9220a..3efa5189 100644
--- a/model_server/base/pipelines/roiset_obmap.py
+++ b/model_server/base/pipelines/roiset_obmap.py
@@ -139,7 +139,7 @@ def roiset_object_map_pipeline(
     rois = RoiSet.from_binary_mask(d['raw'], d.last, RoiSetMetaParams(**k['roi_params']))
 
     # optionally derive additional inputs for object classification
-    if dpmod := models.get('pixel_classifier_derived'):
+    if dpmod := models.get('pixel_classifier_derived_model'):
         ic = k['derived_channels_input_channel']
         ocs = k['derived_channels_output_channels']
         assert ic < d['raw'].chroma
@@ -157,7 +157,7 @@ def roiset_object_map_pipeline(
     else:
         derived_channel_handles = None
 
-    # optionally classify if an object classifier is passed
+    # optionally run an object classifier if specified
     if obmod := models.get('object_classifier_model'):
         obmod_name = k['object_classifier_model_id']
         assert isinstance(obmod, InstanceSegmentationModel)
@@ -167,8 +167,14 @@ def roiset_object_map_pipeline(
             obmod,
             derived_channel_functions=derived_channel_handles
         )
+        # add derived channels to intermediate data
+        for di, dacc in enumerate(rois.accs_derived):
+            d[f'derived_{di:02d}'] = dacc
+
         d[obmod_name] = rois.get_object_class_map(obmod_name)
 
+
+
     # TODO: subclass this boundary pipeline in separate module
     # if p.label_params and p.export_label_interm:
     #     fp_la_interm = Path(output_folder_path) / 'mip_masks' / f'interm_{fstem}.png'
diff --git a/tests/test_ilastik/test_roiset_workflow.py b/tests/test_ilastik/test_roiset_workflow.py
index 860d8089..27fd0d7f 100644
--- a/tests/test_ilastik/test_roiset_workflow.py
+++ b/tests/test_ilastik/test_roiset_workflow.py
@@ -1,15 +1,11 @@
-import shutil
 from pathlib import Path
-from shutil import copyfile
 import unittest
 
 import numpy as np
 
-import model_server.conf.testing as conf
 
 from model_server.base.accessors import generate_file_accessor
 from tests.base.test_model import DummyInstanceSegmentationModel
-from model_server.base.roiset import RoiSetMetaParams, RoiSetExportParams
 
 import model_server.conf.testing as conf
 from model_server.base.pipelines.roiset_obmap import RoiSetObjectMapParams, roiset_object_map_pipeline
@@ -121,6 +117,7 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase):
         self.assertTrue('ob_id' in trace.keys())
         self.assertEqual(len(trace['ob_id']._unique()[0]), 2)
 
+
     def test_object_map_workflow_with_derived_channels(self):
         acc_in = generate_file_accessor(self.fpi)
         params = RoiSetObjectMapParams(
@@ -140,7 +137,8 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase):
 
         self.assertTrue('ob_id' in trace.keys())
         self.assertEqual(len(trace['ob_id']._unique()[0]), 2)
-        self.assertTrue(all([Path(pa).exists() for pa in record['interm']['derived_channels']]))
+        self.assertTrue('derived_00' in trace.keys())
+        self.assertEqual(trace['derived_00'].chroma, 1)
 
 
 class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts):
@@ -255,69 +253,3 @@ class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProd
     #     self.assertTrue(all(np.unique(acc_obmap.data) == [0, 1]))
 
 
-
-class TestRoiSetWithDerivedChannels(conf.TestServerBaseClass):
-
-    app_name = 'api:app'
-    input_data = data['multichannel_zstack_raw']
-
-    # def test_object_map_workflow_with_derived_channels(self):
-    #     models = _get_models()
-    #     models['pixel_classifier_derived'] = models['pixel_classifier_segmentation'] # re-use same classifier
-    #     where_out = output_path / 'roiset' / 'workflow'
-    #
-    #     p = ClassifyZStackApiParams(
-    #         input_filename=str(data['multichannel_zstack_raw']['path']),
-    #         model_ids={
-    #             'pixel_classifier_segmentation': 'id_px',
-    #             'object_classifier': 'id_ob',
-    #             'pixel_classifier_derived': 'id_px',
-    #         },
-    #         derived_channels_input_channel=0,
-    #         derived_channels_output_channels=[1],
-    #         segmentation_channel=test_params['segmentation_channel'],
-    #         patches_channel=test_params['patches_channel'],
-    #         roi_params=_get_roi_params(),
-    #         export_params=_get_export_params(),
-    #     )
-    #
-    #     record = export_zstack_roiset(
-    #         data['multichannel_zstack_raw']['path'],
-    #         where_out,
-    #         models,
-    #         p
-    #     )
-    #     self.assertTrue(all([Path(pa).exists() for pa in record['interm']['derived_channels']]))
-
-    def test_derived_channels_api(self):
-        resp = self._put(
-            'ilastik/seg/load/',
-            body={'project_file': _get_model_params()['pixel_classifier_segmentation']['project_file'].__str__()},
-        )
-        self.assertEqual(resp.status_code, 200)
-        mid_px = resp.json()['model_id']
-        resp = self._put(f'models/dummy_instance/load')
-        mid_ob = resp.json()['model_id']
-
-
-        p = ClassifyZStackApiParams(
-            input_filename=data['multichannel_zstack_raw']['path'].__str__(),
-            model_ids={
-                'pixel_classifier_segmentation': mid_px,
-                'object_classifier': mid_ob,
-                'pixel_classifier_derived': mid_px,
-            },
-            derived_channels_input_channel=0,
-            derived_channels_output_channels=[1],
-            segmentation_channel=test_params['segmentation_channel'],
-            patches_channel=test_params['patches_channel'],
-            roi_params=_get_roi_params(),
-            export_params=_get_export_params(),
-        )
-        resp = self._put('chaeo/classify_zstack/infer', body=p.dict())
-        self.assertEqual(resp.status_code, 200, resp.json())
-        omfp = resp.json()['object_map_filepath']
-        self.assertTrue(Path(omfp).exists())
-        acc_obmap = generate_file_accessor(omfp)
-        self.assertTrue(all(np.unique(acc_obmap.data) == [0, 1]))
-
-- 
GitLab