Skip to content
Snippets Groups Projects
Commit f9f86211 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Moved intensity gate model to RoiSet module

parent 623f707c
No related branches found
No related tags found
No related merge requests found
......@@ -3,8 +3,8 @@ from typing import List, Union
from fastapi import FastAPI, HTTPException
from .accessors import generate_file_accessor
from .models import BinaryThresholdSegmentationModel, IntensityThresholdInstanceMaskSegmentationModel
from .roiset import RoiSetExportParams, SerializeRoiSetError
from .models import BinaryThresholdSegmentationModel
from .roiset import IntensityThresholdInstanceMaskSegmentationModel, RoiSetExportParams, SerializeRoiSetError
from .session import session, AccessorIdError, InvalidPathError, RoiSetIdError, WriteAccessorError
app = FastAPI(debug=True)
......
......@@ -158,31 +158,6 @@ class InstanceMaskSegmentationModel(ImageToImageModel):
return PatchStack(data)
class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationModel):
def __init__(self, tr: float = 0.5, channel: int = 0):
"""
Model that labels all objects as class 1 if the intensity in specified channel exceeds a threshold; labels all
objects as class 1 if threshold = 0.0
:param tr: threshold in range of 0.0 to 1.0; model handles normalization to full pixel intensity range
:param channel: channel to use for thresholding
"""
self.tr = tr
self.channel = channel
self.loaded = self.load()
super().__init__(info={'tr': tr, 'channel': channel})
def load(self):
return True
def infer(self, acc: GenericImageDataAccessor, mask: GenericImageDataAccessor) -> GenericImageDataAccessor:
return mask.apply(lambda x: (1 * (x > self.tr)).astype(acc.dtype))
def label_instance_class(
self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
) -> GenericImageDataAccessor:
super().label_instance_class(img, mask, **kwargs)
return self.infer(img, mask)
class Error(Exception):
pass
......
......@@ -1081,6 +1081,7 @@ class RoiSet(object):
class RoiSetWithDerivedChannelsExportParams(RoiSetExportParams):
derived_channels: bool = False
class RoiSetWithDerivedChannels(RoiSet):
def __init__(self, *a, **k):
......@@ -1165,32 +1166,69 @@ class RoiSetWithDerivedChannels(RoiSet):
record[k].append(str(fp))
return record
class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationModel):
def __init__(self, tr: float = 0.5, channel: int = 0):
"""
Model that labels all objects as class 1 if the intensity in specified channel exceeds a threshold; labels all
objects as class 1 if threshold = 0.0
:param tr: threshold in range of 0.0 to 1.0; model handles normalization to full pixel intensity range
:param channel: channel to use for thresholding
"""
self.tr = tr
self.channel = channel
self.loaded = self.load()
super().__init__(info={'tr': tr, 'channel': channel})
def load(self):
return True
def infer(self, acc: GenericImageDataAccessor, mask: GenericImageDataAccessor) -> GenericImageDataAccessor:
# TODO: gate by intensity in image, not threshold in mask channel; this will require iterating on labeled objects
return mask.apply(lambda x: (1 * (x > self.tr)).astype(acc.dtype))
def label_instance_class(
self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
) -> GenericImageDataAccessor:
super().label_instance_class(img, mask, **kwargs)
return self.infer(img, mask)
class Error(Exception):
pass
class BoundingBoxError(Error):
pass
class DeserializeRoiSetError(Error):
pass
class SerializeRoiSetError(Error):
pass
class NoDeprojectChannelSpecifiedError(Error):
pass
class DerivedChannelError(Error):
pass
class MissingSegmentationError(Error):
pass
class PatchMaskShapeError(Error):
pass
class ShapeMismatchError(Error):
pass
class MissingInstanceLabelsError(Error):
pass
\ No newline at end of file
......@@ -68,9 +68,3 @@ class TestCziImageFileAccess(unittest.TestCase):
obmap = model.label_instance_class(img, mask)
self.assertTrue(all(obmap.unique()[0] == [0, 1]))
self.assertTrue(all(obmap.unique()[1] > 0))
def test_permissive_instance_segmentation(self):
img, mask = self.test_dummy_pixel_segmentation()
model = IntensityThresholdInstanceMaskSegmentationModel()
obmap = model.label_instance_class(img, mask)
self.assertTrue(np.all(mask.data == 255 * obmap.data))
......@@ -7,9 +7,8 @@ from pathlib import Path
import pandas as pd
from model_server.base.process import smooth
from model_server.base.roiset import filter_df_overlap_bbox, filter_df_overlap_seg, RoiSetExportParams, RoiSetMetaParams
from model_server.base.roiset import RoiSet
from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file, PatchStack
from model_server.base.roiset import filter_df_overlap_bbox, filter_df_overlap_seg, RoiSet, RoiSetExportParams, RoiSetMetaParams
from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
import model_server.conf.testing as conf
from model_server.conf.testing import DummyInstanceMaskSegmentationModel
......@@ -819,3 +818,15 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase):
self.assertTrue((res.bbox_intersec == 2).all())
self.assertTrue((res.loc[res.seg_overlaps, :].index == [1]).all())
self.assertTrue((res.loc[res.seg_overlaps, 'seg_iou'] == 0.4).all())
class TestIntensityThresholdObjectModel(BaseTestRoiSetMonoProducts, unittest.TestCase):
def test_permissive_instance_segmentation(self):
from model_server.base.roiset import IntensityThresholdInstanceMaskSegmentationModel
img = self.stack.get_mono(channel=0, mip=True)
mask = self.seg_mask
# TODO: test on actual binary mask, not dummy square mask
model = IntensityThresholdInstanceMaskSegmentationModel()
obmap = model.label_instance_class(img, mask)
self.assertTrue(np.all(mask.data == 255 * obmap.data))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment