From e629a6bc98da2460273bb5a619c472c833dd4f73 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Wed, 20 Dec 2023 10:27:04 +0100 Subject: [PATCH] Added superclass for semantic segmentation models and modified ilastik pixel classification to inherit from it --- extensions/ilastik/models.py | 78 ++++++++++++------------ extensions/ilastik/router.py | 6 +- extensions/ilastik/tests/test_ilastik.py | 12 ++-- model_server/models.py | 12 +++- 4 files changed, 57 insertions(+), 51 deletions(-) diff --git a/extensions/ilastik/models.py b/extensions/ilastik/models.py index 21278461..8f22c3ad 100644 --- a/extensions/ilastik/models.py +++ b/extensions/ilastik/models.py @@ -6,10 +6,10 @@ import vigra import extensions.ilastik.conf from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor -from model_server.models import ImageToImageModel, ParameterExpectedError +from model_server.models import Model, ParameterExpectedError, SemanticSegmentationModel -class IlastikImageToImageModel(ImageToImageModel): +class IlastikModel(Model): def __init__(self, params, autoload=True): self.project_file = Path(params['project_file']) @@ -27,7 +27,6 @@ class IlastikImageToImageModel(ImageToImageModel): self.shell = None super().__init__(autoload, params) - def load(self): from ilastik import app from ilastik.applets.dataSelection.opDataSelection import PreloadedArrayDatasetInfo @@ -51,7 +50,7 @@ class IlastikImageToImageModel(ImageToImageModel): return True -class IlastikPixelClassifierModel(IlastikImageToImageModel): +class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): model_id = 'ilastik_pixel_classification' operations = ['segment', ] @@ -78,43 +77,44 @@ class IlastikPixelClassifierModel(IlastikImageToImageModel): ) return InMemoryDataAccessor(data=yxcz), {'success': True} - def segment(self, input_img, thresh, channel): - return InMemoryDataAccessor( - self.infer(input_img).data[:, :, channel, :] > thresh + def segment(self, img: GenericImageDataAccessor, pixel_class: int, pixel_probability_threshold=0.5): + pxmap, _ = self.infer(img) + mask = pxmap.data[:, :, pixel_class, :] > pixel_probability_threshold + return InMemoryDataAccessor(mask) + +# TODO: deprecate +class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel): + model_id = 'ilastik_object_classification_from_pixel_predictions' + + @staticmethod + def get_workflow(): + from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction + return ObjectClassificationWorkflowPrediction + + def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict): + tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') + tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz') + + dsi = [ + { + 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), + 'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data), + } + ] + + obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] + + assert (len(obmaps) == 1, 'ilastik generated more than one object map') + + yxcz = np.moveaxis( + obmaps[0], + [1, 2, 3, 0], + [0, 1, 2, 3] ) + return InMemoryDataAccessor(data=yxcz), {'success': True} + -# class IlastikObjectClassifierFromPixelPredictionsModel(IlastikImageToImageModel): -# model_id = 'ilastik_object_classification_from_pixel_predictions' -# -# @staticmethod -# def get_workflow(): -# from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction -# return ObjectClassificationWorkflowPrediction -# -# def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict): -# tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') -# tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz') -# -# dsi = [ -# { -# 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), -# 'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data), -# } -# ] -# -# obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] -# -# assert (len(obmaps) == 1, 'ilastik generated more than one object map') -# -# yxcz = np.moveaxis( -# obmaps[0], -# [1, 2, 3, 0], -# [0, 1, 2, 3] -# ) -# return InMemoryDataAccessor(data=yxcz), {'success': True} - - -class IlastikObjectClassifierFromSegmentationModel(IlastikImageToImageModel): +class IlastikObjectClassifierFromSegmentationModel(IlastikModel): model_id = 'ilastik_object_classification_from_segmentation' @staticmethod diff --git a/extensions/ilastik/router.py b/extensions/ilastik/router.py index 4d9a8b28..15a59b28 100644 --- a/extensions/ilastik/router.py +++ b/extensions/ilastik/router.py @@ -14,7 +14,7 @@ router = APIRouter( session = Session() -def load_ilastik_model(model_class: ilm.IlastikImageToImageModel, project_file: str, duplicate=True) -> dict: +def load_ilastik_model(model_class: ilm.IlastikModel, project_file: str, duplicate=True) -> dict: """ Load an ilastik model of a given class and project filename. :param model_class: @@ -35,10 +35,6 @@ def load_ilastik_model(model_class: ilm.IlastikImageToImageModel, project_file: ) return {'model_id': result} -# @router.put('/px/load/') -# def load_px_model(project_file: str, duplicate: bool = True) -> dict: -# return load_ilastik_model(ilm.IlastikPixelClassifierModel, project_file, duplicate=duplicate) - @router.put('/seg/load/') def load_px_model(project_file: str, duplicate: bool = True) -> dict: return load_ilastik_model(ilm.IlastikPixelClassifierModel, project_file, duplicate=duplicate) diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py index 79cbe3b4..f4f400a6 100644 --- a/extensions/ilastik/tests/test_ilastik.py +++ b/extensions/ilastik/tests/test_ilastik.py @@ -113,7 +113,7 @@ class TestIlastikOverApi(TestServerBaseClass): def test_httpexception_if_incorrect_project_file_loaded(self): resp_load = requests.put( - self.uri + 'ilastik/px/load/', + self.uri + 'ilastik/seg/load/', params={'project_file': 'improper.ilp'}, ) self.assertEqual(resp_load.status_code, 404) @@ -121,7 +121,7 @@ class TestIlastikOverApi(TestServerBaseClass): def test_load_ilastik_pixel_model(self): resp_load = requests.put( - self.uri + 'ilastik/px/load/', + self.uri + 'ilastik/seg/load/', params={'project_file': str(ilastik_classifiers['px'])}, ) self.assertEqual(resp_load.status_code, 200, resp_load.json()) @@ -137,7 +137,7 @@ class TestIlastikOverApi(TestServerBaseClass): resp_list_1st = requests.get(self.uri + 'models').json() self.assertEqual(len(resp_list_1st), 1, resp_list_1st) resp_load_2nd = requests.put( - self.uri + 'ilastik/px/load/', + self.uri + 'ilastik/seg/load/', params={ 'project_file': str(ilastik_classifiers['px']), 'duplicate': True, @@ -146,7 +146,7 @@ class TestIlastikOverApi(TestServerBaseClass): resp_list_2nd = requests.get(self.uri + 'models').json() self.assertEqual(len(resp_list_2nd), 2, resp_list_2nd) resp_load_3rd = requests.put( - self.uri + 'ilastik/px/load/', + self.uri + 'ilastik/seg/load/', params={ 'project_file': str(ilastik_classifiers['px']), 'duplicate': False, @@ -172,14 +172,14 @@ class TestIlastikOverApi(TestServerBaseClass): # load models with these paths resp1 = requests.put( - self.uri + 'ilastik/px/load/', + self.uri + 'ilastik/seg/load/', params={ 'project_file': ilp_win, 'duplicate': False, }, ) resp2 = requests.put( - self.uri + 'ilastik/px/load/', + self.uri + 'ilastik/seg/load/', params={ 'project_file': ilp_posx, 'duplicate': False, diff --git a/model_server/models.py b/model_server/models.py index 4285b49c..07fc31df 100644 --- a/model_server/models.py +++ b/model_server/models.py @@ -36,7 +36,7 @@ class Model(ABC): pass @abstractmethod - def infer(self, img: GenericImageDataAccessor) -> (object, dict): # return json describing inference result + def infer(self, *args) -> (object, dict): # return json describing inference result pass def reload(self): @@ -52,6 +52,15 @@ class ImageToImageModel(Model): def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict): pass +class SemanticSegmentationModel(ImageToImageModel): + """ + Model that exposes a method that returns a binary mask for a given input image and pixel class + """ + + @abstractmethod + def segment(self, img: GenericImageDataAccessor, pixel_class: int, **kwargs) -> (GenericImageDataAccessor, dict): + pass + class DummyImageToImageModel(ImageToImageModel): model_id = 'dummy_make_white_square' @@ -67,6 +76,7 @@ class DummyImageToImageModel(ImageToImageModel): result[floor(0.25 * h) : floor(0.75 * h), floor(0.25 * w) : floor(0.75 * w)] = 255 return InMemoryDataAccessor(data=result), {'success': True} + class Error(Exception): pass -- GitLab