Skip to content
Snippets Groups Projects
Commit a0b7809f authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Merge in ilastik API support for segmentation-to-object workflow

parent db420797
No related branches found
No related tags found
No related merge requests found
*/.idea/*
*__pycache__*
\ No newline at end of file
*__pycache__*
/clients/imagej/.idea/workspace.xml
/clients/imagej/.idea/
......@@ -53,9 +53,10 @@ monozstackmask = {
'z': 85
}
ilastik = {
'pixel_classifier': 'demo_px.ilp',
'object_classifier': 'demo_obj.ilp',
ilastik_classifiers = {
'px': 'demo_px.ilp',
'pxmap_to_obj': 'demo_obj.ilp',
'seg_to_obj': 'new_auto_obj.ilp',
}
output_path = root / 'testing_output'
......
......@@ -3,7 +3,7 @@ from fastapi import APIRouter, HTTPException
from model_server.session import Session
from model_server.validators import validate_workflow_inputs
from extensions.ilastik.models import IlastikImageToImageModel, IlastikPixelClassifierModel, IlastikObjectClassifierFromPixelPredictionsModel
from extensions.ilastik import models as ilm
from model_server.models import ParameterExpectedError
from extensions.ilastik.workflows import infer_px_then_ob_model
......@@ -14,7 +14,7 @@ router = APIRouter(
session = Session()
def load_ilastik_model(model_class: IlastikImageToImageModel, project_file: str, duplicate=True) -> dict:
def load_ilastik_model(model_class: ilm.IlastikImageToImageModel, project_file: str, duplicate=True) -> dict:
"""
Load an ilastik model of a given class and project filename.
:param model_class:
......@@ -40,13 +40,17 @@ def load_ilastik_model(model_class: IlastikImageToImageModel, project_file: str,
)
return result
@router.put('/pixel_classification/load/')
def load_ilastik_pixel_classification_model(project_file: str, duplicate: bool = True) -> dict:
return load_ilastik_model(IlastikPixelClassifierModel, project_file, duplicate=duplicate)
@router.put('/px/load/')
def load_px_model(project_file: str, duplicate: bool = True) -> dict:
return load_ilastik_model(ilm.IlastikPixelClassifierModel, project_file, duplicate=duplicate)
@router.put('/object_classification/load/')
def load_ilastik_object_classification_model(project_file: str, duplicate: bool = True) -> dict:
return load_ilastik_model(IlastikObjectClassifierFromPixelPredictionsModel, project_file, duplicate=duplicate)
@router.put('/pxmap_to_obj/load/')
def load_pxmap_to_obj_model(project_file: str, duplicate: bool = True) -> dict:
return load_ilastik_model(ilm.IlastikObjectClassifierFromPixelPredictionsModel, project_file, duplicate=duplicate)
@router.put('/seg_to_obj/load/')
def load_seg_to_obj_model(project_file: str, duplicate: bool = True) -> dict:
return load_ilastik_model(ilm.IlastikObjectClassifierFromSegmentationModel, project_file, duplicate=duplicate)
@router.put('/pixel_then_object_classification/infer')
def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: str, channel: int = None) -> dict:
......
......@@ -3,15 +3,15 @@ import unittest
import numpy as np
import conf.testing
from conf.testing import czifile, ilastik_classifiers, output_path
from model_server.accessors import CziImageFileAccessor, InMemoryDataAccessor, write_accessor_data_to_file
from extensions.ilastik.models import IlastikObjectClassifierFromPixelPredictionsModel, IlastikPixelClassifierModel
from extensions.ilastik import models as ilm
from model_server.workflows import infer_image_to_image
from tests.test_api import TestServerBaseClass
class TestIlastikPixelClassification(unittest.TestCase):
def setUp(self) -> None:
self.cf = CziImageFileAccessor(conf.testing.czifile['path'])
self.cf = CziImageFileAccessor(czifile['path'])
def test_faulthandler(self): # recreate error that is messing up ilastik
......@@ -24,8 +24,8 @@ class TestIlastikPixelClassification(unittest.TestCase):
def test_raise_error_if_autoload_disabled(self):
model = IlastikPixelClassifierModel(
{'project_file': conf.testing.ilastik['pixel_classifier']},
model = ilm.IlastikPixelClassifierModel(
{'project_file': ilastik_classifiers['px']},
autoload=False
)
w = 512
......@@ -38,8 +38,8 @@ class TestIlastikPixelClassification(unittest.TestCase):
def test_run_pixel_classifier_on_random_data(self):
model = IlastikPixelClassifierModel(
{'project_file': conf.testing.ilastik['pixel_classifier']},
model = ilm.IlastikPixelClassifierModel(
{'project_file': ilastik_classifiers['px']},
)
w = 512
h = 256
......@@ -52,16 +52,16 @@ class TestIlastikPixelClassification(unittest.TestCase):
def test_run_pixel_classifier(self):
channel = 0
model = IlastikPixelClassifierModel(
{'project_file': conf.testing.ilastik['pixel_classifier']}
model = ilm.IlastikPixelClassifierModel(
{'project_file': ilastik_classifiers['px']}
)
cf = CziImageFileAccessor(
conf.testing.czifile['path']
czifile['path']
)
mono_image = cf.get_one_channel_data(channel)
self.assertEqual(mono_image.shape_dict['X'], conf.testing.czifile['w'])
self.assertEqual(mono_image.shape_dict['Y'], conf.testing.czifile['h'])
self.assertEqual(mono_image.shape_dict['X'], czifile['w'])
self.assertEqual(mono_image.shape_dict['Y'], czifile['h'])
self.assertEqual(mono_image.shape_dict['C'], 1)
self.assertEqual(mono_image.shape_dict['Z'], 1)
......@@ -72,7 +72,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
self.assertEqual(pxmap.shape_dict['Z'], 1)
self.assertTrue(
write_accessor_data_to_file(
conf.testing.output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif',
output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif',
pxmap
)
)
......@@ -82,15 +82,15 @@ class TestIlastikPixelClassification(unittest.TestCase):
def test_run_object_classifier(self):
self.test_run_pixel_classifier()
fp = conf.testing.czifile['path']
model = IlastikObjectClassifierFromPixelPredictionsModel(
{'project_file': conf.testing.ilastik['object_classifier']}
fp = czifile['path']
model = ilm.IlastikObjectClassifierFromPixelPredictionsModel(
{'project_file': ilastik_classifiers['pxmap_to_obj']}
)
objmap, _ = model.infer(self.mono_image, self.pxmap)
self.assertTrue(
write_accessor_data_to_file(
conf.testing.output_path / f'obmap_{fp.stem}.tif',
output_path / f'obmap_{fp.stem}.tif',
objmap,
)
)
......@@ -98,11 +98,11 @@ class TestIlastikPixelClassification(unittest.TestCase):
def test_ilastik_pixel_classification_as_workflow(self):
result = infer_image_to_image(
conf.testing.czifile['path'],
IlastikPixelClassifierModel(
{'project_file': conf.testing.ilastik['pixel_classifier']}
czifile['path'],
ilm.IlastikPixelClassifierModel(
{'project_file': ilastik_classifiers['px']}
),
conf.testing.output_path,
output_path,
channel=0,
)
self.assertTrue(result.success)
......@@ -112,7 +112,7 @@ class TestIlastikOverApi(TestServerBaseClass):
def test_httpexception_if_incorrect_project_file_loaded(self):
resp_load = requests.put(
self.uri + 'ilastik/pixel_classification/load/',
self.uri + 'ilastik/px/load/',
params={'project_file': 'improper.ilp'},
)
self.assertEqual(resp_load.status_code, 404)
......@@ -120,8 +120,8 @@ class TestIlastikOverApi(TestServerBaseClass):
def test_load_ilastik_pixel_model(self):
resp_load = requests.put(
self.uri + 'ilastik/pixel_classification/load/',
params={'project_file': str(conf.testing.ilastik['pixel_classifier'])},
self.uri + 'ilastik/px/load/',
params={'project_file': str(ilastik_classifiers['px'])},
)
model_id = resp_load.json()['model_id']
......@@ -137,18 +137,18 @@ class TestIlastikOverApi(TestServerBaseClass):
resp_list_1st = requests.get(self.uri + 'models').json()
self.assertEqual(len(resp_list_1st), 1, resp_list_1st)
resp_load_2nd = requests.put(
self.uri + 'ilastik/pixel_classification/load/',
self.uri + 'ilastik/px/load/',
params={
'project_file': str(conf.testing.ilastik['pixel_classifier']),
'project_file': str(ilastik_classifiers['px']),
'duplicate': True,
},
)
resp_list_2nd = requests.get(self.uri + 'models').json()
self.assertEqual(len(resp_list_2nd), 2, resp_list_2nd)
resp_load_3rd = requests.put(
self.uri + 'ilastik/pixel_classification/load/',
self.uri + 'ilastik/px/load/',
params={
'project_file': str(conf.testing.ilastik['pixel_classifier']),
'project_file': str(ilastik_classifiers['px']),
'duplicate': False,
},
)
......@@ -156,10 +156,10 @@ class TestIlastikOverApi(TestServerBaseClass):
self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd)
def test_load_ilastik_object_model(self):
def test_load_ilastik_pxmap_to_obj_model(self):
resp_load = requests.put(
self.uri + 'ilastik/object_classification/load/',
params={'project_file': str(conf.testing.ilastik['object_classifier'])},
self.uri + 'ilastik/pxmap_to_obj/load/',
params={'project_file': str(ilastik_classifiers['pxmap_to_obj'])},
)
model_id = resp_load.json()['model_id']
......@@ -170,6 +170,20 @@ class TestIlastikOverApi(TestServerBaseClass):
self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierFromPixelPredictionsModel')
return model_id
def test_load_ilastik_seg_to_obj_model(self):
resp_load = requests.put(
self.uri + 'ilastik/seg_to_obj/load/',
params={'project_file': str(ilastik_classifiers['seg_to_obj'])},
)
model_id = resp_load.json()['model_id']
self.assertEqual(resp_load.status_code, 200, resp_load.json())
resp_list = requests.get(self.uri + 'models')
self.assertEqual(resp_list.status_code, 200)
rj = resp_list.json()
self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierFromSegmentationModel')
return model_id
def test_ilastik_infer_pixel_probability(self):
self.copy_input_file_to_server()
model_id = self.test_load_ilastik_pixel_model()
......@@ -178,7 +192,7 @@ class TestIlastikOverApi(TestServerBaseClass):
self.uri + f'infer/from_image_file',
params={
'model_id': model_id,
'input_filename': conf.testing.czifile['filename'],
'input_filename': czifile['filename'],
'channel': 0,
},
)
......@@ -187,14 +201,14 @@ class TestIlastikOverApi(TestServerBaseClass):
def test_ilastik_infer_px_then_ob(self):
self.copy_input_file_to_server()
px_model_id = self.test_load_ilastik_pixel_model()
ob_model_id = self.test_load_ilastik_object_model()
ob_model_id = self.test_load_ilastik_pxmap_to_obj_model()
resp_infer = requests.put(
self.uri + f'ilastik/pixel_then_object_classification/infer/',
params={
'px_model_id': px_model_id,
'ob_model_id': ob_model_id,
'input_filename': conf.testing.czifile['filename'],
'input_filename': czifile['filename'],
'channel': 0,
}
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment