Skip to content
Snippets Groups Projects
Commit cb77c5f9 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Expose gaussian blur in ilastik segmentation model

parent 097a6ca2
No related branches found
No related tags found
No related merge requests found
......@@ -6,6 +6,7 @@ from math import ceil, floor
import numpy as np
import skimage
from skimage.exposure import rescale_intensity
from skimage.filters import gaussian
from skimage.measure import find_contours
def is_mask(img):
......@@ -128,6 +129,14 @@ def get_safe_contours(mask):
else:
return find_contours(mask)
def smooth(img: np.ndarray, sig: float) -> np.ndarray:
"""
Perform Gaussian smoothing on an image
:param img: image data
:param sig: threshold parameter
:return: smoothed image
"""
return gaussian(img, sig)
class Error(Exception):
pass
......
......@@ -4,11 +4,13 @@ from pathlib import Path
import numpy as np
from pydantic import BaseModel
from skimage.filters import gaussian
import vigra
import model_server.extensions.ilastik.conf
from model_server.base.accessors import PatchStack
from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor
from model_server.base.process import smooth
from model_server.base.models import Model, ImageToImageModel, InstanceSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel
class IlastikParams(BaseModel):
......@@ -92,7 +94,8 @@ class IlastikModel(Model):
class IlastikPixelClassifierParams(IlastikParams):
px_class: int = 0
px_prob_threshold = 0.5
px_prob_threshold: float = 0.5
px_smoothing: float = 0.0
class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
model_id = 'ilastik_pixel_classification'
......@@ -157,7 +160,12 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs):
pxmap, _ = self.infer(img)
mask = pxmap.data[:, :, self.params['px_class'], :] > self.params['px_prob_threshold']
sig = self.params['px_smoothing']
if sig > 0.0:
proc = smooth(img.data, sig)
else:
proc = pxmap.data
mask = proc[:, :, self.params['px_class'], :] > self.params['px_prob_threshold']
return InMemoryDataAccessor(mask)
......
import pathlib
import requests
import unittest
import numpy as np
......@@ -19,6 +18,11 @@ def _random_int(*args):
class TestIlastikPixelClassification(unittest.TestCase):
def setUp(self) -> None:
self.cf = CziImageFileAccessor(czifile['path'])
self.channel = 0
self.model = ilm.IlastikPixelClassifierModel(
params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px'].__str__())
)
self.mono_image = self.cf.get_one_channel_data(self.channel)
def test_faulthandler(self): # recreate error that is messing up ilastik
......@@ -45,59 +49,61 @@ class TestIlastikPixelClassification(unittest.TestCase):
def test_run_pixel_classifier_on_random_data(self):
model = ilm.IlastikPixelClassifierModel(
params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px'].__str__()),
)
w = 512
h = 256
input_img = InMemoryDataAccessor(data=np.random.rand(h, w, 1, 1))
mask = model.label_pixel_class(input_img)
mask = self.model.label_pixel_class(input_img)
self.assertEqual(mask.shape, (h, w, 1, 1))
def test_run_pixel_classifier(self):
channel = 0
model = ilm.IlastikPixelClassifierModel(
params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px'].__str__())
)
cf = CziImageFileAccessor(
czifile['path']
)
mono_image = cf.get_one_channel_data(channel)
self.assertEqual(self.mono_image.shape_dict['X'], czifile['w'])
self.assertEqual(self.mono_image.shape_dict['Y'], czifile['h'])
self.assertEqual(self.mono_image.shape_dict['C'], 1)
self.assertEqual(self.mono_image.shape_dict['Z'], 1)
self.assertEqual(mono_image.shape_dict['X'], czifile['w'])
self.assertEqual(mono_image.shape_dict['Y'], czifile['h'])
self.assertEqual(mono_image.shape_dict['C'], 1)
self.assertEqual(mono_image.shape_dict['Z'], 1)
mask = model.label_pixel_class(mono_image)
mask = self.model.label_pixel_class(self.mono_image)
self.assertTrue(mask.is_mask())
self.assertEqual(mask.shape[0:2], cf.shape[0:2])
self.assertEqual(mask.shape[0:2], self.cf.shape[0:2])
self.assertEqual(mask.shape_dict['C'], 1)
self.assertEqual(mask.shape_dict['Z'], 1)
self.assertTrue(
write_accessor_data_to_file(
output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif',
output_path / 'seg' / f'seg_{self.cf.fpath.stem}_ch{self.channel}.tif',
mask
)
)
self.mono_image = mono_image
self.mask = mask
def test_label_pixels_with_params(self):
def _run_seg(tr, sig):
mod = ilm.IlastikPixelClassifierModel(
params=ilm.IlastikPixelClassifierParams(
project_file=ilastik_classifiers['px'].__str__(),
px_prob_threshold=tr,
px_smoothing=sig,
),
)
mask = mod.label_pixel_class(self.mono_image)
write_accessor_data_to_file(
output_path / 'seg' / f'seg_tr{int(10*tr)}_sig{int(10*sig)}.tif',
mask
)
return mask
mask1 = _run_seg(0.5, 0.0)
mask2 = _run_seg(0.5, 0.2)
self.assertEqual(mask1.shape, mask2.shape)
def test_pixel_classifier_enforces_input_shape(self):
model = ilm.IlastikPixelClassifierModel(
params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px'].__str__())
)
self.assertEqual(model.model_chroma, 1)
self.assertEqual(model.model_3d, False)
self.assertEqual(self.model.model_chroma, 1)
self.assertEqual(self.model.model_3d, False)
# correct data
self.assertIsInstance(
model.label_pixel_class(
self.model.label_pixel_class(
InMemoryDataAccessor(
_random_int(512, 256, 1, 1)
)
......@@ -107,7 +113,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
# raise except with input of multiple channels
with self.assertRaises(ilm.IlastikInputShapeError):
mask = model.label_pixel_class(
mask = self.model.label_pixel_class(
InMemoryDataAccessor(
_random_int(512, 256, 3, 1)
)
......@@ -115,7 +121,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
# raise except with input of multiple channels
with self.assertRaises(ilm.IlastikInputShapeError):
mask = model.label_pixel_class(
mask = self.model.label_pixel_class(
InMemoryDataAccessor(
_random_int(512, 256, 1, 15)
)
......@@ -130,20 +136,16 @@ class TestIlastikPixelClassification(unittest.TestCase):
self.assertEqual(acc.hw, (512, 512))
self.assertEqual(acc.iat(0, crop=True).hw, (256, 512))
model = ilm.IlastikPixelClassifierModel(
ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px'].__str__()),
)
mask = model.label_patch_stack(acc)
mask = self.model.label_patch_stack(acc)
self.assertEqual(mask.dtype, bool)
self.assertEqual(mask.chroma, 1)
self.assertEqual(mask.hw, acc.hw)
self.assertEqual(mask.nz, acc.nz)
self.assertEqual(mask.count, acc.count)
pxmap, _ = model.infer_patch_stack(acc)
pxmap, _ = self.model.infer_patch_stack(acc)
self.assertEqual(pxmap.dtype, float)
self.assertEqual(pxmap.chroma, len(model.labels))
self.assertEqual(pxmap.chroma, len(self.model.labels))
self.assertEqual(pxmap.hw, acc.hw)
self.assertEqual(pxmap.nz, acc.nz)
self.assertEqual(pxmap.count, acc.count)
......@@ -154,7 +156,8 @@ class TestIlastikPixelClassification(unittest.TestCase):
model = ilm.IlastikObjectClassifierFromPixelPredictionsModel(
params=ilm.IlastikParams(project_file=ilastik_classifiers['pxmap_to_obj'].__str__())
)
objmap, _ = model.infer(self.mono_image, self.mask)
mask = self.model.label_pixel_class(self.mono_image)
objmap, _ = model.infer(self.mono_image, mask)
self.assertTrue(
write_accessor_data_to_file(
......@@ -171,7 +174,8 @@ class TestIlastikPixelClassification(unittest.TestCase):
model = ilm.IlastikObjectClassifierFromSegmentationModel(
params=ilm.IlastikParams(project_file=ilastik_classifiers['seg_to_obj'].__str__())
)
objmap = model.label_instance_class(self.mono_image, self.mask)
mask = self.model.label_pixel_class(self.mono_image)
objmap = model.label_instance_class(self.mono_image, mask)
self.assertTrue(
write_accessor_data_to_file(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment