Skip to content
Snippets Groups Projects
Commit 20e3a20e authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Option to specify model ID

parent 7f6e43da
No related branches found
No related tags found
No related merge requests found
......@@ -131,10 +131,16 @@ class _Session(object):
def log_error(self, msg):
logger.error(msg)
def load_model(self, ModelClass: Model, params: Union[BaseModel, None] = None) -> dict:
def load_model(
self,
ModelClass: Model,
key: Union[str, None] = None,
params: Union[BaseModel, None] = None,
) -> dict:
"""
Load an instance of a given model class and attach to this session's model registry
:param ModelClass: subclass of Model
:param key: unique identifier of model, or autogenerate if None
:param params: optional parameters that are passed to the model's construct
:return: model_id of loaded model
"""
......@@ -142,13 +148,17 @@ class _Session(object):
assert mi.loaded, f'Error loading instance of {ModelClass.__name__}'
ii = 0
def mid(i):
return f'{ModelClass.__name__}_{i:02d}'
if key is None:
def mid(i):
return f'{ModelClass.__name__}_{i:02d}'
while mid(ii) in self.models.keys():
ii += 1
while mid(ii) in self.models.keys():
ii += 1
key = mid(ii)
elif key in self.models.keys():
raise CouldNotInstantiateModelError(f'Model with key {key} already exists.')
key = mid(ii)
self.models[key] = {
'object': mi,
'params': getattr(mi, 'params', None)
......
......@@ -2,10 +2,11 @@ import json
from logging import getLogger
import os
from pathlib import Path
from typing import Union
import warnings
import numpy as np
from pydantic import BaseModel
from pydantic import BaseModel, Field
import vigra
import model_server.extensions.ilastik.conf
......@@ -14,8 +15,12 @@ from ...base.accessors import GenericImageDataAccessor, InMemoryDataAccessor
from ...base.models import Model, ImageToImageModel, InstanceSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel
class IlastikParams(BaseModel):
project_file: str
duplicate: bool = True
project_file: str = Field(description='(*.ilp) ilastik project filename')
duplicate: bool = Field(
True,
description='Load another instance of the same project file if True; return existing one if False'
)
model_id: Union[str, None] = Field(None, description='Unique identifier of the model, or autogenerate if empty')
class IlastikModel(Model):
......
......@@ -13,43 +13,43 @@ router = APIRouter(
import model_server.extensions.ilastik.pipelines.px_then_ob
router.include_router(model_server.extensions.ilastik.pipelines.px_then_ob.router)
def load_ilastik_model(model_class: ilm.IlastikModel, params: ilm.IlastikParams) -> dict:
@router.put('/seg/load/')
def load_px_model(p: ilm.IlastikPixelClassifierParams) -> dict:
"""
Load an ilastik model of a given class and project filename.
:param model_class:
:param project_file: (*.ilp) ilastik project filename
:param duplicate: load another instance of the same project file if True; return existing one if false
:return: dict containing model's ID
Load an ilastik pixel classifier model from its project file
"""
project_file = params.project_file
if not params.duplicate:
existing_model_id = session.find_param_in_loaded_models('project_file', project_file, is_path=True)
if existing_model_id is not None:
session.log_info(f'An ilastik model from {project_file} already existing exists; did not load a duplicate')
return {'model_id': existing_model_id}
result = session.load_model(model_class, params)
session.log_info(f'Loaded ilastik model {result} from {project_file}')
return {'model_id': result}
@router.put('/seg/load/')
# TODO: optionally let client name model
def load_px_model(params: ilm.IlastikPixelClassifierParams) -> dict:
return load_ilastik_model(
ilm.IlastikPixelClassifierModel,
params,
p,
)
@router.put('/pxmap_to_obj/load/')
def load_pxmap_to_obj_model(params: ilm.IlastikParams) -> dict:
def load_pxmap_to_obj_model(p: ilm.IlastikParams) -> dict:
"""
Load an ilastik object classifier from pixel predictions model from its project file
"""
return load_ilastik_model(
ilm.IlastikObjectClassifierFromPixelPredictionsModel,
params,
p,
)
@router.put('/seg_to_obj/load/')
def load_seg_to_obj_model(params: ilm.IlastikParams) -> dict:
def load_seg_to_obj_model(p: ilm.IlastikParams) -> dict:
"""
Load an ilastik object classifier from segmentation model from its project file
"""
return load_ilastik_model(
ilm.IlastikObjectClassifierFromSegmentationModel,
params,
p,
)
def load_ilastik_model(model_class: ilm.IlastikModel, p: ilm.IlastikParams) -> dict:
project_file = p.project_file
if not p.duplicate:
existing_model_id = session.find_param_in_loaded_models('project_file', project_file, is_path=True)
if existing_model_id is not None:
session.log_info(f'An ilastik model from {project_file} already existing exists; did not load a duplicate')
return {'model_id': existing_model_id}
result = session.load_model(model_class, key=p.model_id, params=p)
session.log_info(f'Loaded ilastik model {result} from {project_file}')
return {'model_id': result}
\ No newline at end of file
......@@ -271,6 +271,18 @@ class TestIlastikOverApi(conf.TestServerBaseClass):
self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierFromPixelPredictionsModel')
return model_id
def test_load_ilastik_model_with_model_id(self):
mid = 'new_model_id'
resp_load = self._put(
'ilastik/pxmap_to_obj/load/',
body={
'project_file': str(ilastik_classifiers['pxmap_to_obj']['path']),
'model_id': mid,
},
)
res_mid = resp_load.json()['model_id']
self.assertEqual(res_mid, mid)
def test_load_ilastik_seg_to_obj_model(self):
resp_load = self._put(
'ilastik/seg_to_obj/load/',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment