Skip to content
Snippets Groups Projects
Commit bc76a8c4 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Include endpoint and testing for ilastik model that takes segmentations and solves object maps

parent 346bb8db
No related branches found
No related tags found
No related merge requests found
......@@ -53,9 +53,10 @@ monozstackmask = {
'z': 85
}
ilastik = {
'pixel_classifier': 'demo_px.ilp',
'object_classifier': 'demo_obj.ilp',
ilastik_classifiers = {
'px': 'demo_px.ilp',
'pxmap_to_obj': 'demo_obj.ilp',
'seg_to_obj': 'new_auto_obj.ilp',
}
output_path = root / 'testing_output'
......
......@@ -3,7 +3,7 @@ from fastapi import APIRouter, HTTPException
from model_server.session import Session
from model_server.validators import validate_workflow_inputs
from extensions.ilastik.models import IlastikImageToImageModel, IlastikPixelClassifierModel, IlastikObjectClassifierFromPixelPredictionsModel
from extensions.ilastik import models as ilm
from model_server.models import ParameterExpectedError
from extensions.ilastik.workflows import infer_px_then_ob_model
......@@ -14,7 +14,7 @@ router = APIRouter(
session = Session()
def load_ilastik_model(model_class: IlastikImageToImageModel, project_file: str, duplicate=True) -> dict:
def load_ilastik_model(model_class: ilm.IlastikImageToImageModel, project_file: str, duplicate=True) -> dict:
"""
Load an ilastik model of a given class and project filename.
:param model_class:
......@@ -40,13 +40,17 @@ def load_ilastik_model(model_class: IlastikImageToImageModel, project_file: str,
)
return result
@router.put('/pixel_classification/load/')
def load_ilastik_pixel_classification_model(project_file: str, duplicate: bool = True) -> dict:
return load_ilastik_model(IlastikPixelClassifierModel, project_file, duplicate=duplicate)
@router.put('/px/load/')
def load_px_model(project_file: str, duplicate: bool = True) -> dict:
return load_ilastik_model(ilm.IlastikPixelClassifierModel, project_file, duplicate=duplicate)
@router.put('/object_classification/load/')
def load_ilastik_object_classification_model(project_file: str, duplicate: bool = True) -> dict:
return load_ilastik_model(IlastikObjectClassifierFromPixelPredictionsModel, project_file, duplicate=duplicate)
@router.put('/pxmap_to_obj/load/')
def load_pxmap_to_obj_model(project_file: str, duplicate: bool = True) -> dict:
return load_ilastik_model(ilm.IlastikObjectClassifierFromPixelPredictionsModel, project_file, duplicate=duplicate)
@router.put('/seg_to_obj/load/')
def load_seg_to_obj_model(project_file: str, duplicate: bool = True) -> dict:
return load_ilastik_model(ilm.IlastikObjectClassifierFromSegmentationModel, project_file, duplicate=duplicate)
@router.put('/pixel_then_object_classification/infer')
def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: str, channel: int = None) -> dict:
......
......@@ -3,15 +3,15 @@ import unittest
import numpy as np
import conf.testing
from conf.testing import czifile, ilastik_classifiers, output_path
from model_server.accessors import CziImageFileAccessor, InMemoryDataAccessor, write_accessor_data_to_file
from extensions.ilastik.models import IlastikObjectClassifierFromPixelPredictionsModel, IlastikPixelClassifierModel
from extensions.ilastik import models as ilm
from model_server.workflows import infer_image_to_image
from tests.test_api import TestServerBaseClass
class TestIlastikPixelClassification(unittest.TestCase):
def setUp(self) -> None:
self.cf = CziImageFileAccessor(conf.testing.czifile['path'])
self.cf = CziImageFileAccessor(czifile['path'])
def test_faulthandler(self): # recreate error that is messing up ilastik
......@@ -24,8 +24,8 @@ class TestIlastikPixelClassification(unittest.TestCase):
def test_raise_error_if_autoload_disabled(self):
model = IlastikPixelClassifierModel(
{'project_file': conf.testing.ilastik['pixel_classifier']},
model = ilm.IlastikPixelClassifierModel(
{'project_file': ilastik_classifiers['px']},
autoload=False
)
w = 512
......@@ -38,8 +38,8 @@ class TestIlastikPixelClassification(unittest.TestCase):
def test_run_pixel_classifier_on_random_data(self):
model = IlastikPixelClassifierModel(
{'project_file': conf.testing.ilastik['pixel_classifier']},
model = ilm.IlastikPixelClassifierModel(
{'project_file': ilastik_classifiers['px']},
)
w = 512
h = 256
......@@ -52,16 +52,16 @@ class TestIlastikPixelClassification(unittest.TestCase):
def test_run_pixel_classifier(self):
channel = 0
model = IlastikPixelClassifierModel(
{'project_file': conf.testing.ilastik['pixel_classifier']}
model = ilm.IlastikPixelClassifierModel(
{'project_file': ilastik_classifiers['px']}
)
cf = CziImageFileAccessor(
conf.testing.czifile['path']
czifile['path']
)
mono_image = cf.get_one_channel_data(channel)
self.assertEqual(mono_image.shape_dict['X'], conf.testing.czifile['w'])
self.assertEqual(mono_image.shape_dict['Y'], conf.testing.czifile['h'])
self.assertEqual(mono_image.shape_dict['X'], czifile['w'])
self.assertEqual(mono_image.shape_dict['Y'], czifile['h'])
self.assertEqual(mono_image.shape_dict['C'], 1)
self.assertEqual(mono_image.shape_dict['Z'], 1)
......@@ -72,7 +72,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
self.assertEqual(pxmap.shape_dict['Z'], 1)
self.assertTrue(
write_accessor_data_to_file(
conf.testing.output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif',
output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif',
pxmap
)
)
......@@ -82,15 +82,15 @@ class TestIlastikPixelClassification(unittest.TestCase):
def test_run_object_classifier(self):
self.test_run_pixel_classifier()
fp = conf.testing.czifile['path']
model = IlastikObjectClassifierFromPixelPredictionsModel(
{'project_file': conf.testing.ilastik['object_classifier']}
fp = czifile['path']
model = ilm.IlastikObjectClassifierFromPixelPredictionsModel(
{'project_file': ilastik_classifiers['pxmap_to_obj']}
)
objmap, _ = model.infer(self.mono_image, self.pxmap)
self.assertTrue(
write_accessor_data_to_file(
conf.testing.output_path / f'obmap_{fp.stem}.tif',
output_path / f'obmap_{fp.stem}.tif',
objmap,
)
)
......@@ -98,11 +98,11 @@ class TestIlastikPixelClassification(unittest.TestCase):
def test_ilastik_pixel_classification_as_workflow(self):
result = infer_image_to_image(
conf.testing.czifile['path'],
IlastikPixelClassifierModel(
{'project_file': conf.testing.ilastik['pixel_classifier']}
czifile['path'],
ilm.IlastikPixelClassifierModel(
{'project_file': ilastik_classifiers['px']}
),
conf.testing.output_path,
output_path,
channel=0,
)
self.assertTrue(result.success)
......@@ -112,7 +112,7 @@ class TestIlastikOverApi(TestServerBaseClass):
def test_httpexception_if_incorrect_project_file_loaded(self):
resp_load = requests.put(
self.uri + 'ilastik/pixel_classification/load/',
self.uri + 'ilastik/px/load/',
params={'project_file': 'improper.ilp'},
)
self.assertEqual(resp_load.status_code, 404)
......@@ -120,8 +120,8 @@ class TestIlastikOverApi(TestServerBaseClass):
def test_load_ilastik_pixel_model(self):
resp_load = requests.put(
self.uri + 'ilastik/pixel_classification/load/',
params={'project_file': str(conf.testing.ilastik['pixel_classifier'])},
self.uri + 'ilastik/px/load/',
params={'project_file': str(ilastik_classifiers['px'])},
)
model_id = resp_load.json()['model_id']
......@@ -137,18 +137,18 @@ class TestIlastikOverApi(TestServerBaseClass):
resp_list_1st = requests.get(self.uri + 'models').json()
self.assertEqual(len(resp_list_1st), 1, resp_list_1st)
resp_load_2nd = requests.put(
self.uri + 'ilastik/pixel_classification/load/',
self.uri + 'ilastik/px/load/',
params={
'project_file': str(conf.testing.ilastik['pixel_classifier']),
'project_file': str(ilastik_classifiers['px']),
'duplicate': True,
},
)
resp_list_2nd = requests.get(self.uri + 'models').json()
self.assertEqual(len(resp_list_2nd), 2, resp_list_2nd)
resp_load_3rd = requests.put(
self.uri + 'ilastik/pixel_classification/load/',
self.uri + 'ilastik/px/load/',
params={
'project_file': str(conf.testing.ilastik['pixel_classifier']),
'project_file': str(ilastik_classifiers['px']),
'duplicate': False,
},
)
......@@ -156,10 +156,10 @@ class TestIlastikOverApi(TestServerBaseClass):
self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd)
def test_load_ilastik_object_model(self):
def test_load_ilastik_pxmap_to_obj_model(self):
resp_load = requests.put(
self.uri + 'ilastik/object_classification/load/',
params={'project_file': str(conf.testing.ilastik['object_classifier'])},
self.uri + 'ilastik/pxmap_to_obj/load/',
params={'project_file': str(ilastik_classifiers['pxmap_to_obj'])},
)
model_id = resp_load.json()['model_id']
......@@ -170,6 +170,20 @@ class TestIlastikOverApi(TestServerBaseClass):
self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierFromPixelPredictionsModel')
return model_id
def test_load_ilastik_seg_to_obj_model(self):
resp_load = requests.put(
self.uri + 'ilastik/seg_to_obj/load/',
params={'project_file': str(ilastik_classifiers['seg_to_obj'])},
)
model_id = resp_load.json()['model_id']
self.assertEqual(resp_load.status_code, 200, resp_load.json())
resp_list = requests.get(self.uri + 'models')
self.assertEqual(resp_list.status_code, 200)
rj = resp_list.json()
self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierFromSegmentationModel')
return model_id
def test_ilastik_infer_pixel_probability(self):
self.copy_input_file_to_server()
model_id = self.test_load_ilastik_pixel_model()
......@@ -178,7 +192,7 @@ class TestIlastikOverApi(TestServerBaseClass):
self.uri + f'infer/from_image_file',
params={
'model_id': model_id,
'input_filename': conf.testing.czifile['filename'],
'input_filename': czifile['filename'],
'channel': 0,
},
)
......@@ -187,14 +201,14 @@ class TestIlastikOverApi(TestServerBaseClass):
def test_ilastik_infer_px_then_ob(self):
self.copy_input_file_to_server()
px_model_id = self.test_load_ilastik_pixel_model()
ob_model_id = self.test_load_ilastik_object_model()
ob_model_id = self.test_load_ilastik_pxmap_to_obj_model()
resp_infer = requests.put(
self.uri + f'ilastik/pixel_then_object_classification/infer/',
params={
'px_model_id': px_model_id,
'ob_model_id': ob_model_id,
'input_filename': conf.testing.czifile['filename'],
'input_filename': czifile['filename'],
'channel': 0,
}
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment