diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index 90edcec53164158d6c3fe9f4e4be3e74e489600e..2dea13329513d0256baf7e9f2c698d4731301893 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -162,8 +162,15 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment return ObjectClassificationWorkflowBinary def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): - if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d(): - raise IlastikInputShapeError() + if self.model_chroma != input_img.chroma: + raise IlastikInputShapeError( + f'Model {self} expects {self.model_chroma} input channels but received only {input_img.chroma}' + ) + if self.model_3d != input_img.is_3d(): + if self.model_3d: + raise IlastikInputShapeError(f'Model is 3D but input image is 2D') + else: + raise IlastikInputShapeError(f'Model is 2D but input image is 3D') assert segmentation_img.is_mask() if isinstance(input_img, PatchStack): @@ -225,8 +232,13 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d(): raise IlastikInputShapeError() - tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') - tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz') + if isinstance(input_img, PatchStack): + assert isinstance(pxmap_img, PatchStack) + tagged_input_data = vigra.taggedView(input_img.pczyx, 'tczyx') + tagged_pxmap_data = vigra.taggedView(pxmap_img.pczyx, 'tczyx') + else: + tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') + tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz') dsi = [ { @@ -239,12 +251,20 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag assert len(obmaps) == 1, 'ilastik generated more than one object map' - yxcz = np.moveaxis( - obmaps[0], - [1, 2, 3, 0], - [0, 1, 2, 3] - ) - return InMemoryDataAccessor(data=yxcz), {'success': True} + if isinstance(input_img, PatchStack): + pyxcz = np.moveaxis( + obmaps[0], + [0, 1, 2, 3, 4], + [0, 4, 1, 2, 3] + ) + return PatchStack(data=pyxcz), {'success': True} + else: + yxcz = np.moveaxis( + obmaps[0], + [1, 2, 3, 0], + [0, 1, 2, 3] + ) + return InMemoryDataAccessor(data=yxcz), {'success': True} def label_instance_class(self, img: GenericImageDataAccessor, pxmap: GenericImageDataAccessor, **kwargs): @@ -262,10 +282,30 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape') pxch = kwargs.get('pixel_classification_channel', 0) pxtr = kwargs.get('pixel_classification_threshold', 0.5) - mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr) + mask = img._derived_accessor(pxmap.get_one_channel_data(pxch).data > pxtr) obmap, _ = self.infer(img, mask) return obmap + def make_instance_segmentation_model(self, px_ch: int): + """ + Generate an instance segmentation model, i.e. one that takes binary masks instead of pixel probabilities as a + second input. + :param px_ch: channel of pixel probability map to use + :return: + InstanceSegmentationModel object + """ + class _Mod(self.__class__, InstanceSegmentationModel): + def label_instance_class( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs + ) -> GenericImageDataAccessor: + if mask.dtype == 'bool': + norm_mask = 1.0 * mask.data + else: + norm_mask = mask.data / np.iinfo(mask.dtype).max + norm_mask_acc = mask._derived_accessor(norm_mask.astype('float32')) + return super().label_instance_class(img, norm_mask_acc, pixel_classification_channel=px_ch) + return _Mod(params={'project_file': self.project_file}) + class Error(Exception): diff --git a/model_server/extensions/ilastik/router.py b/model_server/extensions/ilastik/router.py index 151e37915c16dd8a3943dc0b07e1e188d45606a3..c40679c02a95bb386aaad613c26cc29e28b48055 100644 --- a/model_server/extensions/ilastik/router.py +++ b/model_server/extensions/ilastik/router.py @@ -27,14 +27,8 @@ def load_ilastik_model(model_class: ilm.IlastikModel, project_file: str, duplica if existing_model_id is not None: session.log_info(f'An ilastik model from {project_file} already existing exists; did not load a duplicate') return {'model_id': existing_model_id} - try: - result = session.load_model(model_class, {'project_file': project_file}) - session.log_info(f'Loaded ilastik model {result} from {project_file}') - except (FileNotFoundError, ParameterExpectedError): - raise HTTPException( - status_code=404, - detail=f'Could not load project file {project_file}', - ) + result = session.load_model(model_class, {'project_file': project_file}) + session.log_info(f'Loaded ilastik model {result} from {project_file}') return {'model_id': result} @router.put('/seg/load/') diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py index 7a3a9c38bd495af6f146d63e6adb96f4215767c2..a1a4983efbded47e54804521c04394c55795439c 100644 --- a/model_server/extensions/ilastik/tests/test_ilastik.py +++ b/model_server/extensions/ilastik/tests/test_ilastik.py @@ -199,7 +199,7 @@ class TestIlastikOverApi(TestServerBaseClass): 'ilastik/seg/load/', query={'project_file': 'improper.ilp'}, ) - self.assertEqual(resp_load.status_code, 404) + self.assertEqual(resp_load.status_code, 500) def test_load_ilastik_pixel_model(self): @@ -319,6 +319,7 @@ class TestIlastikOverApi(TestServerBaseClass): ) self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) + class TestIlastikObjectClassification(unittest.TestCase): def setUp(self): stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path'])