From bc76a8c42c5d289b94d0d464c86a544847b2b7b4 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Tue, 31 Oct 2023 15:08:35 +0100 Subject: [PATCH] Include endpoint and testing for ilastik model that takes segmentations and solves object maps --- conf/testing.py | 7 +- extensions/ilastik/router.py | 20 +++--- extensions/ilastik/tests/test_ilastik.py | 82 ++++++++++++++---------- 3 files changed, 64 insertions(+), 45 deletions(-) diff --git a/conf/testing.py b/conf/testing.py index 3c52b95c..1c2b0cf3 100644 --- a/conf/testing.py +++ b/conf/testing.py @@ -53,9 +53,10 @@ monozstackmask = { 'z': 85 } -ilastik = { - 'pixel_classifier': 'demo_px.ilp', - 'object_classifier': 'demo_obj.ilp', +ilastik_classifiers = { + 'px': 'demo_px.ilp', + 'pxmap_to_obj': 'demo_obj.ilp', + 'seg_to_obj': 'new_auto_obj.ilp', } output_path = root / 'testing_output' diff --git a/extensions/ilastik/router.py b/extensions/ilastik/router.py index 3287dc47..ba01d1b9 100644 --- a/extensions/ilastik/router.py +++ b/extensions/ilastik/router.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, HTTPException from model_server.session import Session from model_server.validators import validate_workflow_inputs -from extensions.ilastik.models import IlastikImageToImageModel, IlastikPixelClassifierModel, IlastikObjectClassifierFromPixelPredictionsModel +from extensions.ilastik import models as ilm from model_server.models import ParameterExpectedError from extensions.ilastik.workflows import infer_px_then_ob_model @@ -14,7 +14,7 @@ router = APIRouter( session = Session() -def load_ilastik_model(model_class: IlastikImageToImageModel, project_file: str, duplicate=True) -> dict: +def load_ilastik_model(model_class: ilm.IlastikImageToImageModel, project_file: str, duplicate=True) -> dict: """ Load an ilastik model of a given class and project filename. :param model_class: @@ -40,13 +40,17 @@ def load_ilastik_model(model_class: IlastikImageToImageModel, project_file: str, ) return result -@router.put('/pixel_classification/load/') -def load_ilastik_pixel_classification_model(project_file: str, duplicate: bool = True) -> dict: - return load_ilastik_model(IlastikPixelClassifierModel, project_file, duplicate=duplicate) +@router.put('/px/load/') +def load_px_model(project_file: str, duplicate: bool = True) -> dict: + return load_ilastik_model(ilm.IlastikPixelClassifierModel, project_file, duplicate=duplicate) -@router.put('/object_classification/load/') -def load_ilastik_object_classification_model(project_file: str, duplicate: bool = True) -> dict: - return load_ilastik_model(IlastikObjectClassifierFromPixelPredictionsModel, project_file, duplicate=duplicate) +@router.put('/pxmap_to_obj/load/') +def load_pxmap_to_obj_model(project_file: str, duplicate: bool = True) -> dict: + return load_ilastik_model(ilm.IlastikObjectClassifierFromPixelPredictionsModel, project_file, duplicate=duplicate) + +@router.put('/seg_to_obj/load/') +def load_seg_to_obj_model(project_file: str, duplicate: bool = True) -> dict: + return load_ilastik_model(ilm.IlastikObjectClassifierFromSegmentationModel, project_file, duplicate=duplicate) @router.put('/pixel_then_object_classification/infer') def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: str, channel: int = None) -> dict: diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py index 217b00c9..d2498af2 100644 --- a/extensions/ilastik/tests/test_ilastik.py +++ b/extensions/ilastik/tests/test_ilastik.py @@ -3,15 +3,15 @@ import unittest import numpy as np -import conf.testing +from conf.testing import czifile, ilastik_classifiers, output_path from model_server.accessors import CziImageFileAccessor, InMemoryDataAccessor, write_accessor_data_to_file -from extensions.ilastik.models import IlastikObjectClassifierFromPixelPredictionsModel, IlastikPixelClassifierModel +from extensions.ilastik import models as ilm from model_server.workflows import infer_image_to_image from tests.test_api import TestServerBaseClass class TestIlastikPixelClassification(unittest.TestCase): def setUp(self) -> None: - self.cf = CziImageFileAccessor(conf.testing.czifile['path']) + self.cf = CziImageFileAccessor(czifile['path']) def test_faulthandler(self): # recreate error that is messing up ilastik @@ -24,8 +24,8 @@ class TestIlastikPixelClassification(unittest.TestCase): def test_raise_error_if_autoload_disabled(self): - model = IlastikPixelClassifierModel( - {'project_file': conf.testing.ilastik['pixel_classifier']}, + model = ilm.IlastikPixelClassifierModel( + {'project_file': ilastik_classifiers['px']}, autoload=False ) w = 512 @@ -38,8 +38,8 @@ class TestIlastikPixelClassification(unittest.TestCase): def test_run_pixel_classifier_on_random_data(self): - model = IlastikPixelClassifierModel( - {'project_file': conf.testing.ilastik['pixel_classifier']}, + model = ilm.IlastikPixelClassifierModel( + {'project_file': ilastik_classifiers['px']}, ) w = 512 h = 256 @@ -52,16 +52,16 @@ class TestIlastikPixelClassification(unittest.TestCase): def test_run_pixel_classifier(self): channel = 0 - model = IlastikPixelClassifierModel( - {'project_file': conf.testing.ilastik['pixel_classifier']} + model = ilm.IlastikPixelClassifierModel( + {'project_file': ilastik_classifiers['px']} ) cf = CziImageFileAccessor( - conf.testing.czifile['path'] + czifile['path'] ) mono_image = cf.get_one_channel_data(channel) - self.assertEqual(mono_image.shape_dict['X'], conf.testing.czifile['w']) - self.assertEqual(mono_image.shape_dict['Y'], conf.testing.czifile['h']) + self.assertEqual(mono_image.shape_dict['X'], czifile['w']) + self.assertEqual(mono_image.shape_dict['Y'], czifile['h']) self.assertEqual(mono_image.shape_dict['C'], 1) self.assertEqual(mono_image.shape_dict['Z'], 1) @@ -72,7 +72,7 @@ class TestIlastikPixelClassification(unittest.TestCase): self.assertEqual(pxmap.shape_dict['Z'], 1) self.assertTrue( write_accessor_data_to_file( - conf.testing.output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif', + output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif', pxmap ) ) @@ -82,15 +82,15 @@ class TestIlastikPixelClassification(unittest.TestCase): def test_run_object_classifier(self): self.test_run_pixel_classifier() - fp = conf.testing.czifile['path'] - model = IlastikObjectClassifierFromPixelPredictionsModel( - {'project_file': conf.testing.ilastik['object_classifier']} + fp = czifile['path'] + model = ilm.IlastikObjectClassifierFromPixelPredictionsModel( + {'project_file': ilastik_classifiers['pxmap_to_obj']} ) objmap, _ = model.infer(self.mono_image, self.pxmap) self.assertTrue( write_accessor_data_to_file( - conf.testing.output_path / f'obmap_{fp.stem}.tif', + output_path / f'obmap_{fp.stem}.tif', objmap, ) ) @@ -98,11 +98,11 @@ class TestIlastikPixelClassification(unittest.TestCase): def test_ilastik_pixel_classification_as_workflow(self): result = infer_image_to_image( - conf.testing.czifile['path'], - IlastikPixelClassifierModel( - {'project_file': conf.testing.ilastik['pixel_classifier']} + czifile['path'], + ilm.IlastikPixelClassifierModel( + {'project_file': ilastik_classifiers['px']} ), - conf.testing.output_path, + output_path, channel=0, ) self.assertTrue(result.success) @@ -112,7 +112,7 @@ class TestIlastikOverApi(TestServerBaseClass): def test_httpexception_if_incorrect_project_file_loaded(self): resp_load = requests.put( - self.uri + 'ilastik/pixel_classification/load/', + self.uri + 'ilastik/px/load/', params={'project_file': 'improper.ilp'}, ) self.assertEqual(resp_load.status_code, 404) @@ -120,8 +120,8 @@ class TestIlastikOverApi(TestServerBaseClass): def test_load_ilastik_pixel_model(self): resp_load = requests.put( - self.uri + 'ilastik/pixel_classification/load/', - params={'project_file': str(conf.testing.ilastik['pixel_classifier'])}, + self.uri + 'ilastik/px/load/', + params={'project_file': str(ilastik_classifiers['px'])}, ) model_id = resp_load.json()['model_id'] @@ -137,18 +137,18 @@ class TestIlastikOverApi(TestServerBaseClass): resp_list_1st = requests.get(self.uri + 'models').json() self.assertEqual(len(resp_list_1st), 1, resp_list_1st) resp_load_2nd = requests.put( - self.uri + 'ilastik/pixel_classification/load/', + self.uri + 'ilastik/px/load/', params={ - 'project_file': str(conf.testing.ilastik['pixel_classifier']), + 'project_file': str(ilastik_classifiers['px']), 'duplicate': True, }, ) resp_list_2nd = requests.get(self.uri + 'models').json() self.assertEqual(len(resp_list_2nd), 2, resp_list_2nd) resp_load_3rd = requests.put( - self.uri + 'ilastik/pixel_classification/load/', + self.uri + 'ilastik/px/load/', params={ - 'project_file': str(conf.testing.ilastik['pixel_classifier']), + 'project_file': str(ilastik_classifiers['px']), 'duplicate': False, }, ) @@ -156,10 +156,10 @@ class TestIlastikOverApi(TestServerBaseClass): self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd) - def test_load_ilastik_object_model(self): + def test_load_ilastik_pxmap_to_obj_model(self): resp_load = requests.put( - self.uri + 'ilastik/object_classification/load/', - params={'project_file': str(conf.testing.ilastik['object_classifier'])}, + self.uri + 'ilastik/pxmap_to_obj/load/', + params={'project_file': str(ilastik_classifiers['pxmap_to_obj'])}, ) model_id = resp_load.json()['model_id'] @@ -170,6 +170,20 @@ class TestIlastikOverApi(TestServerBaseClass): self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierFromPixelPredictionsModel') return model_id + def test_load_ilastik_seg_to_obj_model(self): + resp_load = requests.put( + self.uri + 'ilastik/seg_to_obj/load/', + params={'project_file': str(ilastik_classifiers['seg_to_obj'])}, + ) + model_id = resp_load.json()['model_id'] + + self.assertEqual(resp_load.status_code, 200, resp_load.json()) + resp_list = requests.get(self.uri + 'models') + self.assertEqual(resp_list.status_code, 200) + rj = resp_list.json() + self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierFromSegmentationModel') + return model_id + def test_ilastik_infer_pixel_probability(self): self.copy_input_file_to_server() model_id = self.test_load_ilastik_pixel_model() @@ -178,7 +192,7 @@ class TestIlastikOverApi(TestServerBaseClass): self.uri + f'infer/from_image_file', params={ 'model_id': model_id, - 'input_filename': conf.testing.czifile['filename'], + 'input_filename': czifile['filename'], 'channel': 0, }, ) @@ -187,14 +201,14 @@ class TestIlastikOverApi(TestServerBaseClass): def test_ilastik_infer_px_then_ob(self): self.copy_input_file_to_server() px_model_id = self.test_load_ilastik_pixel_model() - ob_model_id = self.test_load_ilastik_object_model() + ob_model_id = self.test_load_ilastik_pxmap_to_obj_model() resp_infer = requests.put( self.uri + f'ilastik/pixel_then_object_classification/infer/', params={ 'px_model_id': px_model_id, 'ob_model_id': ob_model_id, - 'input_filename': conf.testing.czifile['filename'], + 'input_filename': czifile['filename'], 'channel': 0, } ) -- GitLab