From f9f8621180a9dd5e837b3f62f1e0d066e03752ec Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Wed, 30 Oct 2024 10:17:36 +0100 Subject: [PATCH] Moved intensity gate model to RoiSet module --- model_server/base/api.py | 4 ++-- model_server/base/models.py | 25 ------------------------ model_server/base/roiset.py | 38 +++++++++++++++++++++++++++++++++++++ tests/base/test_model.py | 6 ------ tests/base/test_roiset.py | 17 ++++++++++++++--- 5 files changed, 54 insertions(+), 36 deletions(-) diff --git a/model_server/base/api.py b/model_server/base/api.py index 6d9f5814..081b3b03 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -3,8 +3,8 @@ from typing import List, Union from fastapi import FastAPI, HTTPException from .accessors import generate_file_accessor -from .models import BinaryThresholdSegmentationModel, IntensityThresholdInstanceMaskSegmentationModel -from .roiset import RoiSetExportParams, SerializeRoiSetError +from .models import BinaryThresholdSegmentationModel +from .roiset import IntensityThresholdInstanceMaskSegmentationModel, RoiSetExportParams, SerializeRoiSetError from .session import session, AccessorIdError, InvalidPathError, RoiSetIdError, WriteAccessorError app = FastAPI(debug=True) diff --git a/model_server/base/models.py b/model_server/base/models.py index 3e026143..800ec951 100644 --- a/model_server/base/models.py +++ b/model_server/base/models.py @@ -158,31 +158,6 @@ class InstanceMaskSegmentationModel(ImageToImageModel): return PatchStack(data) -class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationModel): - def __init__(self, tr: float = 0.5, channel: int = 0): - """ - Model that labels all objects as class 1 if the intensity in specified channel exceeds a threshold; labels all - objects as class 1 if threshold = 0.0 - :param tr: threshold in range of 0.0 to 1.0; model handles normalization to full pixel intensity range - :param channel: channel to use for thresholding - """ - self.tr = tr - self.channel = channel - self.loaded = self.load() - super().__init__(info={'tr': tr, 'channel': channel}) - - def load(self): - return True - - def infer(self, acc: GenericImageDataAccessor, mask: GenericImageDataAccessor) -> GenericImageDataAccessor: - return mask.apply(lambda x: (1 * (x > self.tr)).astype(acc.dtype)) - - def label_instance_class( - self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs - ) -> GenericImageDataAccessor: - super().label_instance_class(img, mask, **kwargs) - return self.infer(img, mask) - class Error(Exception): pass diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index ba7c6b06..ac93238b 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -1081,6 +1081,7 @@ class RoiSet(object): class RoiSetWithDerivedChannelsExportParams(RoiSetExportParams): derived_channels: bool = False + class RoiSetWithDerivedChannels(RoiSet): def __init__(self, *a, **k): @@ -1165,32 +1166,69 @@ class RoiSetWithDerivedChannels(RoiSet): record[k].append(str(fp)) return record + +class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationModel): + def __init__(self, tr: float = 0.5, channel: int = 0): + """ + Model that labels all objects as class 1 if the intensity in specified channel exceeds a threshold; labels all + objects as class 1 if threshold = 0.0 + :param tr: threshold in range of 0.0 to 1.0; model handles normalization to full pixel intensity range + :param channel: channel to use for thresholding + """ + self.tr = tr + self.channel = channel + self.loaded = self.load() + super().__init__(info={'tr': tr, 'channel': channel}) + + def load(self): + return True + + def infer(self, acc: GenericImageDataAccessor, mask: GenericImageDataAccessor) -> GenericImageDataAccessor: + # TODO: gate by intensity in image, not threshold in mask channel; this will require iterating on labeled objects + return mask.apply(lambda x: (1 * (x > self.tr)).astype(acc.dtype)) + + def label_instance_class( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs + ) -> GenericImageDataAccessor: + super().label_instance_class(img, mask, **kwargs) + return self.infer(img, mask) + + class Error(Exception): pass + class BoundingBoxError(Error): pass + class DeserializeRoiSetError(Error): pass + class SerializeRoiSetError(Error): pass + class NoDeprojectChannelSpecifiedError(Error): pass + class DerivedChannelError(Error): pass + class MissingSegmentationError(Error): pass + class PatchMaskShapeError(Error): pass + class ShapeMismatchError(Error): pass + class MissingInstanceLabelsError(Error): pass \ No newline at end of file diff --git a/tests/base/test_model.py b/tests/base/test_model.py index e5c96894..8b4aa3a7 100644 --- a/tests/base/test_model.py +++ b/tests/base/test_model.py @@ -68,9 +68,3 @@ class TestCziImageFileAccess(unittest.TestCase): obmap = model.label_instance_class(img, mask) self.assertTrue(all(obmap.unique()[0] == [0, 1])) self.assertTrue(all(obmap.unique()[1] > 0)) - - def test_permissive_instance_segmentation(self): - img, mask = self.test_dummy_pixel_segmentation() - model = IntensityThresholdInstanceMaskSegmentationModel() - obmap = model.label_instance_class(img, mask) - self.assertTrue(np.all(mask.data == 255 * obmap.data)) diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index c510510a..b9b89410 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -7,9 +7,8 @@ from pathlib import Path import pandas as pd from model_server.base.process import smooth -from model_server.base.roiset import filter_df_overlap_bbox, filter_df_overlap_seg, RoiSetExportParams, RoiSetMetaParams -from model_server.base.roiset import RoiSet -from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file, PatchStack +from model_server.base.roiset import filter_df_overlap_bbox, filter_df_overlap_seg, RoiSet, RoiSetExportParams, RoiSetMetaParams +from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file import model_server.conf.testing as conf from model_server.conf.testing import DummyInstanceMaskSegmentationModel @@ -819,3 +818,15 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase): self.assertTrue((res.bbox_intersec == 2).all()) self.assertTrue((res.loc[res.seg_overlaps, :].index == [1]).all()) self.assertTrue((res.loc[res.seg_overlaps, 'seg_iou'] == 0.4).all()) + +class TestIntensityThresholdObjectModel(BaseTestRoiSetMonoProducts, unittest.TestCase): + def test_permissive_instance_segmentation(self): + from model_server.base.roiset import IntensityThresholdInstanceMaskSegmentationModel + + img = self.stack.get_mono(channel=0, mip=True) + mask = self.seg_mask + + # TODO: test on actual binary mask, not dummy square mask + model = IntensityThresholdInstanceMaskSegmentationModel() + obmap = model.label_instance_class(img, mask) + self.assertTrue(np.all(mask.data == 255 * obmap.data)) \ No newline at end of file -- GitLab