Skip to content
Snippets Groups Projects
Commit b8635fdb authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Implemented workflow and endpoint for two-stage (px -> ob) classification in...

Implemented workflow and endpoint for two-stage (px -> ob) classification in ilastik; tests pass and image data is correct
parent 74f655f8
No related branches found
No related tags found
No related merge requests found
......@@ -6,6 +6,7 @@ from model_server.ilastik import IlastikPixelClassifierModel, IlastikObjectClass
from model_server.model import DummyImageToImageModel, ParameterExpectedError
from model_server.session import Session
from model_server.workflow import infer_image_to_image
from model_server.workflow_ilastik import infer_px_then_ob_model
app = FastAPI(debug=True)
session = Session()
......@@ -62,19 +63,24 @@ def load_ilastik_pixel_classification_model(project_file: str) -> dict:
def load_ilastik_object_classification_model(project_file: str) -> dict:
return load_ilastik_model(IlastikObjectClassifierModel, project_file)
def validate_workflow_inputs(model_ids, inpaths):
for mid in model_ids:
if mid not in session.describe_loaded_models().keys():
raise HTTPException(
status_code=409,
detail=f'Model {mid} has not been loaded'
)
for inpa in inpaths:
if not inpa.exists():
raise HTTPException(
status_code=404,
detail=f'Could not find file:\n{inpa}'
)
@app.put('/infer/from_image_file')
def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
if model_id not in session.describe_loaded_models().keys():
raise HTTPException(
status_code=409,
detail=f'Model {model_id} has not been loaded'
)
inpath = session.paths['inbound_images'] / input_filename
if not inpath.exists():
raise HTTPException(
status_code=404,
detail=f'Could not find file:\n{inpath}'
)
validate_workflow_inputs([model_id], [inpath])
record = infer_image_to_image(
inpath,
session.models[model_id]['object'],
......@@ -82,4 +88,20 @@ def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
channel=channel,
)
session.record_workflow_run(record)
return record
@app.put('/models/ilastik/pixel_then_object_classification/infer')
def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: str, channel: int = None) -> dict:
inpath = session.paths['inbound_images'] / input_filename
validate_workflow_inputs([px_model_id, ob_model_id], [inpath])
try:
record = infer_px_then_ob_model(
inpath,
session.models[px_model_id]['object'],
session.models[ob_model_id]['object'],
session.paths['outbound_images'],
channel=channel
)
except AssertionError:
raise HTTPException(f'Incompatible models {px_model_id} and/or {ob_model_id}')
return record
\ No newline at end of file
"""
Implementation of image analysis work behind API endpoints, without knowledge of persistent data in server session.
"""
from pathlib import Path
from time import perf_counter
from typing import Dict
from model_server.image import generate_file_accessor, write_accessor_data_to_file
from model_server.model import Model
from pydantic import BaseModel
......@@ -28,7 +29,15 @@ class WorkflowRunRecord(BaseModel):
timer_results: Dict[str, float]
def infer_image_to_image(fpi, model, where_output, **kwargs) -> dict:
def infer_image_to_image(fpi: Path, model: Model, where_output: Path, **kwargs) -> WorkflowRunRecord:
"""
Generic workflow where a model processes an input image into an output image
:param fpi: Path object that references input image file
:param model: model object
:param where_output: Path object that references output image directory
:param kwargs: variable-length keyword arguments
:return: record object
"""
ti = Timer()
ch = kwargs.get('channel')
img = generate_file_accessor(fpi).get_one_channel_data(ch)
......
"""
Implementation of image analysis work behind API endpoints, without knowledge of persistent data in server session.
"""
from pathlib import Path
from time import perf_counter
from typing import Dict
from model_server.ilastik import IlastikPixelClassifierModel, IlastikObjectClassifierModel
from model_server.image import generate_file_accessor, write_accessor_data_to_file
from model_server.model import Model
from model_server.workflow import Timer
from pydantic import BaseModel
class WorkflowRunRecord(BaseModel):
pixel_model_id: str
object_model_id: str
input_filepath: str
pixel_map_filepath: str
object_map_filepath: str
success: bool
timer_results: Dict[str, float]
def infer_px_then_ob_model(
fpi: Path,
px_model: IlastikPixelClassifierModel,
ob_model: IlastikObjectClassifierModel,
where_output: Path,
**kwargs
) -> WorkflowRunRecord:
"""
Workflow that specifically runs an ilastik pixel classifier, then passes results to an object classifier,
saving intermediate images
:param fpi: Path object that references input image file
:param px_model: model instance for pixel classification
:param ob_model: model instance for object classification
:param where_output: Path object that references output image directory
:param kwargs: variable-length keyword arguments
:return:
"""
assert isinstance(px_model, IlastikPixelClassifierModel)
assert isinstance(ob_model, IlastikObjectClassifierModel)
ti = Timer()
ch = kwargs.get('channel')
img = generate_file_accessor(fpi).get_one_channel_data(ch)
ti.click('file_input')
px_map, _ = px_model.infer(img)
ti.click('pixel_probability_inference')
px_map_path = where_output / (px_model.model_id + '_pxmap_' + fpi.stem + '.tif')
write_accessor_data_to_file(px_map_path, px_map)
ti.click('pixel_map_output')
ob_map, _ = ob_model.infer(img, px_map)
ti.click('object_classification')
ob_map_path = where_output / (ob_model.model_id + '_obmap_' + fpi.stem + '.tif')
write_accessor_data_to_file(ob_map_path, ob_map)
ti.click('object_map_output')
return WorkflowRunRecord(
pixel_model_id=px_model.model_id,
object_model_id=ob_model.model_id,
input_filepath=str(fpi),
pixel_map_filepath=str(px_map_path),
object_map_filepath=str(ob_map_path),
success=True,
timer_results=ti.events,
)
......@@ -130,7 +130,6 @@ class TestIlastikOverApi(TestServerBaseClass):
self.assertEqual(resp_list.status_code, 200)
rj = resp_list.json()
self.assertEqual(rj[model_id]['class'], 'IlastikPixelClassifierModel')
return model_id
......@@ -146,6 +145,7 @@ class TestIlastikOverApi(TestServerBaseClass):
self.assertEqual(resp_list.status_code, 200)
rj = resp_list.json()
self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierModel')
return model_id
def test_ilastik_infer_pixel_probability(self):
self.copy_input_file_to_server()
......@@ -160,3 +160,19 @@ class TestIlastikOverApi(TestServerBaseClass):
},
)
self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
def test_ilastik_infer_px_then_ob(self):
self.copy_input_file_to_server()
px_model_id = self.test_load_ilastik_pixel_model()
ob_model_id = self.test_load_ilastik_object_model()
resp_infer = requests.put(
self.uri + f'models/ilastik/pixel_then_object_classification/infer/',
params={
'px_model_id': px_model_id,
'ob_model_id': ob_model_id,
'input_filename': conf.testing.czifile['filename'],
'channel': 0,
}
)
self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
......@@ -34,6 +34,4 @@ class TestGetSessionObject(unittest.TestCase):
img[0, 0],
0,
'First pixel is not black as expected'
)
print(result.timer_results)
\ No newline at end of file
)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment