diff --git a/model_server/base/pipelines/segment.py b/model_server/base/pipelines/segment.py index 887ccea548105933379af146622d2de6b280bb03..e0180b557e7ab467f6d17812aa5f5a2a96a76a2f 100644 --- a/model_server/base/pipelines/segment.py +++ b/model_server/base/pipelines/segment.py @@ -36,7 +36,7 @@ def segment_pipeline( model = models.get('model') if not isinstance(model, SemanticSegmentationModel): - raise IncompatibleModelsError('Expecting a pixel classification model') + raise IncompatibleModelsError('Expecting a semantic segmentation model') if ch := k.get('channel') is not None: d['mono'] = d['input'].get_mono(ch) diff --git a/model_server/base/session.py b/model_server/base/session.py index 8f25520d9dca34ca143337327bb14c74dc197ea3..4aa88a1e212fc5e7b759fd82c3ef69f2e45d1ff3 100644 --- a/model_server/base/session.py +++ b/model_server/base/session.py @@ -251,7 +251,7 @@ class _Session(object): :param params: optional parameters that are passed to the model's construct :return: model_id of loaded model """ - mi = ModelClass(params=params) + mi = ModelClass(params=params.dict()) assert mi.loaded, f'Error loading instance of {ModelClass.__name__}' ii = 0 diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index 75e0434cdacd58426b2070cb4124a72451151ade..dccb269ade99d677fc22336979ccd98c61a1076e 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -2,11 +2,9 @@ import json from logging import getLogger import os from pathlib import Path -from typing import Union import warnings import numpy as np -from pydantic import BaseModel, Field import vigra import model_server.extensions.ilastik.conf @@ -14,17 +12,10 @@ from ...base.accessors import PatchStack from ...base.accessors import GenericImageDataAccessor, InMemoryDataAccessor from ...base.models import Model, ImageToImageModel, InstanceSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel -# TODO: move params models to router only; model classes to be created only with dict -class IlastikParams(BaseModel): - project_file: str = Field(description='(*.ilp) ilastik project filename') - duplicate: bool = Field( - True, - description='Load another instance of the same project file if True; return existing one if False' - ) class IlastikModel(Model): - def __init__(self, params: IlastikParams, autoload=True, enforce_embedded=True): + def __init__(self, params: dict, autoload=True, enforce_embedded=True): """ Base class for models that run via ilastik shell API :param params: @@ -33,7 +24,7 @@ class IlastikModel(Model): :param enforce_embedded: raise an error if all input data are not embedded in the project file, i.e. on the filesystem """ - pf = Path(params.project_file) + pf = Path(params['project_file']) self.enforce_embedded = enforce_embedded if pf.is_absolute(): pap = pf @@ -103,14 +94,11 @@ class IlastikModel(Model): def model_3d(self): return self.model_shape_dict['Z'] > 1 -class IlastikPixelClassifierParams(IlastikParams): - px_class: int = 0 - px_prob_threshold: float = 0.5 class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): operations = ['segment', ] - def __init__(self, params: IlastikPixelClassifierParams, **kwargs): + def __init__(self, params: dict, **kwargs): super(IlastikPixelClassifierModel, self).__init__(params, **kwargs) @staticmethod diff --git a/model_server/extensions/ilastik/router.py b/model_server/extensions/ilastik/router.py index ea8527985593b348749f8be7694196fc2dcb3199..e4412221fb1fcfed9c8c3917095457192170fed1 100644 --- a/model_server/extensions/ilastik/router.py +++ b/model_server/extensions/ilastik/router.py @@ -1,4 +1,5 @@ from fastapi import APIRouter +from pydantic import BaseModel, Field from model_server.base.session import session @@ -13,8 +14,20 @@ router = APIRouter( import model_server.extensions.ilastik.pipelines.px_then_ob router.include_router(model_server.extensions.ilastik.pipelines.px_then_ob.router) +# TODO: move params models to router only; model classes to be created only with dict +class IlastikParams(BaseModel): + project_file: str = Field(description='(*.ilp) ilastik project filename') + duplicate: bool = Field( + True, + description='Load another instance of the same project file if True; return existing one if False' + ) + +class IlastikPixelClassifierParams(IlastikParams): + px_class: int = 0 + px_prob_threshold: float = 0.5 + @router.put('/seg/load/') -def load_px_model(p: ilm.IlastikPixelClassifierParams, model_id=None) -> dict: +def load_px_model(p: IlastikPixelClassifierParams, model_id=None) -> dict: """ Load an ilastik pixel classifier model from its project file """ @@ -25,7 +38,7 @@ def load_px_model(p: ilm.IlastikPixelClassifierParams, model_id=None) -> dict: ) @router.put('/pxmap_to_obj/load/') -def load_pxmap_to_obj_model(p: ilm.IlastikParams, model_id=None) -> dict: +def load_pxmap_to_obj_model(p: IlastikParams, model_id=None) -> dict: """ Load an ilastik object classifier from pixel predictions model from its project file """ @@ -36,7 +49,7 @@ def load_pxmap_to_obj_model(p: ilm.IlastikParams, model_id=None) -> dict: ) @router.put('/seg_to_obj/load/') -def load_seg_to_obj_model(p: ilm.IlastikParams, model_id=None) -> dict: +def load_seg_to_obj_model(p: IlastikParams, model_id=None) -> dict: """ Load an ilastik object classifier from segmentation model from its project file """ @@ -46,13 +59,13 @@ def load_seg_to_obj_model(p: ilm.IlastikParams, model_id=None) -> dict: model_id=model_id, ) -def load_ilastik_model(model_class: ilm.IlastikModel, p: ilm.IlastikParams, model_id=None) -> dict: - project_file = p.project_file +def load_ilastik_model(model_class: ilm.IlastikModel, p: IlastikParams, model_id=None) -> dict: + pf = p.project_file if not p.duplicate: - existing_model_id = session.find_param_in_loaded_models('project_file', project_file, is_path=True) + existing_model_id = session.find_param_in_loaded_models('project_file', pf, is_path=True) if existing_model_id is not None: - session.log_info(f'An ilastik model from {project_file} already existing exists; did not load a duplicate') + session.log_info(f'An ilastik model from {pf} already existing exists; did not load a duplicate') return {'model_id': existing_model_id} result = session.load_model(model_class, key=model_id, params=p) - session.log_info(f'Loaded ilastik model {result} from {project_file}') + session.log_info(f'Loaded ilastik model {result} from {pf}') return {'model_id': result} \ No newline at end of file diff --git a/tests/test_ilastik/test_ilastik.py b/tests/test_ilastik/test_ilastik.py index 40ba7e0b4b9a5e4e1e6ed98940761eafa0096a6f..a6f467f9b7da2c30f14f07c09fc1a5f32db40f4a 100644 --- a/tests/test_ilastik/test_ilastik.py +++ b/tests/test_ilastik/test_ilastik.py @@ -29,14 +29,18 @@ class TestIlastikPixelClassification(unittest.TestCase): self.cf = CziImageFileAccessor(czifile['path']) self.channel = 0 self.model = ilm.IlastikPixelClassifierModel( - params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px']['path'].__str__()) + params={ + 'project_file': ilastik_classifiers['px']['path'].__str__(), + 'px_class': 0, + 'px_prob_threshold': 0.5, + } ) self.mono_image = self.cf.get_mono(self.channel) def test_raise_error_if_autoload_disabled(self): model = ilm.IlastikPixelClassifierModel( - params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px']['path'].__str__()), + params={'project_file': ilastik_classifiers['px']['path'].__str__()}, autoload=False ) w = 512 @@ -80,11 +84,12 @@ class TestIlastikPixelClassification(unittest.TestCase): def test_label_pixels_with_params(self): def _run_seg(tr, sig): mod = ilm.IlastikPixelClassifierModel( - params=ilm.IlastikPixelClassifierParams( - project_file=ilastik_classifiers['px']['path'].__str__(), - px_prob_threshold=tr, - px_smoothing=sig, - ), + params={ + 'project_file': ilastik_classifiers['px']['path'].__str__(), + 'px_class': 0, + 'px_prob_threshold': tr, + 'px_smoothing': sig, + }, ) mask = mod.label_pixel_class(self.mono_image) write_accessor_data_to_file( @@ -154,7 +159,7 @@ class TestIlastikPixelClassification(unittest.TestCase): self.test_run_pixel_classifier() fp = czifile['path'] model = ilm.IlastikObjectClassifierFromPixelPredictionsModel( - params=ilm.IlastikParams(project_file=ilastik_classifiers['pxmap_to_obj']['path'].__str__()) + params={'project_file': ilastik_classifiers['pxmap_to_obj']['path'].__str__()} ) mask = self.model.label_pixel_class(self.mono_image) objmap, _ = model.infer(self.mono_image, mask) @@ -172,7 +177,7 @@ class TestIlastikPixelClassification(unittest.TestCase): self.test_run_pixel_classifier() fp = czifile['path'] model = ilm.IlastikObjectClassifierFromSegmentationModel( - params=ilm.IlastikParams(project_file=ilastik_classifiers['seg_to_obj']['path'].__str__()) + params={'project_file': ilastik_classifiers['seg_to_obj']['path'].__str__()} ) mask = self.model.label_pixel_class(self.mono_image) objmap = model.label_instance_class(self.mono_image, mask) @@ -191,11 +196,11 @@ class TestIlastikPixelClassification(unittest.TestCase): 'accessor': generate_file_accessor(czifile['path']) }, models={ - 'model': ilm.IlastikPixelClassifierModel( - params=ilm.IlastikPixelClassifierParams( - project_file=ilastik_classifiers['px']['path'].__str__() - ), - ), + 'model': ilm.IlastikPixelClassifierModel({ + 'project_file': ilastik_classifiers['px']['path'].__str__(), + 'px_class': 0, + 'px_prob_threshold': 0.5, + }), }, channel=0, ) @@ -327,7 +332,7 @@ class TestIlastikOnMultichannelInputs(TestServerTestCase): def test_classify_pixels(self): img = generate_file_accessor(self.pa_input_image) self.assertGreater(img.chroma, 1) - mod = ilm.IlastikPixelClassifierModel(ilm.IlastikPixelClassifierParams(project_file=self.pa_px_classifier.__str__())) + mod = ilm.IlastikPixelClassifierModel({'project_file': self.pa_px_classifier.__str__()}) pxmap = mod.infer(img)[0] self.assertEqual(pxmap.hw, img.hw) self.assertEqual(pxmap.nz, img.nz) @@ -337,7 +342,7 @@ class TestIlastikOnMultichannelInputs(TestServerTestCase): pxmap = self.test_classify_pixels() img = generate_file_accessor(self.pa_input_image) mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel( - ilm.IlastikParams(project_file=self.pa_ob_pxmap_classifier.__str__()) + {'project_file': self.pa_ob_pxmap_classifier.__str__()} ) obmap = mod.infer(img, pxmap)[0] self.assertEqual(obmap.hw, img.hw) @@ -354,10 +359,10 @@ class TestIlastikOnMultichannelInputs(TestServerTestCase): }, models={ 'px_model': ilm.IlastikPixelClassifierModel( - ilm.IlastikParams(project_file=self.pa_px_classifier.__str__()), + {'project_file': self.pa_px_classifier.__str__()}, ), 'ob_model': ilm.IlastikObjectClassifierFromPixelPredictionsModel( - ilm.IlastikParams(project_file=self.pa_ob_pxmap_classifier.__str__()), + {'project_file': self.pa_ob_pxmap_classifier.__str__()} ) }, channel=channel, @@ -428,7 +433,7 @@ class TestIlastikObjectClassification(unittest.TestCase): ) self.classifier = ilm.IlastikObjectClassifierFromSegmentationModel( - params=ilm.IlastikParams(project_file=ilastik_classifiers['seg_to_obj']['path'].__str__()), + params={'project_file': ilastik_classifiers['seg_to_obj']['path'].__str__()}, ) self.raw = self.roiset.get_patches_acc() self.masks = self.roiset.get_patch_masks_acc()