From a295a72dba3fc52d4834af56200103638ebc264e Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Tue, 13 Aug 2024 10:47:15 +0200
Subject: [PATCH] First test passes on direct pipeline call with accessors

---
 model_server/base/pipelines/params.py         |   5 +-
 model_server/base/pipelines/roiset_obmap.py   | 179 ++++++++++
 model_server/base/session.py                  |   1 +
 .../ilastik/pipelines/roiset_obmap.py         | 266 ---------------
 tests/base/test_pipelines.py                  |   2 +-
 tests/test_ilastik/test_roiset_workflow.py    | 317 ++++++++++++++++++
 6 files changed, 501 insertions(+), 269 deletions(-)
 create mode 100644 model_server/base/pipelines/roiset_obmap.py
 delete mode 100644 model_server/extensions/ilastik/pipelines/roiset_obmap.py
 create mode 100644 tests/test_ilastik/test_roiset_workflow.py

diff --git a/model_server/base/pipelines/params.py b/model_server/base/pipelines/params.py
index ad9585a5..b81ccf41 100644
--- a/model_server/base/pipelines/params.py
+++ b/model_server/base/pipelines/params.py
@@ -7,18 +7,19 @@ from ..session import session, AccessorIdError
 
 class PipelineParams(BaseModel):
     keep_interm: bool = Field(False, description='Keep accessors to intermediate images in session')
+    api: bool = Field(True, description='Validate parameters against server session and map HTTP errors if True')
 
     @root_validator(pre=False)
     def models_are_loaded(cls, dd):
         for k, v in dd.items():
-            if k.endswith('model_id') and v not in session.describe_loaded_models().keys():
+            if dd['api'] and k.endswith('model_id') and v not in session.describe_loaded_models().keys():
                 raise HTTPException(status_code=409, detail=f'Model with {k} = {v} has not been loaded')
         return dd
 
     @root_validator(pre=False)
     def accessors_are_loaded(cls, dd):
         for k, v in dd.items():
-            if k.endswith('accessor_id'):
+            if dd['api'] and k.endswith('accessor_id'):
                 try:
                     info = session.get_accessor_info(v)
                 except AccessorIdError as e:
diff --git a/model_server/base/pipelines/roiset_obmap.py b/model_server/base/pipelines/roiset_obmap.py
new file mode 100644
index 00000000..a25d740d
--- /dev/null
+++ b/model_server/base/pipelines/roiset_obmap.py
@@ -0,0 +1,179 @@
+from typing import Dict, Union
+
+from fastapi import APIRouter
+import pandas as pd
+from pydantic import Field
+
+
+from model_server.base.accessors import GenericImageDataAccessor, PatchStack
+from model_server.base.pipelines.params import PipelineParams, PipelineRecord
+from model_server.base.pipelines.util import call_pipeline
+from model_server.base.roiset import RoiSet, RoiSetMetaParams, RoiSetExportParams
+from model_server.base.session import session
+from model_server.base.util import PipelineTrace
+
+from model_server.base.models import Model, InstanceSegmentationModel, SemanticSegmentationModel
+
+router = APIRouter(
+    prefix='/pipelines',
+)
+
+class RoiSetObjectMapParams(PipelineParams):
+    accessor_id: str = Field(
+        description='ID(s) of previously loaded accessor(s) to use as pipeline input'
+    )
+    pixel_classifier_segmentation_model_id: str = Field(
+        description='Pixel classifier applied to segmentation_channel(s) to segment objects'
+    )
+    object_classifier_model_id: Union[str, None] = Field(
+        None,
+        description='Object classifier used to classify segmented objectss'
+    )
+    pixel_classifier_derived_model_id: Union[str, None] = Field(
+        None,
+        description='Pixel classifier used to derive channel(s) as additional inputs to object classification'
+    )
+    segmentation_channel: int = Field(
+        description='Channel of input image to use for solving segmentation'
+    )
+    patches_channel: int = Field(
+        description='Channel of input image used in patches sent to object classifier'
+    )
+    zmask_zindex: Union[int, None] = Field(
+        None,
+        description='z coordinate to use on input image when solving segmentation; apply MIP if empty',
+    )
+    roi_params: RoiSetMetaParams = RoiSetMetaParams(**{
+        'mask_type': 'boxes',
+        'filters': {
+            'area': {'min': 1e3, 'max': 1e8}
+        },
+        'expand_box_by': [128, 2]
+    })
+    # TODO: maybe don't support all these exports here; instead leverage interm accessors
+    export_params: RoiSetExportParams = RoiSetExportParams(**{
+        'annotated_patches_2d': {
+            'draw_bounding_box': True,
+            'pad_to': 256,
+        },
+        'patches_2d': {
+            'draw_bounding_box': True,
+            'draw_mask': False,
+        },
+        'patch_masks': {},
+        'object_classes': True,
+        'dataframe': True,
+    })
+    derived_channels_input_channel: Union[int, None] = Field(
+        None,
+        description='Channel of input image from which to compute derived channels; use all if empty'
+    )
+    derived_channels_output_channels: Union[int, list] = Field(
+        None,
+        description='Derived channels to send to object classifier; use all if empty'
+    )
+    # TODO: move to subclassed pipeline
+    # label_params: Union[LabelFromBoundarySegParams, None] = None
+    export_label_interm: bool = False
+
+
+class RoiSetToObjectMapRecord(PipelineRecord):
+    pass
+
+@router.put('/roiset_to_obmap/infer')
+def roiset_object_map(p: RoiSetObjectMapParams) -> RoiSetToObjectMapRecord:
+    """
+    Compute a RoiSet from 2d segmentation, apply to z-stack, and optionally apply object classification.
+    """
+
+
+    record, dataframe = call_pipeline(roiset_object_map_pipeline, p)
+
+    # TODO: try labeling this correctly as input accessor_id instead of filename
+    session.write_to_table('RoiSet', {'input_filename': p.accessor_id}, dataframe)
+    return record
+
+
+def roiset_object_map_pipeline(
+        accessors: Dict[str, GenericImageDataAccessor],
+        models: Dict[str, Model],
+        **k
+) -> (RoiSetToObjectMapRecord, pd.DataFrame):
+    if not isinstance(models['pixel_classifier_segmentation_model'], SemanticSegmentationModel):
+        raise IncompatibleModelsError('Expecting a pixel classification model')
+
+    d = PipelineTrace()
+    d['raw'] = accessors['accessor']
+
+    # MIP if no zmask z-index is given, then classify pixels
+    zmi = k['zmask_zindex']
+    sch = k['segmentation_channel']
+
+    if isinstance(zmi, int):
+        assert 0 < zmi < d.last.nz
+        d['mip'] = d.last.get_mono(channel=sch).get_zi(zmi)
+    else:
+        d['mip'] = d.last.get_mono(channel=sch).apply(lambda x: x.max(axis=-1, keepdims=True))
+
+    d['mip_mask'] = models['pixel_classifier_segmentation_model'].label_pixel_class(d.last)
+
+    # TODO: subclass this boundary pipeline in separate module
+    # # optionally label objects using a boundary segmentation pipeline
+    # if p.label_params:
+    #     la_nda, interm = label_from_boundary_seg(
+    #         mip_mask.data[:, :, 0, 0],
+    #         p.label_params,
+    #         return_interm=True
+    #     )
+    #     labels = InMemoryDataAccessor(la_nda)
+    #     rois = RoiSet(stack, labels, params=p.roi_params)
+    # else:
+    #     rois = RoiSet.from_segmentation(stack, mip_mask, params=p.roi_params)
+    # ti.click('generate_zmasks')
+
+    rois = RoiSet.from_binary_mask(d['raw'], d.last, RoiSetMetaParams(**k['roi_params']))
+
+    # optionally derive additional inputs for object classification
+    if dpmod := models.get('pixel_classifier_derived'):
+        ic = k['derived_channels_input_channel']
+        ocs = k['derived_channels_output_channels']
+        assert ic < d['raw'].chroma
+        if not isinstance(dpmod, SemanticSegmentationModel):
+            raise IncompatibleModelsError('Expecting pixel_classifier_derived to be a pixel classification model')
+
+        def _derive(acc: GenericImageDataAccessor, oc: int):
+            acc_mono = acc.get_channels([ic])
+            pxmap = dpmod.infer_patch_stack(acc_mono)[0]
+            assert oc < pxmap.chroma
+            return PatchStack((pxmap.get_channels([oc]).data * 255).astype('uint8'))
+        derived_channel_handles = [
+            lambda a: _derive(a, oc) for oc in ocs
+        ]
+    else:
+        derived_channel_handles = None
+
+    # optionally classify if an object classifier is passed
+    if obmod := models.get('object_classifier_model'):
+        obmod_name = k['object_classifier_model_id']
+        assert isinstance(obmod, InstanceSegmentationModel)
+        rois.classify_by(
+            obmod_name,
+            [k['patches_channel']],
+            obmod,
+            derived_channel_functions=derived_channel_handles
+        )
+        d[obmod_name] = rois.get_object_class_map(obmod_name)
+
+    # TODO: subclass this boundary pipeline in separate module
+    # if p.label_params and p.export_label_interm:
+    #     fp_la_interm = Path(output_folder_path) / 'mip_masks' / f'interm_{fstem}.png'
+    #     plot_image_sequence_with_markers(interm, fp_la_interm, n_rows=3)
+    #     record['label_interm'] = fp_la_interm
+
+    return d
+
+class Error(Exception):
+    pass
+
+class IncompatibleModelsError(Error):
+    pass
\ No newline at end of file
diff --git a/model_server/base/session.py b/model_server/base/session.py
index ea1b3bf9..44e5855d 100644
--- a/model_server/base/session.py
+++ b/model_server/base/session.py
@@ -235,6 +235,7 @@ class _Session(object):
 
         if key is None:
             def mid(i):
+                # TODO: give model the option to report its own name
                 return f'{ModelClass.__name__}_{i:02d}'
 
             while mid(ii) in self.models.keys():
diff --git a/model_server/extensions/ilastik/pipelines/roiset_obmap.py b/model_server/extensions/ilastik/pipelines/roiset_obmap.py
deleted file mode 100644
index fec4f748..00000000
--- a/model_server/extensions/ilastik/pipelines/roiset_obmap.py
+++ /dev/null
@@ -1,266 +0,0 @@
-from pathlib import Path
-from typing import Dict, List, Union
-
-from fastapi import APIRouter, HTTPException
-import numpy as np
-import pandas as pd
-from pydantic import Field
-
-
-from skimage.measure import label
-from skimage.morphology import dilation
-from sklearn.model_selection import train_test_split
-
-# from extensions.chaeo.plotting import plot_image_sequence_with_markers
-# from extensions.chaeo.process import label_from_boundary_seg, LabelFromBoundarySegParams
-
-from model_server.base.accessors import GenericImageDataAccessor, PatchStack
-from model_server.base.models import Model, SemanticSegmentationModel
-from model_server.base.pipelines.params import PipelineParams, PipelineRecord
-from model_server.base.pipelines.util import call_pipeline
-from model_server.base.roiset import RoiSetMetaParams, RoiSetExportParams
-from model_server.base.process import mask_largest_object
-from model_server.base.roiset import RoiSet
-from model_server.base.session import session
-from model_server.base.util import PipelineTrace
-
-from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
-from model_server.base.models import Model, InstanceSegmentationModel, SemanticSegmentationModel
-# from model_server.base.workflows import Timer
-
-router = APIRouter(
-    prefix='/pipelines',
-)
-
-class RoiSetObjectMapParams(PipelineParams):
-    accessor_id: str = Field(
-        description='ID(s) of previously loaded accessor(s) to use as pipeline input'
-    )
-    pixel_classifier_segmentation_model_id: str = Field(
-        description='Pixel classifier applied to segmentation_channel(s) to segment objects'
-    )
-    object_classifier_model_id: Union[str, None] = Field(
-        None,
-        description='Object classifier used to classify segmented objectss'
-    )
-    pixel_classifier_derived_model_id: Union[str, None] = Field(
-        None,
-        description='Pixel classifier used to derive channel(s) as additional inputs to object classification'
-    )
-    segmentation_channel: int = Field(
-        description='Channel of input image to use for solving segmentation'
-    )
-    patches_channel: int = Field(
-        description='Channel of input image used in patches sent to object classifier'
-    )
-    zmask_zindex: Union[int, None] = Field(
-        None,
-        description='z coordinate to use on input image when solving segmentation; apply MIP if empty',
-    )
-    roi_params: RoiSetMetaParams = RoiSetMetaParams(**{
-        'mask_type': 'boxes',
-        'filters': {
-            'area': {'min': 1e3, 'max': 1e8}
-        },
-        'expand_box_by': [128, 2]
-    })
-    # TODO: maybe don't support all these exports here; instead leverage interm accessors
-    export_params: RoiSetExportParams = RoiSetExportParams(**{
-        'annotated_patches_2d': {
-            'draw_bounding_box': True,
-            'pad_to': 256,
-        },
-        'patches_2d': {
-            'draw_bounding_box': True,
-            'draw_mask': False,
-        },
-        'patch_masks': {},
-        'object_classes': True,
-        'dataframe': True,
-    })
-    derived_channels_input_channel: Union[int, None] = Field(
-        None,
-        description='Channel of input image from which to compute derived channels; use all if empty'
-    )
-    derived_channels_output_channels: Union[int, list] = Field(
-        None,
-        description='Derived channels to send to object classifier; use all if empty'
-    )
-    # TODO: move to subclassed pipeline
-    # label_params: Union[LabelFromBoundarySegParams, None] = None
-    export_label_interm: bool = False
-
-
-class RoiSetToObjectMapRecord(PipelineRecord):
-    pass
-#     class _ModelIds(BaseModel):
-#         pixel_classifier_segmentation: str
-#         object_classifier: Union[str, None]
-#         pixel_classifier_derived: Union[str, None]
-#     input_filename: str
-#     model_ids: _ModelIds
-
-@router.put('/roiset_to_obmap/infer')
-def roiset_object_map(p: RoiSetObjectMapParams) -> RoiSetToObjectMapRecord:
-    """
-    Compute a RoiSet from 2d segmentation, apply to z-stack, and optionally apply object classification.
-    """
-    # inpath = session.paths['inbound_images'] / p.input_filename
-    # validate_workflow_inputs([p.model_ids.pixel_classifier_segmentation, p.model_ids.object_classifier], [inpath])
-
-    # def _get_model_dict(mid):
-    #     return {
-    #         'name': mid,
-    #         'model': session.models[mid]['object'],
-    #         'params': session.models[mid]['params'],
-    #     }
-
-
-    record, dataframe = call_pipeline(roiset_object_map_pipeline, p)
-
-    # TODO: try labeling this correctly as input accessor_id instead of filename
-    session.write_to_table('RoiSet', {'input_filename': p.accessor_id}, dataframe)
-    return record
-
-
-    # models = {'pixel_classifier_segmentation': _get_model_dict(p.model_ids.pixel_classifier_segmentation)}
-    # if p.model_ids.object_classifier is not None:
-    #     models['object_classifier'] = _get_model_dict(p.model_ids.object_classifier)
-    # if p.model_ids.pixel_classifier_derived is not None:
-    #     models['pixel_classifier_derived'] = _get_model_dict(p.model_ids.pixel_classifier_derived)
-
-    # record = export_zstack_roiset(inpath, session.paths['outbound_images'], models, p)
-    # object_map_filepaths = [record['interm'][k] for k in record['interm'].keys() if k.startswith('object_classes_')]
-
-    # df = record['dataframe']
-    # session.write_to_table('RoiSet', {'input_filename': p.input_filename}, df)
-    # session.log_info(f'Completed classification of {p.input_filename}, recorded {len(df)} ROIs')
-    #
-    # resp = WorkflowRunRecord(
-    #     pixel_model_id=p.model_ids.pixel_classifier_segmentation,
-    #     object_model_id=p.model_ids.object_classifier,
-    #     input_filepath=p.input_filename,
-    #     pixel_map_filepath=record['interm']['mask'].__str__(),
-    #     object_map_filepath=object_map_filepaths[0].__str__(),
-    #     success=True,
-    #     timer_results=record['timer_results'],
-    # ).dict()
-    # resp['interm'] = record['interm']
-    # return resp
-
-
-
-
-def roiset_object_map_pipeline(
-        accessors: Dict[str, GenericImageDataAccessor],
-        models: Dict[str, Model],
-        **k
-) -> (RoiSetToObjectMapRecord, pd.DataFrame):
-    if not isinstance(models['pixel_classifier_segmentation'], SemanticSegmentationModel):
-        raise IncompatibleModelsError('Expecting a pixel classification model')
-    # if not isinstance(models['ob_model'], IlastikObjectClassifierFromPixelPredictionsModel):
-    #     raise IncompatibleModelsError('Expecting an ilastik object classification from pixel predictions model')
-    # assert isinstance(models['pixel_classifier_segmentation']['model'], SemanticSegmentationModel)
-
-    # ti = Timer()
-    # stack = generate_file_accessor(input_file_path)
-    # fstem = Path(input_file_path).stem
-    # ti.click('file_input')
-
-    d = PipelineTrace()
-    d['raw'] = accessors['accessor']
-
-    # MIP if no zmask z-index is given, then classify pixels
-    zmi = k.get('zmask_zindex')
-    sch = k.get('segmentation_channel')
-
-    if isinstance(zmi, int):
-        assert 0 < zmi < d.last.nz
-        d['mip'] = d.last.get_mono(channel=sch).get_zi(zmi)
-    else:
-        d['mip'] = d.last.get_mono(channel=sch).data.max(axis=-1, keepdims=True)
-
-    d['mip_mask'] = models['pixel_classifier_segmentation'].label_pixel_class(d.last)
-
-    # TODO: subclass this boundary pipeline in separate module
-    # # optionally label objects using a boundary segmentation pipeline
-    # if p.label_params:
-    #     la_nda, interm = label_from_boundary_seg(
-    #         mip_mask.data[:, :, 0, 0],
-    #         p.label_params,
-    #         return_interm=True
-    #     )
-    #     labels = InMemoryDataAccessor(la_nda)
-    #     rois = RoiSet(stack, labels, params=p.roi_params)
-    # else:
-    #     rois = RoiSet.from_segmentation(stack, mip_mask, params=p.roi_params)
-    # ti.click('generate_zmasks')
-
-    rois = RoiSet.from_segmentation(d['raw'], d.last, params=k.get('roi_params'))
-
-    # optionally derive additional inputs for object classification
-    if dmod := models.get('pixel_classifier_derived'):
-        ic = k.get('derived_channels_input_channel')
-        ocs = k.get('derived_channels_output_channels')
-        assert ic < d['raw'].chroma
-        if not isinstance(dmod, SemanticSegmentationModel):
-            raise IncompatibleModelsError('Expecting pixel_classifier_derived to be a pixel classification model')
-
-        def _derive(acc: GenericImageDataAccessor, oc: int):
-            acc_mono = acc.get_channels([ic])
-            pxmap = dmod.infer_patch_stack(acc_mono)[0]
-            assert oc < pxmap.chroma
-            return PatchStack((pxmap.get_channels([oc]).data * 255).astype('uint8'))
-        derived_channel_handles = [
-            lambda a: _derive(a, oc) for oc in ocs
-        ]
-    else:
-        derived_channel_handles = None
-
-    # optionally classify if an object classifier is passed
-    if 'object_classifier' in models.keys():
-        assert isinstance(models['object_classifier']['model'], InstanceSegmentationModel)
-        rois.classify_by(
-            models['object_classifier']['name'],      # TODO: does model need a name?
-            [p.patches_channel],
-            models['object_classifier']['model'],
-            derived_channel_functions=derived_channel_handles
-        )
-        ti.click('classify_objects')
-
-    record = rois.run_exports(
-        Path(output_folder_path),
-        p.patches_channel,
-        fstem,
-        p.export_params,
-    )
-    ti.click('export_roi_products')
-
-    if p.label_params and p.export_label_interm:
-        fp_la_interm = Path(output_folder_path) / 'mip_masks' / f'interm_{fstem}.png'
-        plot_image_sequence_with_markers(interm, fp_la_interm, n_rows=3)
-        record['label_interm'] = fp_la_interm
-
-    fp_mask = Path(output_folder_path) / 'mip_masks' / f'mask_{fstem}.tif'
-    write_accessor_data_to_file(fp_mask, mip_mask)
-
-    unit_mask_data = (mip_mask.data > 0).astype('uint8')
-    fp_unit_mask = Path(output_folder_path) / 'unit_masks' / f'mask_{fstem}.tif'
-    write_accessor_data_to_file(fp_unit_mask, InMemoryDataAccessor(unit_mask_data))
-    record['object_classes_all'] = fp_unit_mask.__str__()
-    record['mask'] = fp_mask
-
-    # return {
-    #     'timer_results': ti.events,
-    #     'dataframe': rois.get_df(),
-    #     'interm': record,
-    #     'output_path': output_folder_path,
-    # }
-
-    return record, df
-
-class Error(Exception):
-    pass
-
-class IncompatibleModelsError(Error):
-    pass
\ No newline at end of file
diff --git a/tests/base/test_pipelines.py b/tests/base/test_pipelines.py
index bbcd2305..84563127 100644
--- a/tests/base/test_pipelines.py
+++ b/tests/base/test_pipelines.py
@@ -17,7 +17,7 @@ class TestSegmentationPipeline(unittest.TestCase):
 
     def test_call_pipeline_function(self):
         acc = generate_file_accessor(czifile['path'])
-        trace = segment.segment_pipeline(acc, self.model, channel=2, smooth=3)
+        trace = segment.segment_pipeline({'accessor': acc}, {'model': self.model}, channel=2, smooth=3)
         outfp = output_path / 'classify_pixels.tif'
         write_accessor_data_to_file(outfp, trace.last)
 
diff --git a/tests/test_ilastik/test_roiset_workflow.py b/tests/test_ilastik/test_roiset_workflow.py
new file mode 100644
index 00000000..62fb87f5
--- /dev/null
+++ b/tests/test_ilastik/test_roiset_workflow.py
@@ -0,0 +1,317 @@
+import shutil
+from pathlib import Path
+from shutil import copyfile
+import unittest
+
+import numpy as np
+
+import model_server.conf.testing as conf
+
+from model_server.base.accessors import generate_file_accessor
+from tests.base.test_model import DummyInstanceSegmentationModel
+from model_server.base.roiset import RoiSetMetaParams, RoiSetExportParams
+
+import model_server.conf.testing as conf
+# from extensions.chaeo.workflows import ClassifyZStackApiParams, export_zstack_roiset
+from model_server.base.pipelines.roiset_obmap import RoiSetObjectMapParams, roiset_object_map_pipeline
+from model_server.extensions.ilastik.models import IlastikPixelClassifierModel, IlastikPixelClassifierParams
+
+data = conf.meta['image_files']
+output_path = conf.meta['output_path']
+test_params = conf.meta['roiset']
+classifiers = conf.meta['ilastik_classifiers']
+
+class BaseTestRoiSetMonoProducts(object):
+
+    def setUp(self) -> None:
+        # set up test raw data and segmentation from file
+        self.stack = generate_file_accessor(data['multichannel_zstack_raw']['path'])
+        self.stack_ch_pa = self.stack.get_mono(test_params['patches_channel'])
+        self.seg_mask = generate_file_accessor(data['multichannel_zstack_mask2d']['path'])
+
+
+def _get_export_params():
+    return RoiSetExportParams(**{
+        'pixel_probabilities': True,
+        'patches_3d': {},
+        'annotated_patches_2d': {
+            'draw_bounding_box': True,
+            'rgb_overlay_channels': [3, None, None],
+            'rgb_overlay_weights': [0.2, 1.0, 1.0],
+            'pad_to': 512,
+        },
+        'patches_2d': {
+            'draw_bounding_box': False,
+            'draw_mask': False,
+        },
+        'patch_masks': {
+            'pad_to': 256,
+        },
+        'annotated_zstacks': {},
+        'object_classes': True,
+        'dataframe': True,
+    })
+
+def _get_roi_params():
+    return RoiSetMetaParams(**{
+        'mask_type': 'boxes',
+        'filters': {
+            'area': {'min': 1e3, 'max': 1e8}
+        },
+        'expand_box_by': [128, 2]
+    })
+
+# def _get_model_params():
+#     return {
+#         'pixel_classifier_segmentation': {
+#             # 'project_file': params['pixel_classifier'],
+#             'project_file': classifiers['px']['path'],
+#         },
+#         'object_classifier': {
+#             'name': 'dummy',
+#         }
+#     }
+
+def _get_models():
+
+    return {
+        'pixel_classifier_segmentation': {
+            'name': 'ilastik_px_mod',
+            'project_file': classifiers['px']['path'],
+            'model': IlastikPixelClassifierModel(
+                IlastikPixelClassifierParams(
+                    project_file=classifiers['px']['path'].__str__()
+                )
+            )
+        },
+        'object_classifier': {
+            'name': 'dummy_ob_mod',
+            'model': DummyInstanceSegmentationModel()
+        },
+
+    }
+
+
+class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase):
+
+    def test_object_map_workflow(self):
+        acc_in = generate_file_accessor(data['multichannel_zstack_raw']['path'])
+        params = RoiSetObjectMapParams(
+            api=False,
+            accessor_id='acc_id',
+            pixel_classifier_segmentation_model_id='px_id',
+            object_classifier_model_id='ob_id',
+            segmentation_channel=test_params['segmentation_channel'],
+            patches_channel=test_params['patches_channel'],
+            roi_params=_get_roi_params(),
+            export_params=_get_export_params(),
+        )
+        trace = roiset_object_map_pipeline(
+            {'accessor': acc_in},
+            {f'{k}_model': v['model'] for k, v in _get_models().items()},
+            **params.dict()
+        )
+
+        self.assertTrue('ob_id' in trace.keys())
+        self.assertEqual(len(trace['ob_id']._unique()[0]), 2)
+
+
+class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass):
+
+    app_name = 'api:app'
+
+    def _copy_input_file_to_server(self):
+        resp = self._get('paths')
+        pa = resp.json()['inbound_images']
+        copyfile(
+            data['multichannel_zstack_raw']['path'],
+            Path(pa) / data['multichannel_zstack_raw']['name']
+        )
+
+    def setUp(self) -> None:
+        self.where_out = output_path / 'trec-adaptive-feedback' / 'roiset'
+        self.where_out.mkdir(parents=True, exist_ok=True)
+        return super().setUp()
+
+    def test_bounceback(self):
+        resp = self._put('chaeo/bounce_back', query={'par1': 'hello'})
+        self.assertEqual(resp.status_code, 200, resp.json())
+        self.assertEqual(resp.json()['params']['par1'], 'hello', resp.json())
+
+    def test_load_pixel_classifier(self):
+        resp = self._put(
+            'ilastik/seg/load/',
+            body={'project_file': _get_model_params()['pixel_classifier_segmentation']['project_file'].__str__()},
+        )
+        model_id = resp.json()['model_id']
+        self.assertTrue(model_id.startswith('IlastikPixelClassifierModel'))
+        return model_id
+
+    def test_load_object_classifier(self):
+        resp = self._put(f'models/dummy_instance/load')
+        model_id = resp.json()['model_id']
+        self.assertTrue(model_id.startswith('DummyInstanceSegmentationModel'))
+        return model_id
+
+    def test_object_map_workflow(self):
+        mid_px = self.test_load_pixel_classifier()
+        mid_ob = self.test_load_object_classifier()
+        resp = self._put(
+            'chaeo/classify_zstack/infer',
+            body=ClassifyZStackApiParams(
+                model_ids={
+                    'pixel_classifier_segmentation': mid_px,
+                    'object_classifier': mid_ob,
+                },
+                **{
+                    'input_filename': data['multichannel_zstack_raw']['path'].__str__(),
+                    'segmentation_channel': 0,
+                    'patches_channel': 1,
+                    'roi_params': _get_roi_params(),
+                    'export_params': {'object_classes': True},
+                },
+            ).dict()
+        )
+        self.assertEqual(resp.status_code, 200, resp.json())
+        omfp = Path(resp.json()['object_map_filepath'])
+        self.assertTrue(omfp.exists())
+        acc_obmap = generate_file_accessor(omfp)
+        self.assertTrue(all(np.unique(acc_obmap.data) == [0, 1]))
+        shutil.copy(omfp, self.where_out / f'normal_{omfp.name}')
+
+    def test_workflow_without_object_classifier(self):
+        mid_px = self.test_load_pixel_classifier()
+        resp = self._put(
+            'chaeo/classify_zstack/infer',
+            body=ClassifyZStackApiParams(
+                model_ids={
+                    'pixel_classifier_segmentation': mid_px,
+                },
+                **{
+                    'input_filename': data['multichannel_zstack_raw']['path'].__str__(),
+                    'segmentation_channel': 0,
+                    'patches_channel': 1,
+                    'roi_params': _get_roi_params(),
+                    'export_params': {'object_classes': True},
+                },
+            ).dict()
+        )
+        self.assertEqual(resp.status_code, 200, resp.json())
+        omfp = Path(resp.json()['object_map_filepath'])
+        self.assertTrue(omfp.exists())
+        acc_obmap = generate_file_accessor(omfp)
+        self.assertTrue(all(np.unique(acc_obmap.data) == [0, 1]))
+        shutil.copy(omfp, self.where_out / f'normal_{omfp.name}')
+
+    def test_object_map_workflow_boundary_channel(self):
+        fp_in = data['multichannel_zstack_raw']['path']
+        resp = self._put(
+            'ilastik/seg/load/',
+            body={
+                'project_file': classifiers['px']['path'].__str__(),
+                'px_class': 1,
+                'px_prob_threshold': 0.5
+            },
+        )
+        mid_px = resp.json()['model_id']
+        mid_ob = self.test_load_object_classifier()
+        resp = self._put(
+            'chaeo/classify_zstack/infer',
+            body=ClassifyZStackApiParams(
+                model_ids={
+                    'pixel_classifier_segmentation': mid_px,
+                    'object_classifier': mid_ob,
+                },
+                **{
+                    'input_filename': fp_in.__str__(),
+                    'segmentation_channel': 0,
+                    'patches_channel': 1,
+                    'roi_params': _get_roi_params(),
+                    'export_params': {'object_classes': True},
+                    'label_params': {
+                        'bbox_filter': {'area': {'min': 1e3, 'max': 1e8}},
+                        'px_expand': 0,
+                        'n_dilate': 2,
+                        'marker_source': 'distmax',
+                        'min_marker_dist': 10,
+                        'background_method': 'threshold',
+                        'coarse_sig': 15,
+                        'background_tr': 0.2,
+                    },
+                    'export_label_interm': False,
+                },
+            ).dict()
+        )
+        self.assertEqual(resp.status_code, 200, resp.json())
+        fp_obmap = Path(resp.json()['object_map_filepath'])
+        print(fp_obmap)
+        self.assertTrue(fp_obmap.exists())
+        acc_obmap = generate_file_accessor(fp_obmap)
+        self.assertTrue(all(np.unique(acc_obmap.data) == [0, 1]))
+
+
+
+class TestRoiSetWithDerivedChannels(conf.TestServerBaseClass):
+
+    app_name = 'api:app'
+
+    def test_object_map_workflow_with_derived_channels(self):
+        models = _get_models()
+        models['pixel_classifier_derived'] = models['pixel_classifier_segmentation'] # re-use same classifier
+        where_out = output_path / 'roiset' / 'workflow'
+
+        p = ClassifyZStackApiParams(
+            input_filename=str(data['multichannel_zstack_raw']['path']),
+            model_ids={
+                'pixel_classifier_segmentation': 'id_px',
+                'object_classifier': 'id_ob',
+                'pixel_classifier_derived': 'id_px',
+            },
+            derived_channels_input_channel=0,
+            derived_channels_output_channels=[1],
+            segmentation_channel=test_params['segmentation_channel'],
+            patches_channel=test_params['patches_channel'],
+            roi_params=_get_roi_params(),
+            export_params=_get_export_params(),
+        )
+
+        record = export_zstack_roiset(
+            data['multichannel_zstack_raw']['path'],
+            where_out,
+            models,
+            p
+        )
+        self.assertTrue(all([Path(pa).exists() for pa in record['interm']['derived_channels']]))
+
+    def test_derived_channels_api(self):
+        resp = self._put(
+            'ilastik/seg/load/',
+            body={'project_file': _get_model_params()['pixel_classifier_segmentation']['project_file'].__str__()},
+        )
+        self.assertEqual(resp.status_code, 200)
+        mid_px = resp.json()['model_id']
+        resp = self._put(f'models/dummy_instance/load')
+        mid_ob = resp.json()['model_id']
+
+
+        p = ClassifyZStackApiParams(
+            input_filename=data['multichannel_zstack_raw']['path'].__str__(),
+            model_ids={
+                'pixel_classifier_segmentation': mid_px,
+                'object_classifier': mid_ob,
+                'pixel_classifier_derived': mid_px,
+            },
+            derived_channels_input_channel=0,
+            derived_channels_output_channels=[1],
+            segmentation_channel=test_params['segmentation_channel'],
+            patches_channel=test_params['patches_channel'],
+            roi_params=_get_roi_params(),
+            export_params=_get_export_params(),
+        )
+        resp = self._put('chaeo/classify_zstack/infer', body=p.dict())
+        self.assertEqual(resp.status_code, 200, resp.json())
+        omfp = resp.json()['object_map_filepath']
+        self.assertTrue(Path(omfp).exists())
+        acc_obmap = generate_file_accessor(omfp)
+        self.assertTrue(all(np.unique(acc_obmap.data) == [0, 1]))
+
-- 
GitLab