diff --git a/api.py b/api.py index e28446ad27ce88ff56f286f8badfa0c42f51dce8..b51364991987556684cd4156ca515b519486a726 100644 --- a/api.py +++ b/api.py @@ -6,6 +6,7 @@ from model_server.ilastik import IlastikPixelClassifierModel, IlastikObjectClass from model_server.model import DummyImageToImageModel, ParameterExpectedError from model_server.session import Session from model_server.workflow import infer_image_to_image +from model_server.workflow_ilastik import infer_px_then_ob_model app = FastAPI(debug=True) session = Session() @@ -62,19 +63,24 @@ def load_ilastik_pixel_classification_model(project_file: str) -> dict: def load_ilastik_object_classification_model(project_file: str) -> dict: return load_ilastik_model(IlastikObjectClassifierModel, project_file) +def validate_workflow_inputs(model_ids, inpaths): + for mid in model_ids: + if mid not in session.describe_loaded_models().keys(): + raise HTTPException( + status_code=409, + detail=f'Model {mid} has not been loaded' + ) + for inpa in inpaths: + if not inpa.exists(): + raise HTTPException( + status_code=404, + detail=f'Could not find file:\n{inpa}' + ) + @app.put('/infer/from_image_file') def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: - if model_id not in session.describe_loaded_models().keys(): - raise HTTPException( - status_code=409, - detail=f'Model {model_id} has not been loaded' - ) inpath = session.paths['inbound_images'] / input_filename - if not inpath.exists(): - raise HTTPException( - status_code=404, - detail=f'Could not find file:\n{inpath}' - ) + validate_workflow_inputs([model_id], [inpath]) record = infer_image_to_image( inpath, session.models[model_id]['object'], @@ -82,4 +88,20 @@ def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: channel=channel, ) session.record_workflow_run(record) + return record + +@app.put('/models/ilastik/pixel_then_object_classification/infer') +def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: str, channel: int = None) -> dict: + inpath = session.paths['inbound_images'] / input_filename + validate_workflow_inputs([px_model_id, ob_model_id], [inpath]) + try: + record = infer_px_then_ob_model( + inpath, + session.models[px_model_id]['object'], + session.models[ob_model_id]['object'], + session.paths['outbound_images'], + channel=channel + ) + except AssertionError: + raise HTTPException(f'Incompatible models {px_model_id} and/or {ob_model_id}') return record \ No newline at end of file diff --git a/model_server/workflow.py b/model_server/workflow.py index a89f900a95536da22bb3acea88369648fbb7bec8..7c09bc46835eacb04a31871534a91786991267ec 100644 --- a/model_server/workflow.py +++ b/model_server/workflow.py @@ -1,11 +1,12 @@ """ Implementation of image analysis work behind API endpoints, without knowledge of persistent data in server session. """ - +from pathlib import Path from time import perf_counter from typing import Dict from model_server.image import generate_file_accessor, write_accessor_data_to_file +from model_server.model import Model from pydantic import BaseModel @@ -28,7 +29,15 @@ class WorkflowRunRecord(BaseModel): timer_results: Dict[str, float] -def infer_image_to_image(fpi, model, where_output, **kwargs) -> dict: +def infer_image_to_image(fpi: Path, model: Model, where_output: Path, **kwargs) -> WorkflowRunRecord: + """ + Generic workflow where a model processes an input image into an output image + :param fpi: Path object that references input image file + :param model: model object + :param where_output: Path object that references output image directory + :param kwargs: variable-length keyword arguments + :return: record object + """ ti = Timer() ch = kwargs.get('channel') img = generate_file_accessor(fpi).get_one_channel_data(ch) diff --git a/model_server/workflow_ilastik.py b/model_server/workflow_ilastik.py new file mode 100644 index 0000000000000000000000000000000000000000..db412bb4b9b2ba1e33b0ba42d9d5108bbfcdce5d --- /dev/null +++ b/model_server/workflow_ilastik.py @@ -0,0 +1,73 @@ +""" +Implementation of image analysis work behind API endpoints, without knowledge of persistent data in server session. +""" +from pathlib import Path +from time import perf_counter +from typing import Dict + +from model_server.ilastik import IlastikPixelClassifierModel, IlastikObjectClassifierModel +from model_server.image import generate_file_accessor, write_accessor_data_to_file +from model_server.model import Model +from model_server.workflow import Timer + +from pydantic import BaseModel + +class WorkflowRunRecord(BaseModel): + pixel_model_id: str + object_model_id: str + input_filepath: str + pixel_map_filepath: str + object_map_filepath: str + success: bool + timer_results: Dict[str, float] + + +def infer_px_then_ob_model( + fpi: Path, + px_model: IlastikPixelClassifierModel, + ob_model: IlastikObjectClassifierModel, + where_output: Path, + **kwargs +) -> WorkflowRunRecord: + """ + Workflow that specifically runs an ilastik pixel classifier, then passes results to an object classifier, + saving intermediate images + :param fpi: Path object that references input image file + :param px_model: model instance for pixel classification + :param ob_model: model instance for object classification + :param where_output: Path object that references output image directory + :param kwargs: variable-length keyword arguments + :return: + """ + assert isinstance(px_model, IlastikPixelClassifierModel) + assert isinstance(ob_model, IlastikObjectClassifierModel) + + ti = Timer() + ch = kwargs.get('channel') + img = generate_file_accessor(fpi).get_one_channel_data(ch) + ti.click('file_input') + + px_map, _ = px_model.infer(img) + ti.click('pixel_probability_inference') + + px_map_path = where_output / (px_model.model_id + '_pxmap_' + fpi.stem + '.tif') + write_accessor_data_to_file(px_map_path, px_map) + ti.click('pixel_map_output') + + ob_map, _ = ob_model.infer(img, px_map) + ti.click('object_classification') + + ob_map_path = where_output / (ob_model.model_id + '_obmap_' + fpi.stem + '.tif') + write_accessor_data_to_file(ob_map_path, ob_map) + ti.click('object_map_output') + + return WorkflowRunRecord( + pixel_model_id=px_model.model_id, + object_model_id=ob_model.model_id, + input_filepath=str(fpi), + pixel_map_filepath=str(px_map_path), + object_map_filepath=str(ob_map_path), + success=True, + timer_results=ti.events, + ) + diff --git a/tests/test_ilastik.py b/tests/test_ilastik.py index 1c55a9a63f4dfd64900e1dd25fdad7800561dee0..25c876b945fdac880f78af76e7e96edb0e4e1cf3 100644 --- a/tests/test_ilastik.py +++ b/tests/test_ilastik.py @@ -130,7 +130,6 @@ class TestIlastikOverApi(TestServerBaseClass): self.assertEqual(resp_list.status_code, 200) rj = resp_list.json() self.assertEqual(rj[model_id]['class'], 'IlastikPixelClassifierModel') - return model_id @@ -146,6 +145,7 @@ class TestIlastikOverApi(TestServerBaseClass): self.assertEqual(resp_list.status_code, 200) rj = resp_list.json() self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierModel') + return model_id def test_ilastik_infer_pixel_probability(self): self.copy_input_file_to_server() @@ -160,3 +160,19 @@ class TestIlastikOverApi(TestServerBaseClass): }, ) self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) + + def test_ilastik_infer_px_then_ob(self): + self.copy_input_file_to_server() + px_model_id = self.test_load_ilastik_pixel_model() + ob_model_id = self.test_load_ilastik_object_model() + + resp_infer = requests.put( + self.uri + f'models/ilastik/pixel_then_object_classification/infer/', + params={ + 'px_model_id': px_model_id, + 'ob_model_id': ob_model_id, + 'input_filename': conf.testing.czifile['filename'], + 'channel': 0, + } + ) + self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 9bca0b4af3d440206041ddb16c6c56891a18ed56..6bc8d56d88f960d63cdada4efe68e65c6ecd572f 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -34,6 +34,4 @@ class TestGetSessionObject(unittest.TestCase): img[0, 0], 0, 'First pixel is not black as expected' - ) - - print(result.timer_results) \ No newline at end of file + ) \ No newline at end of file