Skip to content
Snippets Groups Projects
Commit e629a6bc authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Added superclass for semantic segmentation models and modified ilastik pixel...

Added superclass for semantic segmentation models and modified ilastik pixel classification to inherit from it
parent a3043334
No related branches found
No related tags found
No related merge requests found
......@@ -6,10 +6,10 @@ import vigra
import extensions.ilastik.conf
from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor
from model_server.models import ImageToImageModel, ParameterExpectedError
from model_server.models import Model, ParameterExpectedError, SemanticSegmentationModel
class IlastikImageToImageModel(ImageToImageModel):
class IlastikModel(Model):
def __init__(self, params, autoload=True):
self.project_file = Path(params['project_file'])
......@@ -27,7 +27,6 @@ class IlastikImageToImageModel(ImageToImageModel):
self.shell = None
super().__init__(autoload, params)
def load(self):
from ilastik import app
from ilastik.applets.dataSelection.opDataSelection import PreloadedArrayDatasetInfo
......@@ -51,7 +50,7 @@ class IlastikImageToImageModel(ImageToImageModel):
return True
class IlastikPixelClassifierModel(IlastikImageToImageModel):
class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
model_id = 'ilastik_pixel_classification'
operations = ['segment', ]
......@@ -78,43 +77,44 @@ class IlastikPixelClassifierModel(IlastikImageToImageModel):
)
return InMemoryDataAccessor(data=yxcz), {'success': True}
def segment(self, input_img, thresh, channel):
return InMemoryDataAccessor(
self.infer(input_img).data[:, :, channel, :] > thresh
def segment(self, img: GenericImageDataAccessor, pixel_class: int, pixel_probability_threshold=0.5):
pxmap, _ = self.infer(img)
mask = pxmap.data[:, :, pixel_class, :] > pixel_probability_threshold
return InMemoryDataAccessor(mask)
# TODO: deprecate
class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel):
model_id = 'ilastik_object_classification_from_pixel_predictions'
@staticmethod
def get_workflow():
from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction
return ObjectClassificationWorkflowPrediction
def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz')
dsi = [
{
'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data),
}
]
obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
assert (len(obmaps) == 1, 'ilastik generated more than one object map')
yxcz = np.moveaxis(
obmaps[0],
[1, 2, 3, 0],
[0, 1, 2, 3]
)
return InMemoryDataAccessor(data=yxcz), {'success': True}
# class IlastikObjectClassifierFromPixelPredictionsModel(IlastikImageToImageModel):
# model_id = 'ilastik_object_classification_from_pixel_predictions'
#
# @staticmethod
# def get_workflow():
# from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction
# return ObjectClassificationWorkflowPrediction
#
# def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
# tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
# tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz')
#
# dsi = [
# {
# 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
# 'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data),
# }
# ]
#
# obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
#
# assert (len(obmaps) == 1, 'ilastik generated more than one object map')
#
# yxcz = np.moveaxis(
# obmaps[0],
# [1, 2, 3, 0],
# [0, 1, 2, 3]
# )
# return InMemoryDataAccessor(data=yxcz), {'success': True}
class IlastikObjectClassifierFromSegmentationModel(IlastikImageToImageModel):
class IlastikObjectClassifierFromSegmentationModel(IlastikModel):
model_id = 'ilastik_object_classification_from_segmentation'
@staticmethod
......
......@@ -14,7 +14,7 @@ router = APIRouter(
session = Session()
def load_ilastik_model(model_class: ilm.IlastikImageToImageModel, project_file: str, duplicate=True) -> dict:
def load_ilastik_model(model_class: ilm.IlastikModel, project_file: str, duplicate=True) -> dict:
"""
Load an ilastik model of a given class and project filename.
:param model_class:
......@@ -35,10 +35,6 @@ def load_ilastik_model(model_class: ilm.IlastikImageToImageModel, project_file:
)
return {'model_id': result}
# @router.put('/px/load/')
# def load_px_model(project_file: str, duplicate: bool = True) -> dict:
# return load_ilastik_model(ilm.IlastikPixelClassifierModel, project_file, duplicate=duplicate)
@router.put('/seg/load/')
def load_px_model(project_file: str, duplicate: bool = True) -> dict:
return load_ilastik_model(ilm.IlastikPixelClassifierModel, project_file, duplicate=duplicate)
......
......@@ -113,7 +113,7 @@ class TestIlastikOverApi(TestServerBaseClass):
def test_httpexception_if_incorrect_project_file_loaded(self):
resp_load = requests.put(
self.uri + 'ilastik/px/load/',
self.uri + 'ilastik/seg/load/',
params={'project_file': 'improper.ilp'},
)
self.assertEqual(resp_load.status_code, 404)
......@@ -121,7 +121,7 @@ class TestIlastikOverApi(TestServerBaseClass):
def test_load_ilastik_pixel_model(self):
resp_load = requests.put(
self.uri + 'ilastik/px/load/',
self.uri + 'ilastik/seg/load/',
params={'project_file': str(ilastik_classifiers['px'])},
)
self.assertEqual(resp_load.status_code, 200, resp_load.json())
......@@ -137,7 +137,7 @@ class TestIlastikOverApi(TestServerBaseClass):
resp_list_1st = requests.get(self.uri + 'models').json()
self.assertEqual(len(resp_list_1st), 1, resp_list_1st)
resp_load_2nd = requests.put(
self.uri + 'ilastik/px/load/',
self.uri + 'ilastik/seg/load/',
params={
'project_file': str(ilastik_classifiers['px']),
'duplicate': True,
......@@ -146,7 +146,7 @@ class TestIlastikOverApi(TestServerBaseClass):
resp_list_2nd = requests.get(self.uri + 'models').json()
self.assertEqual(len(resp_list_2nd), 2, resp_list_2nd)
resp_load_3rd = requests.put(
self.uri + 'ilastik/px/load/',
self.uri + 'ilastik/seg/load/',
params={
'project_file': str(ilastik_classifiers['px']),
'duplicate': False,
......@@ -172,14 +172,14 @@ class TestIlastikOverApi(TestServerBaseClass):
# load models with these paths
resp1 = requests.put(
self.uri + 'ilastik/px/load/',
self.uri + 'ilastik/seg/load/',
params={
'project_file': ilp_win,
'duplicate': False,
},
)
resp2 = requests.put(
self.uri + 'ilastik/px/load/',
self.uri + 'ilastik/seg/load/',
params={
'project_file': ilp_posx,
'duplicate': False,
......
......@@ -36,7 +36,7 @@ class Model(ABC):
pass
@abstractmethod
def infer(self, img: GenericImageDataAccessor) -> (object, dict): # return json describing inference result
def infer(self, *args) -> (object, dict): # return json describing inference result
pass
def reload(self):
......@@ -52,6 +52,15 @@ class ImageToImageModel(Model):
def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict):
pass
class SemanticSegmentationModel(ImageToImageModel):
"""
Model that exposes a method that returns a binary mask for a given input image and pixel class
"""
@abstractmethod
def segment(self, img: GenericImageDataAccessor, pixel_class: int, **kwargs) -> (GenericImageDataAccessor, dict):
pass
class DummyImageToImageModel(ImageToImageModel):
model_id = 'dummy_make_white_square'
......@@ -67,6 +76,7 @@ class DummyImageToImageModel(ImageToImageModel):
result[floor(0.25 * h) : floor(0.75 * h), floor(0.25 * w) : floor(0.75 * w)] = 255
return InMemoryDataAccessor(data=result), {'success': True}
class Error(Exception):
pass
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment