From 3fa8292a1da8b43c623f5c346af9407d82f21789 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Thu, 15 Aug 2024 16:46:49 +0200 Subject: [PATCH] API maps RoiSet products as accessors --- model_server/base/models.py | 3 +-- model_server/base/pipelines/roiset_obmap.py | 19 +++++-------------- model_server/base/session.py | 7 +++++-- tests/test_ilastik/test_roiset_workflow.py | 12 ++++-------- 4 files changed, 15 insertions(+), 26 deletions(-) diff --git a/model_server/base/models.py b/model_server/base/models.py index ffaa346f..eaded2c9 100644 --- a/model_server/base/models.py +++ b/model_server/base/models.py @@ -50,7 +50,7 @@ class Model(ABC): @property def name(self): - return f'{self.__class__}' + return f'{self.__class__.__name__}' @@ -131,7 +131,6 @@ class InstanceSegmentationModel(ImageToImageModel): class BinaryThresholdSegmentationModel(SemanticSegmentationModel): - # TODO: also allow relative threshold def __init__(self, tr: float = 0.5): self.tr = tr diff --git a/model_server/base/pipelines/roiset_obmap.py b/model_server/base/pipelines/roiset_obmap.py index 376dbe3c..96e03390 100644 --- a/model_server/base/pipelines/roiset_obmap.py +++ b/model_server/base/pipelines/roiset_obmap.py @@ -53,20 +53,7 @@ class RoiSetObjectMapParams(PipelineParams): }, 'expand_box_by': [128, 2] }) - # TODO: maybe don't support all these exports here; instead leverage interm accessors - export_params: RoiSetExportParams = RoiSetExportParams(**{ - 'annotated_patches_2d': { - 'draw_bounding_box': True, - 'pad_to': 256, - }, - 'patches_2d': { - 'draw_bounding_box': True, - 'draw_mask': False, - }, - 'patch_masks': {}, - 'object_classes': True, - 'dataframe': True, - }) + export_params: RoiSetExportParams = RoiSetExportParams() derived_channels_input_channel: Union[int, None] = Field( None, description='Channel of input image from which to compute derived channels; use all if empty' @@ -114,6 +101,10 @@ def roiset_object_map_pipeline( d['labeled'] = get_label_ids(d.last) rois = RoiSet.from_object_ids(d['input'], d['labeled'], RoiSetMetaParams(**k['roi_params'])) + # optionally append RoiSet products + for ki, vi in rois.get_export_product_accessors(k['patches_channel'], RoiSetExportParams(**k['export_params'])).items(): + d[ki] = vi + # optionally run an object classifier if specified if obmod := models.get('object_classifier_model'): obmod_name = k['object_classifier_model_id'] diff --git a/model_server/base/session.py b/model_server/base/session.py index 8358a957..f22655c6 100644 --- a/model_server/base/session.py +++ b/model_server/base/session.py @@ -10,7 +10,7 @@ from typing import Union import pandas as pd from ..conf import defaults -from .accessors import GenericImageDataAccessor +from .accessors import GenericImageDataAccessor, PatchStack from .models import Model logger = logging.getLogger(__name__) @@ -166,7 +166,10 @@ class _Session(object): f'Cannot overwrite accessor that is already written to {old_fp}' ) - acc.write(fp) + if isinstance(acc, PatchStack): + acc.export_pyxcz(fp) + else: + acc.write(fp) self.accessors[acc_id]['filepath'] = fp.__str__() return fp.name diff --git a/tests/test_ilastik/test_roiset_workflow.py b/tests/test_ilastik/test_roiset_workflow.py index cd6d2ded..c02dd6d9 100644 --- a/tests/test_ilastik/test_roiset_workflow.py +++ b/tests/test_ilastik/test_roiset_workflow.py @@ -39,8 +39,7 @@ class BaseTestRoiSetMonoProducts(object): def _get_export_params(self): return { - 'pixel_probabilities': True, - 'patches_3d': {}, + 'patches_3d': None, 'annotated_patches_2d': { 'draw_bounding_box': True, 'rgb_overlay_channels': [3, None, None], @@ -51,12 +50,8 @@ class BaseTestRoiSetMonoProducts(object): 'draw_bounding_box': False, 'draw_mask': False, }, - 'patch_masks': { - 'pad_to': 256, - }, - 'annotated_zstacks': {}, + 'annotated_zstacks': None, 'object_classes': True, - 'dataframe': True, } def _get_roi_params(self): @@ -119,7 +114,8 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase): {f'{k}_model': v['model'] for k, v in self._get_models().items()}, **params.dict() ) - + self.assertEqual(trace.pop('annotated_patches_2d').count, 13) + self.assertEqual(trace.pop('patches_2d').count, 13) trace.write_interm(Path(output_path) / 'trace', 'roiset_worfklow_trace', skip_first=False, skip_last=False) self.assertTrue('ob_id' in trace.keys()) self.assertEqual(len(trace['labeled'].unique()[0]), 14) -- GitLab