From 04b4d5231329b447b4863832955d1006dcb346ed Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Fri, 25 Oct 2024 12:16:48 +0200 Subject: [PATCH] Models' inference methods now just return result, not tuple with {'success': True'} --- model_server/base/session.py | 2 +- model_server/extensions/ilastik/models.py | 22 +++++++++---------- .../ilastik/pipelines/px_then_ob.py | 4 ++-- tests/test_ilastik/test_ilastik.py | 6 ++--- tests/test_ilastik/test_roiset_workflow.py | 8 ++----- 5 files changed, 18 insertions(+), 24 deletions(-) diff --git a/model_server/base/session.py b/model_server/base/session.py index 4aa88a1e..be4911f0 100644 --- a/model_server/base/session.py +++ b/model_server/base/session.py @@ -251,7 +251,7 @@ class _Session(object): :param params: optional parameters that are passed to the model's construct :return: model_id of loaded model """ - mi = ModelClass(params=params.dict()) + mi = ModelClass(params=params.dict() if params else None) assert mi.loaded, f'Error loading instance of {ModelClass.__name__}' ii = 0 diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index dccb269a..5e45bab4 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -137,7 +137,7 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): [1, 2, 3, 0], [0, 1, 2, 3] ) - return InMemoryDataAccessor(data=yxcz), {'success': True} + return InMemoryDataAccessor(data=yxcz) def infer_patch_stack(self, img: PatchStack, **kwargs) -> (np.ndarray, dict): """ @@ -150,13 +150,13 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): for i in range(0, img.count): sl = img.get_slice_at(i) try: - data[i][sl[0], sl[1], :, sl[3]] = self.infer(img.iat(i, crop=True))[0].data + data[i][sl[0], sl[1], :, sl[3]] = self.infer(img.iat(i, crop=True)).data except FeatureSelectionConstraintError: # occurs occasionally on small patches continue - return PatchStack(data), {'success': True} + return PatchStack(data) def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs): - pxmap, _ = self.infer(img) + pxmap = self.infer(img) mask = pxmap.get_mono(self.params['px_class']).apply(lambda x: x > self.params['px_prob_threshold']) return mask @@ -219,19 +219,18 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment [0, 1, 2, 3, 4], [0, 4, 1, 2, 3] ) - return PatchStack(data=pyxcz), {'success': True} + return PatchStack(data=pyxcz) else: yxcz = np.moveaxis( obmaps[0], [1, 2, 3, 0], [0, 1, 2, 3] ) - return InMemoryDataAccessor(data=yxcz), {'success': True} + return InMemoryDataAccessor(data=yxcz) def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs): super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs) - obmap, _ = self.infer(img, mask) - return obmap + return self.infer(img, mask) class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImageModel): @@ -277,14 +276,14 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag [0, 1, 2, 3, 4], [0, 4, 1, 2, 3] ) - return PatchStack(data=pyxcz), {'success': True} + return PatchStack(data=pyxcz) else: yxcz = np.moveaxis( obmaps[0], [1, 2, 3, 0], [0, 1, 2, 3] ) - return InMemoryDataAccessor(data=yxcz), {'success': True} + return InMemoryDataAccessor(data=yxcz) def label_instance_class(self, img: GenericImageDataAccessor, pxmap: GenericImageDataAccessor, **kwargs): @@ -305,8 +304,7 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag pxch = kwargs.get('pixel_classification_channel', 0) pxtr = kwargs.get('pixel_classification_threshold', 0.5) mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr) - obmap, _ = self.infer(img, mask) - return obmap + return self.infer(img, mask) class Error(Exception): diff --git a/model_server/extensions/ilastik/pipelines/px_then_ob.py b/model_server/extensions/ilastik/pipelines/px_then_ob.py index 1aa64615..2e9fd6e4 100644 --- a/model_server/extensions/ilastik/pipelines/px_then_ob.py +++ b/model_server/extensions/ilastik/pipelines/px_then_ob.py @@ -57,8 +57,8 @@ def pixel_then_object_classification_pipeline( else: channels = range(0, d['input'].chroma) d['select_channels'] = d.last.get_channels(channels, mip=k.get('mip', False)) - d['pxmap'], _ = models['px_model'].infer(d.last) - d['ob_map'], _ = models['ob_model'].infer(d['select_channels'], d['pxmap']) + d['pxmap'] = models['px_model'].infer(d.last) + d['ob_map'] = models['ob_model'].infer(d['select_channels'], d['pxmap']) return d diff --git a/tests/test_ilastik/test_ilastik.py b/tests/test_ilastik/test_ilastik.py index a6f467f9..4eedf981 100644 --- a/tests/test_ilastik/test_ilastik.py +++ b/tests/test_ilastik/test_ilastik.py @@ -148,7 +148,7 @@ class TestIlastikPixelClassification(unittest.TestCase): self.assertEqual(mask.nz, acc.nz) self.assertEqual(mask.count, acc.count) - pxmap, _ = self.model.infer_patch_stack(acc) + pxmap = self.model.infer_patch_stack(acc) self.assertEqual(pxmap.dtype, float) self.assertEqual(pxmap.chroma, len(self.model.labels)) self.assertEqual(pxmap.hw, acc.hw) @@ -162,7 +162,7 @@ class TestIlastikPixelClassification(unittest.TestCase): params={'project_file': ilastik_classifiers['pxmap_to_obj']['path'].__str__()} ) mask = self.model.label_pixel_class(self.mono_image) - objmap, _ = model.infer(self.mono_image, mask) + objmap = model.infer(self.mono_image, mask) self.assertTrue( write_accessor_data_to_file( @@ -333,7 +333,7 @@ class TestIlastikOnMultichannelInputs(TestServerTestCase): img = generate_file_accessor(self.pa_input_image) self.assertGreater(img.chroma, 1) mod = ilm.IlastikPixelClassifierModel({'project_file': self.pa_px_classifier.__str__()}) - pxmap = mod.infer(img)[0] + pxmap = mod.infer(img) self.assertEqual(pxmap.hw, img.hw) self.assertEqual(pxmap.nz, img.nz) return pxmap diff --git a/tests/test_ilastik/test_roiset_workflow.py b/tests/test_ilastik/test_roiset_workflow.py index 5f799089..4adf24bb 100644 --- a/tests/test_ilastik/test_roiset_workflow.py +++ b/tests/test_ilastik/test_roiset_workflow.py @@ -72,18 +72,14 @@ class BaseTestRoiSetMonoProducts(object): 'name': 'ilastik_px_mod', 'project_file': fp_px, 'model': ilm.IlastikPixelClassifierModel( - ilm.IlastikPixelClassifierParams( - project_file=fp_px, - ) + {'project_file': fp_px}, ) }, 'object_classifier': { 'name': 'ilastik_ob_mod', 'project_file': fp_ob, 'model': ilm.IlastikObjectClassifierFromSegmentationModel( - ilm.IlastikParams( - project_file=fp_ob - ) + {'project_file': fp_ob}, ) }, } -- GitLab