Skip to content
Snippets Groups Projects
Commit cb77c5f9 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Expose gaussian blur in ilastik segmentation model

parent 097a6ca2
No related branches found
No related tags found
2 merge requests!50Release 2024.06.03,!42Models are now initialized with pydantic models
...@@ -6,6 +6,7 @@ from math import ceil, floor ...@@ -6,6 +6,7 @@ from math import ceil, floor
import numpy as np import numpy as np
import skimage import skimage
from skimage.exposure import rescale_intensity from skimage.exposure import rescale_intensity
from skimage.filters import gaussian
from skimage.measure import find_contours from skimage.measure import find_contours
def is_mask(img): def is_mask(img):
...@@ -128,6 +129,14 @@ def get_safe_contours(mask): ...@@ -128,6 +129,14 @@ def get_safe_contours(mask):
else: else:
return find_contours(mask) return find_contours(mask)
def smooth(img: np.ndarray, sig: float) -> np.ndarray:
"""
Perform Gaussian smoothing on an image
:param img: image data
:param sig: threshold parameter
:return: smoothed image
"""
return gaussian(img, sig)
class Error(Exception): class Error(Exception):
pass pass
......
...@@ -4,11 +4,13 @@ from pathlib import Path ...@@ -4,11 +4,13 @@ from pathlib import Path
import numpy as np import numpy as np
from pydantic import BaseModel from pydantic import BaseModel
from skimage.filters import gaussian
import vigra import vigra
import model_server.extensions.ilastik.conf import model_server.extensions.ilastik.conf
from model_server.base.accessors import PatchStack from model_server.base.accessors import PatchStack
from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor
from model_server.base.process import smooth
from model_server.base.models import Model, ImageToImageModel, InstanceSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel from model_server.base.models import Model, ImageToImageModel, InstanceSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel
class IlastikParams(BaseModel): class IlastikParams(BaseModel):
...@@ -92,7 +94,8 @@ class IlastikModel(Model): ...@@ -92,7 +94,8 @@ class IlastikModel(Model):
class IlastikPixelClassifierParams(IlastikParams): class IlastikPixelClassifierParams(IlastikParams):
px_class: int = 0 px_class: int = 0
px_prob_threshold = 0.5 px_prob_threshold: float = 0.5
px_smoothing: float = 0.0
class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
model_id = 'ilastik_pixel_classification' model_id = 'ilastik_pixel_classification'
...@@ -157,7 +160,12 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): ...@@ -157,7 +160,12 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs): def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs):
pxmap, _ = self.infer(img) pxmap, _ = self.infer(img)
mask = pxmap.data[:, :, self.params['px_class'], :] > self.params['px_prob_threshold'] sig = self.params['px_smoothing']
if sig > 0.0:
proc = smooth(img.data, sig)
else:
proc = pxmap.data
mask = proc[:, :, self.params['px_class'], :] > self.params['px_prob_threshold']
return InMemoryDataAccessor(mask) return InMemoryDataAccessor(mask)
......
import pathlib import pathlib
import requests
import unittest import unittest
import numpy as np import numpy as np
...@@ -19,6 +18,11 @@ def _random_int(*args): ...@@ -19,6 +18,11 @@ def _random_int(*args):
class TestIlastikPixelClassification(unittest.TestCase): class TestIlastikPixelClassification(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.cf = CziImageFileAccessor(czifile['path']) self.cf = CziImageFileAccessor(czifile['path'])
self.channel = 0
self.model = ilm.IlastikPixelClassifierModel(
params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px'].__str__())
)
self.mono_image = self.cf.get_one_channel_data(self.channel)
def test_faulthandler(self): # recreate error that is messing up ilastik def test_faulthandler(self): # recreate error that is messing up ilastik
...@@ -45,59 +49,61 @@ class TestIlastikPixelClassification(unittest.TestCase): ...@@ -45,59 +49,61 @@ class TestIlastikPixelClassification(unittest.TestCase):
def test_run_pixel_classifier_on_random_data(self): def test_run_pixel_classifier_on_random_data(self):
model = ilm.IlastikPixelClassifierModel(
params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px'].__str__()),
)
w = 512 w = 512
h = 256 h = 256
input_img = InMemoryDataAccessor(data=np.random.rand(h, w, 1, 1)) input_img = InMemoryDataAccessor(data=np.random.rand(h, w, 1, 1))
mask = model.label_pixel_class(input_img) mask = self.model.label_pixel_class(input_img)
self.assertEqual(mask.shape, (h, w, 1, 1)) self.assertEqual(mask.shape, (h, w, 1, 1))
def test_run_pixel_classifier(self): def test_run_pixel_classifier(self):
channel = 0 self.assertEqual(self.mono_image.shape_dict['X'], czifile['w'])
model = ilm.IlastikPixelClassifierModel( self.assertEqual(self.mono_image.shape_dict['Y'], czifile['h'])
params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px'].__str__()) self.assertEqual(self.mono_image.shape_dict['C'], 1)
) self.assertEqual(self.mono_image.shape_dict['Z'], 1)
cf = CziImageFileAccessor(
czifile['path']
)
mono_image = cf.get_one_channel_data(channel)
self.assertEqual(mono_image.shape_dict['X'], czifile['w']) mask = self.model.label_pixel_class(self.mono_image)
self.assertEqual(mono_image.shape_dict['Y'], czifile['h'])
self.assertEqual(mono_image.shape_dict['C'], 1)
self.assertEqual(mono_image.shape_dict['Z'], 1)
mask = model.label_pixel_class(mono_image)
self.assertTrue(mask.is_mask()) self.assertTrue(mask.is_mask())
self.assertEqual(mask.shape[0:2], cf.shape[0:2]) self.assertEqual(mask.shape[0:2], self.cf.shape[0:2])
self.assertEqual(mask.shape_dict['C'], 1) self.assertEqual(mask.shape_dict['C'], 1)
self.assertEqual(mask.shape_dict['Z'], 1) self.assertEqual(mask.shape_dict['Z'], 1)
self.assertTrue( self.assertTrue(
write_accessor_data_to_file( write_accessor_data_to_file(
output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif', output_path / 'seg' / f'seg_{self.cf.fpath.stem}_ch{self.channel}.tif',
mask mask
) )
) )
self.mono_image = mono_image def test_label_pixels_with_params(self):
self.mask = mask def _run_seg(tr, sig):
mod = ilm.IlastikPixelClassifierModel(
params=ilm.IlastikPixelClassifierParams(
project_file=ilastik_classifiers['px'].__str__(),
px_prob_threshold=tr,
px_smoothing=sig,
),
)
mask = mod.label_pixel_class(self.mono_image)
write_accessor_data_to_file(
output_path / 'seg' / f'seg_tr{int(10*tr)}_sig{int(10*sig)}.tif',
mask
)
return mask
mask1 = _run_seg(0.5, 0.0)
mask2 = _run_seg(0.5, 0.2)
self.assertEqual(mask1.shape, mask2.shape)
def test_pixel_classifier_enforces_input_shape(self): def test_pixel_classifier_enforces_input_shape(self):
model = ilm.IlastikPixelClassifierModel( self.assertEqual(self.model.model_chroma, 1)
params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px'].__str__()) self.assertEqual(self.model.model_3d, False)
)
self.assertEqual(model.model_chroma, 1)
self.assertEqual(model.model_3d, False)
# correct data # correct data
self.assertIsInstance( self.assertIsInstance(
model.label_pixel_class( self.model.label_pixel_class(
InMemoryDataAccessor( InMemoryDataAccessor(
_random_int(512, 256, 1, 1) _random_int(512, 256, 1, 1)
) )
...@@ -107,7 +113,7 @@ class TestIlastikPixelClassification(unittest.TestCase): ...@@ -107,7 +113,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
# raise except with input of multiple channels # raise except with input of multiple channels
with self.assertRaises(ilm.IlastikInputShapeError): with self.assertRaises(ilm.IlastikInputShapeError):
mask = model.label_pixel_class( mask = self.model.label_pixel_class(
InMemoryDataAccessor( InMemoryDataAccessor(
_random_int(512, 256, 3, 1) _random_int(512, 256, 3, 1)
) )
...@@ -115,7 +121,7 @@ class TestIlastikPixelClassification(unittest.TestCase): ...@@ -115,7 +121,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
# raise except with input of multiple channels # raise except with input of multiple channels
with self.assertRaises(ilm.IlastikInputShapeError): with self.assertRaises(ilm.IlastikInputShapeError):
mask = model.label_pixel_class( mask = self.model.label_pixel_class(
InMemoryDataAccessor( InMemoryDataAccessor(
_random_int(512, 256, 1, 15) _random_int(512, 256, 1, 15)
) )
...@@ -130,20 +136,16 @@ class TestIlastikPixelClassification(unittest.TestCase): ...@@ -130,20 +136,16 @@ class TestIlastikPixelClassification(unittest.TestCase):
self.assertEqual(acc.hw, (512, 512)) self.assertEqual(acc.hw, (512, 512))
self.assertEqual(acc.iat(0, crop=True).hw, (256, 512)) self.assertEqual(acc.iat(0, crop=True).hw, (256, 512))
model = ilm.IlastikPixelClassifierModel( mask = self.model.label_patch_stack(acc)
ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px'].__str__()),
)
mask = model.label_patch_stack(acc)
self.assertEqual(mask.dtype, bool) self.assertEqual(mask.dtype, bool)
self.assertEqual(mask.chroma, 1) self.assertEqual(mask.chroma, 1)
self.assertEqual(mask.hw, acc.hw) self.assertEqual(mask.hw, acc.hw)
self.assertEqual(mask.nz, acc.nz) self.assertEqual(mask.nz, acc.nz)
self.assertEqual(mask.count, acc.count) self.assertEqual(mask.count, acc.count)
pxmap, _ = model.infer_patch_stack(acc) pxmap, _ = self.model.infer_patch_stack(acc)
self.assertEqual(pxmap.dtype, float) self.assertEqual(pxmap.dtype, float)
self.assertEqual(pxmap.chroma, len(model.labels)) self.assertEqual(pxmap.chroma, len(self.model.labels))
self.assertEqual(pxmap.hw, acc.hw) self.assertEqual(pxmap.hw, acc.hw)
self.assertEqual(pxmap.nz, acc.nz) self.assertEqual(pxmap.nz, acc.nz)
self.assertEqual(pxmap.count, acc.count) self.assertEqual(pxmap.count, acc.count)
...@@ -154,7 +156,8 @@ class TestIlastikPixelClassification(unittest.TestCase): ...@@ -154,7 +156,8 @@ class TestIlastikPixelClassification(unittest.TestCase):
model = ilm.IlastikObjectClassifierFromPixelPredictionsModel( model = ilm.IlastikObjectClassifierFromPixelPredictionsModel(
params=ilm.IlastikParams(project_file=ilastik_classifiers['pxmap_to_obj'].__str__()) params=ilm.IlastikParams(project_file=ilastik_classifiers['pxmap_to_obj'].__str__())
) )
objmap, _ = model.infer(self.mono_image, self.mask) mask = self.model.label_pixel_class(self.mono_image)
objmap, _ = model.infer(self.mono_image, mask)
self.assertTrue( self.assertTrue(
write_accessor_data_to_file( write_accessor_data_to_file(
...@@ -171,7 +174,8 @@ class TestIlastikPixelClassification(unittest.TestCase): ...@@ -171,7 +174,8 @@ class TestIlastikPixelClassification(unittest.TestCase):
model = ilm.IlastikObjectClassifierFromSegmentationModel( model = ilm.IlastikObjectClassifierFromSegmentationModel(
params=ilm.IlastikParams(project_file=ilastik_classifiers['seg_to_obj'].__str__()) params=ilm.IlastikParams(project_file=ilastik_classifiers['seg_to_obj'].__str__())
) )
objmap = model.label_instance_class(self.mono_image, self.mask) mask = self.model.label_pixel_class(self.mono_image)
objmap = model.label_instance_class(self.mono_image, mask)
self.assertTrue( self.assertTrue(
write_accessor_data_to_file( write_accessor_data_to_file(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment