Skip to content
Snippets Groups Projects
Commit bfcd18af authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Tests parameterize models directly with dicts, not pydantic models, which are...

Tests parameterize models directly with dicts, not pydantic models, which are reserved for API calls
parent ef31604c
No related branches found
No related tags found
No related merge requests found
...@@ -36,7 +36,7 @@ def segment_pipeline( ...@@ -36,7 +36,7 @@ def segment_pipeline(
model = models.get('model') model = models.get('model')
if not isinstance(model, SemanticSegmentationModel): if not isinstance(model, SemanticSegmentationModel):
raise IncompatibleModelsError('Expecting a pixel classification model') raise IncompatibleModelsError('Expecting a semantic segmentation model')
if ch := k.get('channel') is not None: if ch := k.get('channel') is not None:
d['mono'] = d['input'].get_mono(ch) d['mono'] = d['input'].get_mono(ch)
......
...@@ -251,7 +251,7 @@ class _Session(object): ...@@ -251,7 +251,7 @@ class _Session(object):
:param params: optional parameters that are passed to the model's construct :param params: optional parameters that are passed to the model's construct
:return: model_id of loaded model :return: model_id of loaded model
""" """
mi = ModelClass(params=params) mi = ModelClass(params=params.dict())
assert mi.loaded, f'Error loading instance of {ModelClass.__name__}' assert mi.loaded, f'Error loading instance of {ModelClass.__name__}'
ii = 0 ii = 0
......
...@@ -2,11 +2,9 @@ import json ...@@ -2,11 +2,9 @@ import json
from logging import getLogger from logging import getLogger
import os import os
from pathlib import Path from pathlib import Path
from typing import Union
import warnings import warnings
import numpy as np import numpy as np
from pydantic import BaseModel, Field
import vigra import vigra
import model_server.extensions.ilastik.conf import model_server.extensions.ilastik.conf
...@@ -14,17 +12,10 @@ from ...base.accessors import PatchStack ...@@ -14,17 +12,10 @@ from ...base.accessors import PatchStack
from ...base.accessors import GenericImageDataAccessor, InMemoryDataAccessor from ...base.accessors import GenericImageDataAccessor, InMemoryDataAccessor
from ...base.models import Model, ImageToImageModel, InstanceSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel from ...base.models import Model, ImageToImageModel, InstanceSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel
# TODO: move params models to router only; model classes to be created only with dict
class IlastikParams(BaseModel):
project_file: str = Field(description='(*.ilp) ilastik project filename')
duplicate: bool = Field(
True,
description='Load another instance of the same project file if True; return existing one if False'
)
class IlastikModel(Model): class IlastikModel(Model):
def __init__(self, params: IlastikParams, autoload=True, enforce_embedded=True): def __init__(self, params: dict, autoload=True, enforce_embedded=True):
""" """
Base class for models that run via ilastik shell API Base class for models that run via ilastik shell API
:param params: :param params:
...@@ -33,7 +24,7 @@ class IlastikModel(Model): ...@@ -33,7 +24,7 @@ class IlastikModel(Model):
:param enforce_embedded: :param enforce_embedded:
raise an error if all input data are not embedded in the project file, i.e. on the filesystem raise an error if all input data are not embedded in the project file, i.e. on the filesystem
""" """
pf = Path(params.project_file) pf = Path(params['project_file'])
self.enforce_embedded = enforce_embedded self.enforce_embedded = enforce_embedded
if pf.is_absolute(): if pf.is_absolute():
pap = pf pap = pf
...@@ -103,14 +94,11 @@ class IlastikModel(Model): ...@@ -103,14 +94,11 @@ class IlastikModel(Model):
def model_3d(self): def model_3d(self):
return self.model_shape_dict['Z'] > 1 return self.model_shape_dict['Z'] > 1
class IlastikPixelClassifierParams(IlastikParams):
px_class: int = 0
px_prob_threshold: float = 0.5
class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
operations = ['segment', ] operations = ['segment', ]
def __init__(self, params: IlastikPixelClassifierParams, **kwargs): def __init__(self, params: dict, **kwargs):
super(IlastikPixelClassifierModel, self).__init__(params, **kwargs) super(IlastikPixelClassifierModel, self).__init__(params, **kwargs)
@staticmethod @staticmethod
......
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel, Field
from model_server.base.session import session from model_server.base.session import session
...@@ -13,8 +14,20 @@ router = APIRouter( ...@@ -13,8 +14,20 @@ router = APIRouter(
import model_server.extensions.ilastik.pipelines.px_then_ob import model_server.extensions.ilastik.pipelines.px_then_ob
router.include_router(model_server.extensions.ilastik.pipelines.px_then_ob.router) router.include_router(model_server.extensions.ilastik.pipelines.px_then_ob.router)
# TODO: move params models to router only; model classes to be created only with dict
class IlastikParams(BaseModel):
project_file: str = Field(description='(*.ilp) ilastik project filename')
duplicate: bool = Field(
True,
description='Load another instance of the same project file if True; return existing one if False'
)
class IlastikPixelClassifierParams(IlastikParams):
px_class: int = 0
px_prob_threshold: float = 0.5
@router.put('/seg/load/') @router.put('/seg/load/')
def load_px_model(p: ilm.IlastikPixelClassifierParams, model_id=None) -> dict: def load_px_model(p: IlastikPixelClassifierParams, model_id=None) -> dict:
""" """
Load an ilastik pixel classifier model from its project file Load an ilastik pixel classifier model from its project file
""" """
...@@ -25,7 +38,7 @@ def load_px_model(p: ilm.IlastikPixelClassifierParams, model_id=None) -> dict: ...@@ -25,7 +38,7 @@ def load_px_model(p: ilm.IlastikPixelClassifierParams, model_id=None) -> dict:
) )
@router.put('/pxmap_to_obj/load/') @router.put('/pxmap_to_obj/load/')
def load_pxmap_to_obj_model(p: ilm.IlastikParams, model_id=None) -> dict: def load_pxmap_to_obj_model(p: IlastikParams, model_id=None) -> dict:
""" """
Load an ilastik object classifier from pixel predictions model from its project file Load an ilastik object classifier from pixel predictions model from its project file
""" """
...@@ -36,7 +49,7 @@ def load_pxmap_to_obj_model(p: ilm.IlastikParams, model_id=None) -> dict: ...@@ -36,7 +49,7 @@ def load_pxmap_to_obj_model(p: ilm.IlastikParams, model_id=None) -> dict:
) )
@router.put('/seg_to_obj/load/') @router.put('/seg_to_obj/load/')
def load_seg_to_obj_model(p: ilm.IlastikParams, model_id=None) -> dict: def load_seg_to_obj_model(p: IlastikParams, model_id=None) -> dict:
""" """
Load an ilastik object classifier from segmentation model from its project file Load an ilastik object classifier from segmentation model from its project file
""" """
...@@ -46,13 +59,13 @@ def load_seg_to_obj_model(p: ilm.IlastikParams, model_id=None) -> dict: ...@@ -46,13 +59,13 @@ def load_seg_to_obj_model(p: ilm.IlastikParams, model_id=None) -> dict:
model_id=model_id, model_id=model_id,
) )
def load_ilastik_model(model_class: ilm.IlastikModel, p: ilm.IlastikParams, model_id=None) -> dict: def load_ilastik_model(model_class: ilm.IlastikModel, p: IlastikParams, model_id=None) -> dict:
project_file = p.project_file pf = p.project_file
if not p.duplicate: if not p.duplicate:
existing_model_id = session.find_param_in_loaded_models('project_file', project_file, is_path=True) existing_model_id = session.find_param_in_loaded_models('project_file', pf, is_path=True)
if existing_model_id is not None: if existing_model_id is not None:
session.log_info(f'An ilastik model from {project_file} already existing exists; did not load a duplicate') session.log_info(f'An ilastik model from {pf} already existing exists; did not load a duplicate')
return {'model_id': existing_model_id} return {'model_id': existing_model_id}
result = session.load_model(model_class, key=model_id, params=p) result = session.load_model(model_class, key=model_id, params=p)
session.log_info(f'Loaded ilastik model {result} from {project_file}') session.log_info(f'Loaded ilastik model {result} from {pf}')
return {'model_id': result} return {'model_id': result}
\ No newline at end of file
...@@ -29,14 +29,18 @@ class TestIlastikPixelClassification(unittest.TestCase): ...@@ -29,14 +29,18 @@ class TestIlastikPixelClassification(unittest.TestCase):
self.cf = CziImageFileAccessor(czifile['path']) self.cf = CziImageFileAccessor(czifile['path'])
self.channel = 0 self.channel = 0
self.model = ilm.IlastikPixelClassifierModel( self.model = ilm.IlastikPixelClassifierModel(
params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px']['path'].__str__()) params={
'project_file': ilastik_classifiers['px']['path'].__str__(),
'px_class': 0,
'px_prob_threshold': 0.5,
}
) )
self.mono_image = self.cf.get_mono(self.channel) self.mono_image = self.cf.get_mono(self.channel)
def test_raise_error_if_autoload_disabled(self): def test_raise_error_if_autoload_disabled(self):
model = ilm.IlastikPixelClassifierModel( model = ilm.IlastikPixelClassifierModel(
params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px']['path'].__str__()), params={'project_file': ilastik_classifiers['px']['path'].__str__()},
autoload=False autoload=False
) )
w = 512 w = 512
...@@ -80,11 +84,12 @@ class TestIlastikPixelClassification(unittest.TestCase): ...@@ -80,11 +84,12 @@ class TestIlastikPixelClassification(unittest.TestCase):
def test_label_pixels_with_params(self): def test_label_pixels_with_params(self):
def _run_seg(tr, sig): def _run_seg(tr, sig):
mod = ilm.IlastikPixelClassifierModel( mod = ilm.IlastikPixelClassifierModel(
params=ilm.IlastikPixelClassifierParams( params={
project_file=ilastik_classifiers['px']['path'].__str__(), 'project_file': ilastik_classifiers['px']['path'].__str__(),
px_prob_threshold=tr, 'px_class': 0,
px_smoothing=sig, 'px_prob_threshold': tr,
), 'px_smoothing': sig,
},
) )
mask = mod.label_pixel_class(self.mono_image) mask = mod.label_pixel_class(self.mono_image)
write_accessor_data_to_file( write_accessor_data_to_file(
...@@ -154,7 +159,7 @@ class TestIlastikPixelClassification(unittest.TestCase): ...@@ -154,7 +159,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
self.test_run_pixel_classifier() self.test_run_pixel_classifier()
fp = czifile['path'] fp = czifile['path']
model = ilm.IlastikObjectClassifierFromPixelPredictionsModel( model = ilm.IlastikObjectClassifierFromPixelPredictionsModel(
params=ilm.IlastikParams(project_file=ilastik_classifiers['pxmap_to_obj']['path'].__str__()) params={'project_file': ilastik_classifiers['pxmap_to_obj']['path'].__str__()}
) )
mask = self.model.label_pixel_class(self.mono_image) mask = self.model.label_pixel_class(self.mono_image)
objmap, _ = model.infer(self.mono_image, mask) objmap, _ = model.infer(self.mono_image, mask)
...@@ -172,7 +177,7 @@ class TestIlastikPixelClassification(unittest.TestCase): ...@@ -172,7 +177,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
self.test_run_pixel_classifier() self.test_run_pixel_classifier()
fp = czifile['path'] fp = czifile['path']
model = ilm.IlastikObjectClassifierFromSegmentationModel( model = ilm.IlastikObjectClassifierFromSegmentationModel(
params=ilm.IlastikParams(project_file=ilastik_classifiers['seg_to_obj']['path'].__str__()) params={'project_file': ilastik_classifiers['seg_to_obj']['path'].__str__()}
) )
mask = self.model.label_pixel_class(self.mono_image) mask = self.model.label_pixel_class(self.mono_image)
objmap = model.label_instance_class(self.mono_image, mask) objmap = model.label_instance_class(self.mono_image, mask)
...@@ -191,11 +196,11 @@ class TestIlastikPixelClassification(unittest.TestCase): ...@@ -191,11 +196,11 @@ class TestIlastikPixelClassification(unittest.TestCase):
'accessor': generate_file_accessor(czifile['path']) 'accessor': generate_file_accessor(czifile['path'])
}, },
models={ models={
'model': ilm.IlastikPixelClassifierModel( 'model': ilm.IlastikPixelClassifierModel({
params=ilm.IlastikPixelClassifierParams( 'project_file': ilastik_classifiers['px']['path'].__str__(),
project_file=ilastik_classifiers['px']['path'].__str__() 'px_class': 0,
), 'px_prob_threshold': 0.5,
), }),
}, },
channel=0, channel=0,
) )
...@@ -327,7 +332,7 @@ class TestIlastikOnMultichannelInputs(TestServerTestCase): ...@@ -327,7 +332,7 @@ class TestIlastikOnMultichannelInputs(TestServerTestCase):
def test_classify_pixels(self): def test_classify_pixels(self):
img = generate_file_accessor(self.pa_input_image) img = generate_file_accessor(self.pa_input_image)
self.assertGreater(img.chroma, 1) self.assertGreater(img.chroma, 1)
mod = ilm.IlastikPixelClassifierModel(ilm.IlastikPixelClassifierParams(project_file=self.pa_px_classifier.__str__())) mod = ilm.IlastikPixelClassifierModel({'project_file': self.pa_px_classifier.__str__()})
pxmap = mod.infer(img)[0] pxmap = mod.infer(img)[0]
self.assertEqual(pxmap.hw, img.hw) self.assertEqual(pxmap.hw, img.hw)
self.assertEqual(pxmap.nz, img.nz) self.assertEqual(pxmap.nz, img.nz)
...@@ -337,7 +342,7 @@ class TestIlastikOnMultichannelInputs(TestServerTestCase): ...@@ -337,7 +342,7 @@ class TestIlastikOnMultichannelInputs(TestServerTestCase):
pxmap = self.test_classify_pixels() pxmap = self.test_classify_pixels()
img = generate_file_accessor(self.pa_input_image) img = generate_file_accessor(self.pa_input_image)
mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel( mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel(
ilm.IlastikParams(project_file=self.pa_ob_pxmap_classifier.__str__()) {'project_file': self.pa_ob_pxmap_classifier.__str__()}
) )
obmap = mod.infer(img, pxmap)[0] obmap = mod.infer(img, pxmap)[0]
self.assertEqual(obmap.hw, img.hw) self.assertEqual(obmap.hw, img.hw)
...@@ -354,10 +359,10 @@ class TestIlastikOnMultichannelInputs(TestServerTestCase): ...@@ -354,10 +359,10 @@ class TestIlastikOnMultichannelInputs(TestServerTestCase):
}, },
models={ models={
'px_model': ilm.IlastikPixelClassifierModel( 'px_model': ilm.IlastikPixelClassifierModel(
ilm.IlastikParams(project_file=self.pa_px_classifier.__str__()), {'project_file': self.pa_px_classifier.__str__()},
), ),
'ob_model': ilm.IlastikObjectClassifierFromPixelPredictionsModel( 'ob_model': ilm.IlastikObjectClassifierFromPixelPredictionsModel(
ilm.IlastikParams(project_file=self.pa_ob_pxmap_classifier.__str__()), {'project_file': self.pa_ob_pxmap_classifier.__str__()}
) )
}, },
channel=channel, channel=channel,
...@@ -428,7 +433,7 @@ class TestIlastikObjectClassification(unittest.TestCase): ...@@ -428,7 +433,7 @@ class TestIlastikObjectClassification(unittest.TestCase):
) )
self.classifier = ilm.IlastikObjectClassifierFromSegmentationModel( self.classifier = ilm.IlastikObjectClassifierFromSegmentationModel(
params=ilm.IlastikParams(project_file=ilastik_classifiers['seg_to_obj']['path'].__str__()), params={'project_file': ilastik_classifiers['seg_to_obj']['path'].__str__()},
) )
self.raw = self.roiset.get_patches_acc() self.raw = self.roiset.get_patches_acc()
self.masks = self.roiset.get_patch_masks_acc() self.masks = self.roiset.get_patch_masks_acc()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment