From cb77c5f94d85a211b4f16dd3525702a68c982516 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Mon, 6 May 2024 15:20:24 +0200 Subject: [PATCH] Expose gaussian blur in ilastik segmentation model --- model_server/base/process.py | 9 ++ model_server/extensions/ilastik/models.py | 12 ++- .../extensions/ilastik/tests/test_ilastik.py | 84 ++++++++++--------- 3 files changed, 63 insertions(+), 42 deletions(-) diff --git a/model_server/base/process.py b/model_server/base/process.py index 94c6cf62..d475b063 100644 --- a/model_server/base/process.py +++ b/model_server/base/process.py @@ -6,6 +6,7 @@ from math import ceil, floor import numpy as np import skimage from skimage.exposure import rescale_intensity +from skimage.filters import gaussian from skimage.measure import find_contours def is_mask(img): @@ -128,6 +129,14 @@ def get_safe_contours(mask): else: return find_contours(mask) +def smooth(img: np.ndarray, sig: float) -> np.ndarray: + """ + Perform Gaussian smoothing on an image + :param img: image data + :param sig: threshold parameter + :return: smoothed image + """ + return gaussian(img, sig) class Error(Exception): pass diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index 15cbd4f0..4a1c4c0a 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -4,11 +4,13 @@ from pathlib import Path import numpy as np from pydantic import BaseModel +from skimage.filters import gaussian import vigra import model_server.extensions.ilastik.conf from model_server.base.accessors import PatchStack from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor +from model_server.base.process import smooth from model_server.base.models import Model, ImageToImageModel, InstanceSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel class IlastikParams(BaseModel): @@ -92,7 +94,8 @@ class IlastikModel(Model): class IlastikPixelClassifierParams(IlastikParams): px_class: int = 0 - px_prob_threshold = 0.5 + px_prob_threshold: float = 0.5 + px_smoothing: float = 0.0 class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): model_id = 'ilastik_pixel_classification' @@ -157,7 +160,12 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs): pxmap, _ = self.infer(img) - mask = pxmap.data[:, :, self.params['px_class'], :] > self.params['px_prob_threshold'] + sig = self.params['px_smoothing'] + if sig > 0.0: + proc = smooth(img.data, sig) + else: + proc = pxmap.data + mask = proc[:, :, self.params['px_class'], :] > self.params['px_prob_threshold'] return InMemoryDataAccessor(mask) diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py index 40719563..a48867c0 100644 --- a/model_server/extensions/ilastik/tests/test_ilastik.py +++ b/model_server/extensions/ilastik/tests/test_ilastik.py @@ -1,5 +1,4 @@ import pathlib -import requests import unittest import numpy as np @@ -19,6 +18,11 @@ def _random_int(*args): class TestIlastikPixelClassification(unittest.TestCase): def setUp(self) -> None: self.cf = CziImageFileAccessor(czifile['path']) + self.channel = 0 + self.model = ilm.IlastikPixelClassifierModel( + params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px'].__str__()) + ) + self.mono_image = self.cf.get_one_channel_data(self.channel) def test_faulthandler(self): # recreate error that is messing up ilastik @@ -45,59 +49,61 @@ class TestIlastikPixelClassification(unittest.TestCase): def test_run_pixel_classifier_on_random_data(self): - model = ilm.IlastikPixelClassifierModel( - params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px'].__str__()), - ) w = 512 h = 256 input_img = InMemoryDataAccessor(data=np.random.rand(h, w, 1, 1)) - mask = model.label_pixel_class(input_img) + mask = self.model.label_pixel_class(input_img) self.assertEqual(mask.shape, (h, w, 1, 1)) def test_run_pixel_classifier(self): - channel = 0 - model = ilm.IlastikPixelClassifierModel( - params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px'].__str__()) - ) - cf = CziImageFileAccessor( - czifile['path'] - ) - mono_image = cf.get_one_channel_data(channel) + self.assertEqual(self.mono_image.shape_dict['X'], czifile['w']) + self.assertEqual(self.mono_image.shape_dict['Y'], czifile['h']) + self.assertEqual(self.mono_image.shape_dict['C'], 1) + self.assertEqual(self.mono_image.shape_dict['Z'], 1) - self.assertEqual(mono_image.shape_dict['X'], czifile['w']) - self.assertEqual(mono_image.shape_dict['Y'], czifile['h']) - self.assertEqual(mono_image.shape_dict['C'], 1) - self.assertEqual(mono_image.shape_dict['Z'], 1) - - mask = model.label_pixel_class(mono_image) + mask = self.model.label_pixel_class(self.mono_image) self.assertTrue(mask.is_mask()) - self.assertEqual(mask.shape[0:2], cf.shape[0:2]) + self.assertEqual(mask.shape[0:2], self.cf.shape[0:2]) self.assertEqual(mask.shape_dict['C'], 1) self.assertEqual(mask.shape_dict['Z'], 1) self.assertTrue( write_accessor_data_to_file( - output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif', + output_path / 'seg' / f'seg_{self.cf.fpath.stem}_ch{self.channel}.tif', mask ) ) - self.mono_image = mono_image - self.mask = mask + def test_label_pixels_with_params(self): + def _run_seg(tr, sig): + mod = ilm.IlastikPixelClassifierModel( + params=ilm.IlastikPixelClassifierParams( + project_file=ilastik_classifiers['px'].__str__(), + px_prob_threshold=tr, + px_smoothing=sig, + ), + ) + mask = mod.label_pixel_class(self.mono_image) + write_accessor_data_to_file( + output_path / 'seg' / f'seg_tr{int(10*tr)}_sig{int(10*sig)}.tif', + mask + ) + return mask + mask1 = _run_seg(0.5, 0.0) + mask2 = _run_seg(0.5, 0.2) + self.assertEqual(mask1.shape, mask2.shape) + def test_pixel_classifier_enforces_input_shape(self): - model = ilm.IlastikPixelClassifierModel( - params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px'].__str__()) - ) - self.assertEqual(model.model_chroma, 1) - self.assertEqual(model.model_3d, False) + self.assertEqual(self.model.model_chroma, 1) + self.assertEqual(self.model.model_3d, False) # correct data self.assertIsInstance( - model.label_pixel_class( + self.model.label_pixel_class( InMemoryDataAccessor( _random_int(512, 256, 1, 1) ) @@ -107,7 +113,7 @@ class TestIlastikPixelClassification(unittest.TestCase): # raise except with input of multiple channels with self.assertRaises(ilm.IlastikInputShapeError): - mask = model.label_pixel_class( + mask = self.model.label_pixel_class( InMemoryDataAccessor( _random_int(512, 256, 3, 1) ) @@ -115,7 +121,7 @@ class TestIlastikPixelClassification(unittest.TestCase): # raise except with input of multiple channels with self.assertRaises(ilm.IlastikInputShapeError): - mask = model.label_pixel_class( + mask = self.model.label_pixel_class( InMemoryDataAccessor( _random_int(512, 256, 1, 15) ) @@ -130,20 +136,16 @@ class TestIlastikPixelClassification(unittest.TestCase): self.assertEqual(acc.hw, (512, 512)) self.assertEqual(acc.iat(0, crop=True).hw, (256, 512)) - model = ilm.IlastikPixelClassifierModel( - ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px'].__str__()), - ) - - mask = model.label_patch_stack(acc) + mask = self.model.label_patch_stack(acc) self.assertEqual(mask.dtype, bool) self.assertEqual(mask.chroma, 1) self.assertEqual(mask.hw, acc.hw) self.assertEqual(mask.nz, acc.nz) self.assertEqual(mask.count, acc.count) - pxmap, _ = model.infer_patch_stack(acc) + pxmap, _ = self.model.infer_patch_stack(acc) self.assertEqual(pxmap.dtype, float) - self.assertEqual(pxmap.chroma, len(model.labels)) + self.assertEqual(pxmap.chroma, len(self.model.labels)) self.assertEqual(pxmap.hw, acc.hw) self.assertEqual(pxmap.nz, acc.nz) self.assertEqual(pxmap.count, acc.count) @@ -154,7 +156,8 @@ class TestIlastikPixelClassification(unittest.TestCase): model = ilm.IlastikObjectClassifierFromPixelPredictionsModel( params=ilm.IlastikParams(project_file=ilastik_classifiers['pxmap_to_obj'].__str__()) ) - objmap, _ = model.infer(self.mono_image, self.mask) + mask = self.model.label_pixel_class(self.mono_image) + objmap, _ = model.infer(self.mono_image, mask) self.assertTrue( write_accessor_data_to_file( @@ -171,7 +174,8 @@ class TestIlastikPixelClassification(unittest.TestCase): model = ilm.IlastikObjectClassifierFromSegmentationModel( params=ilm.IlastikParams(project_file=ilastik_classifiers['seg_to_obj'].__str__()) ) - objmap = model.label_instance_class(self.mono_image, self.mask) + mask = self.model.label_pixel_class(self.mono_image) + objmap = model.label_instance_class(self.mono_image, mask) self.assertTrue( write_accessor_data_to_file( -- GitLab