Skip to content
Snippets Groups Projects
Commit 67e20ac1 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

ilastik extension now manages its own API

parent de6f2ecd
No related branches found
No related tags found
No related merge requests found
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI
from extensions.ilastik.models import IlastikImageToImageModel, IlastikPixelClassifierModel, IlastikObjectClassifierModel
from model_server.models import DummyImageToImageModel, ParameterExpectedError
from model_server.models import DummyImageToImageModel
from model_server.session import Session
from model_server.validators import validate_workflow_inputs
from model_server.workflows import infer_image_to_image
from extensions.ilastik.workflows import infer_px_then_ob_model
app = FastAPI(debug=True)
session = Session()
import extensions.ilastik.router
app.include_router(extensions.ilastik.router.router)
@app.on_event("startup")
def startup():
pass
......@@ -38,54 +41,6 @@ def list_active_models():
def load_dummy_model() -> dict:
return {'model_id': session.load_model(DummyImageToImageModel)}
def load_ilastik_model(model_class: IlastikImageToImageModel, project_file: str, duplicate=True) -> dict:
"""
Load an ilastik model of a given class and project filename.
:param model_class:
:param project_file: (*.ilp) ilastik project filename
:param duplicate: load another instance of the same project file if True; return existing one if false
:return: dictionary with single key describing model's ID
"""
if not duplicate:
existing_model = session.find_param_in_loaded_models('project_file', project_file)
if existing_model is not None:
return existing_model
try:
result = {
'model_id': session.load_model(
model_class,
{'project_file': project_file}
)
}
except (FileNotFoundError, ParameterExpectedError):
raise HTTPException(
status_code=404,
detail=f'Could not load project file {project_file}',
)
return result
@app.put('/models/ilastik/pixel_classification/load/')
def load_ilastik_pixel_classification_model(project_file: str, duplicate: bool = True) -> dict:
return load_ilastik_model(IlastikPixelClassifierModel, project_file, duplicate=duplicate)
@app.put('/models/ilastik/object_classification/load/')
def load_ilastik_object_classification_model(project_file: str, duplicate: bool = True) -> dict:
return load_ilastik_model(IlastikObjectClassifierModel, project_file, duplicate=duplicate)
def validate_workflow_inputs(model_ids, inpaths):
for mid in model_ids:
if mid not in session.describe_loaded_models().keys():
raise HTTPException(
status_code=409,
detail=f'Model {mid} has not been loaded'
)
for inpa in inpaths:
if not inpa.exists():
raise HTTPException(
status_code=404,
detail=f'Could not find file:\n{inpa}'
)
@app.put('/infer/from_image_file')
def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
inpath = session.paths['inbound_images'] / input_filename
......@@ -97,20 +52,4 @@ def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
channel=channel,
)
session.record_workflow_run(record)
return record
@app.put('/models/ilastik/pixel_then_object_classification/infer')
def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: str, channel: int = None) -> dict:
inpath = session.paths['inbound_images'] / input_filename
validate_workflow_inputs([px_model_id, ob_model_id], [inpath])
try:
record = infer_px_then_ob_model(
inpath,
session.models[px_model_id]['object'],
session.models[ob_model_id]['object'],
session.paths['outbound_images'],
channel=channel
)
except AssertionError:
raise HTTPException(f'Incompatible models {px_model_id} and/or {ob_model_id}')
return record
\ No newline at end of file
from fastapi import APIRouter, HTTPException
from model_server.session import Session
from model_server.validators import validate_workflow_inputs
from extensions.ilastik.models import IlastikImageToImageModel, IlastikPixelClassifierModel, IlastikObjectClassifierModel
from model_server.models import ParameterExpectedError
from extensions.ilastik.workflows import infer_px_then_ob_model
router = APIRouter(
prefix='/ilastik',
tags=['ilastik'],
)
session = Session()
def load_ilastik_model(model_class: IlastikImageToImageModel, project_file: str, duplicate=True) -> dict:
"""
Load an ilastik model of a given class and project filename.
:param model_class:
:param project_file: (*.ilp) ilastik project filename
:param duplicate: load another instance of the same project file if True; return existing one if false
:return: dictionary with single key describing model's ID
"""
if not duplicate:
existing_model = session.find_param_in_loaded_models('project_file', project_file)
if existing_model is not None:
return existing_model
try:
result = {
'model_id': session.load_model(
model_class,
{'project_file': project_file}
)
}
except (FileNotFoundError, ParameterExpectedError):
raise HTTPException(
status_code=404,
detail=f'Could not load project file {project_file}',
)
return result
@router.put('/pixel_classification/load/')
def load_ilastik_pixel_classification_model(project_file: str, duplicate: bool = True) -> dict:
return load_ilastik_model(IlastikPixelClassifierModel, project_file, duplicate=duplicate)
@router.put('/object_classification/load/')
def load_ilastik_object_classification_model(project_file: str, duplicate: bool = True) -> dict:
return load_ilastik_model(IlastikObjectClassifierModel, project_file, duplicate=duplicate)
@router.put('/pixel_then_object_classification/infer')
def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: str, channel: int = None) -> dict:
inpath = session.paths['inbound_images'] / input_filename
validate_workflow_inputs([px_model_id, ob_model_id], [inpath])
try:
record = infer_px_then_ob_model(
inpath,
session.models[px_model_id]['object'],
session.models[ob_model_id]['object'],
session.paths['outbound_images'],
channel=channel
)
except AssertionError:
raise HTTPException(f'Incompatible models {px_model_id} and/or {ob_model_id}')
return record
\ No newline at end of file
......@@ -112,7 +112,7 @@ class TestIlastikOverApi(TestServerBaseClass):
def test_httpexception_if_incorrect_project_file_loaded(self):
resp_load = requests.put(
self.uri + 'models/ilastik/pixel_classification/load/',
self.uri + 'ilastik/pixel_classification/load/',
params={'project_file': 'improper.ilp'},
)
self.assertEqual(resp_load.status_code, 404)
......@@ -120,7 +120,7 @@ class TestIlastikOverApi(TestServerBaseClass):
def test_load_ilastik_pixel_model(self):
resp_load = requests.put(
self.uri + 'models/ilastik/pixel_classification/load/',
self.uri + 'ilastik/pixel_classification/load/',
params={'project_file': str(conf.testing.ilastik['pixel_classifier'])},
)
model_id = resp_load.json()['model_id']
......@@ -137,7 +137,7 @@ class TestIlastikOverApi(TestServerBaseClass):
resp_list_1st = requests.get(self.uri + 'models').json()
self.assertEqual(len(resp_list_1st), 1, resp_list_1st)
resp_load_2nd = requests.put(
self.uri + 'models/ilastik/pixel_classification/load/',
self.uri + 'ilastik/pixel_classification/load/',
params={
'project_file': str(conf.testing.ilastik['pixel_classifier']),
'duplicate': True,
......@@ -146,7 +146,7 @@ class TestIlastikOverApi(TestServerBaseClass):
resp_list_2nd = requests.get(self.uri + 'models').json()
self.assertEqual(len(resp_list_2nd), 2, resp_list_2nd)
resp_load_3rd = requests.put(
self.uri + 'models/ilastik/pixel_classification/load/',
self.uri + 'ilastik/pixel_classification/load/',
params={
'project_file': str(conf.testing.ilastik['pixel_classifier']),
'duplicate': False,
......@@ -158,7 +158,7 @@ class TestIlastikOverApi(TestServerBaseClass):
def test_load_ilastik_object_model(self):
resp_load = requests.put(
self.uri + 'models/ilastik/object_classification/load/',
self.uri + 'ilastik/object_classification/load/',
params={'project_file': str(conf.testing.ilastik['object_classifier'])},
)
model_id = resp_load.json()['model_id']
......@@ -190,7 +190,7 @@ class TestIlastikOverApi(TestServerBaseClass):
ob_model_id = self.test_load_ilastik_object_model()
resp_infer = requests.put(
self.uri + f'models/ilastik/pixel_then_object_classification/infer/',
self.uri + f'ilastik/pixel_then_object_classification/infer/',
params={
'px_model_id': px_model_id,
'ob_model_id': ob_model_id,
......
from fastapi import HTTPException
from model_server.session import Session
session = Session()
def validate_workflow_inputs(model_ids, inpaths):
for mid in model_ids:
if mid not in session.describe_loaded_models().keys():
raise HTTPException(
status_code=409,
detail=f'Model {mid} has not been loaded'
)
for inpa in inpaths:
if not inpa.exists():
raise HTTPException(
status_code=404,
detail=f'Could not find file:\n{inpa}'
)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment