Skip to content
Snippets Groups Projects
Commit 3c4aab06 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Renamed some model classes; expanded scope of trivial object classification...

Renamed some model classes; expanded scope of trivial object classification model to be functional base on intensity threshold, too
parent e894ba3f
No related branches found
No related tags found
No related merge requests found
......@@ -3,7 +3,7 @@ from typing import List, Union
from fastapi import FastAPI, HTTPException
from .accessors import generate_file_accessor
from .models import BinaryThresholdSegmentationModel, PermissiveInstanceSegmentationModel
from .models import BinaryThresholdSegmentationModel, IntensityThresholdInstanceMaskSegmentationModel
from .roiset import RoiSetExportParams, SerializeRoiSetError
from .session import session, AccessorIdError, InvalidPathError, RoiSetIdError, WriteAccessorError
......@@ -89,18 +89,19 @@ def list_active_models():
class BinaryThresholdSegmentationParams(BaseModel):
tr: Union[int, float] = Field(0.5, description='Threshold for binary segmentation')
channel: Union[int, int] = Field(0, description='Channel from which to compute binary threshold')
@app.put('/models/seg/binary_threshold/load/')
@app.put('/models/seg/threshold/load/')
def load_binary_threshold_model(p: BinaryThresholdSegmentationParams, model_id=None) -> dict:
result = session.load_model(BinaryThresholdSegmentationModel, key=model_id, params=p)
session.log_info(f'Loaded binary threshold segmentation model {result}')
return {'model_id': result}
@app.put('/models/classify/permissive/load')
def load_permissive_instance_classification_model(model_id=None) -> dict:
result = session.load_model(PermissiveInstanceSegmentationModel, key=model_id, params={})
@app.put('/models/classify/threshold/load')
def load_permissive_instance_classification_model(p: BinaryThresholdSegmentationParams, model_id=None) -> dict:
result = session.load_model(IntensityThresholdInstanceMaskSegmentationModel, key=model_id, params=p)
session.log_info(f'Loaded permissive instance segmentation model {result}')
return {'model_id': result}
......
......@@ -58,7 +58,7 @@ class ImageToImageModel(Model):
"""
@abstractmethod
def infer(self, img: GenericImageDataAccessor) -> GenericImageDataAccessor:
def infer(self, img: GenericImageDataAccessor, *args) -> GenericImageDataAccessor:
pass
......@@ -92,10 +92,12 @@ class BinaryThresholdSegmentationModel(SemanticSegmentationModel):
"""
def __init__(self, params=None):
self.tr = params['tr']
self.tr = params.get('tr', 0.5)
self.channel = params.get('channel', 0)
self.loaded = self.load()
def infer(self, acc: GenericImageDataAccessor) -> GenericImageDataAccessor:
# TODO: interpret thresh as normalized to dtype range
return acc.apply(lambda x: x > self.tr)
def label_pixel_class(self, acc: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor:
......@@ -105,11 +107,15 @@ class BinaryThresholdSegmentationModel(SemanticSegmentationModel):
return True
class InstanceSegmentationModel(ImageToImageModel):
class InstanceMaskSegmentationModel(ImageToImageModel):
"""
Base model that exposes a method that returns an instance classification map for a given input image and mask
"""
@abstractmethod
def infer(self, img: GenericImageDataAccessor, *args) -> GenericImageDataAccessor:
pass
@abstractmethod
def label_instance_class(
self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
......@@ -122,6 +128,8 @@ class InstanceSegmentationModel(ImageToImageModel):
if img.hw != mask.hw or img.nz != mask.nz:
raise InvalidInputImageError('Expect input image and mask to be the same shape')
return self.infer(img, mask)
def label_patch_stack(self, img: PatchStack, mask: PatchStack, allow_multiple=True, force_single=False, **kwargs):
"""
Call inference on all patches in a PatchStack at once
......@@ -146,19 +154,23 @@ class InstanceSegmentationModel(ImageToImageModel):
return PatchStack(data)
class PermissiveInstanceSegmentationModel(InstanceSegmentationModel):
class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationModel):
"""
Trivial but functional model that labels all objects as class 1
Model that labels all objects as class 1 if the intensity in a given channel exceeds threshold;
threshold = 0.0 means that all objects are returned class 1.
"""
def __init__(self, params=None):
# TODO: make model params kwargs to constructor, makes debugging much easier
self.tr = params.get('tr', 0.5)
self.channel = params.get('channel', 0)
self.loaded = self.load()
def load(self):
return True
def infer(self, acc: GenericImageDataAccessor, mask: GenericImageDataAccessor) -> GenericImageDataAccessor:
return mask.apply(lambda x: (1 * (x > 0)).astype(acc.dtype))
return mask.apply(lambda x: (1 * (x > self.tr)).astype(acc.dtype))
def label_instance_class(
self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
......
......@@ -11,7 +11,7 @@ from ..session import session
from ..pipelines.shared import PipelineTrace, PipelineParams, PipelineRecord
from ..models import Model, InstanceSegmentationModel
from ..models import Model, InstanceMaskSegmentationModel
class RoiSetObjectMapParams(PipelineParams):
......@@ -89,7 +89,7 @@ def roiset_object_map_pipeline(
# optionally run an object classifier if specified
if obmod := models.get('object_classifier_'):
obmod_name = k['object_classifier_model_id']
assert isinstance(obmod, InstanceSegmentationModel)
assert isinstance(obmod, InstanceMaskSegmentationModel)
rois.classify_by(
obmod_name,
[k['patches_channel']],
......
......@@ -17,7 +17,7 @@ from skimage.measure import approximate_polygon, find_contours, label, points_in
from skimage.morphology import binary_dilation, disk
from .accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file
from .models import InstanceSegmentationModel
from .models import InstanceMaskSegmentationModel
from .process import get_safe_contours, pad, rescale, resample_to_8bit, make_rgb
from .annotators import draw_box_on_patch, draw_contours_on_patch, draw_boxes_on_3d_image
from .accessors import generate_file_accessor, PatchStack
......@@ -577,7 +577,7 @@ class RoiSet(object):
def classify_by(
self, name: str, channels: list[int],
object_classification_model: InstanceSegmentationModel,
object_classification_model: InstanceMaskSegmentationModel,
):
"""
Insert a column in RoiSet data table that associates each ROI with an integer class, determined by passing
......@@ -1089,7 +1089,7 @@ class RoiSetWithDerivedChannels(RoiSet):
def classify_by(
self, name: str, channels: list[int],
object_classification_model: InstanceSegmentationModel,
object_classification_model: InstanceMaskSegmentationModel,
derived_channel_functions: list[callable] = None
):
"""
......
......@@ -14,7 +14,7 @@ from urllib3 import Retry
from .fastapi import app
from ..base.accessors import GenericImageDataAccessor, InMemoryDataAccessor
from ..base.models import SemanticSegmentationModel, InstanceSegmentationModel
from ..base.models import SemanticSegmentationModel, InstanceMaskSegmentationModel
from ..base.session import session
from ..base.accessors import generate_file_accessor
......@@ -52,7 +52,7 @@ def load_dummy_model() -> dict:
@test_router.put('/models/dummy_instance/load/')
def load_dummy_model() -> dict:
mid = session.load_model(DummyInstanceSegmentationModel)
mid = session.load_model(DummyInstanceMaskSegmentationModel)
session.log_info(f'Loaded model {mid}')
return {'model_id': mid}
......@@ -99,7 +99,7 @@ class TestServerBaseClass(unittest.TestCase):
self.assertEqual(resp.status_code, code)
return resp
def assertPutSuccess(self, endpoint, query=None, body=None):
def assertPutSuccess(self, endpoint, query={}, body={}):
resp = self._get_sesh().put(
self.uri + endpoint,
params=query,
......@@ -207,7 +207,7 @@ class DummySemanticSegmentationModel(SemanticSegmentationModel):
def load(self):
return True
def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict):
def infer(self, img: GenericImageDataAccessor) -> GenericImageDataAccessor:
super().infer(img)
w = img.shape_dict['X']
h = img.shape_dict['Y']
......@@ -221,7 +221,7 @@ class DummySemanticSegmentationModel(SemanticSegmentationModel):
return mask
class DummyInstanceSegmentationModel(InstanceSegmentationModel):
class DummyInstanceMaskSegmentationModel(InstanceMaskSegmentationModel):
model_id = 'dummy_pass_input_mask'
......@@ -241,5 +241,5 @@ class DummyInstanceSegmentationModel(InstanceSegmentationModel):
"""
Returns a trivial segmentation, i.e. the input mask with value 1
"""
super(DummyInstanceSegmentationModel, self).label_instance_class(img, mask, **kwargs)
super(DummyInstanceMaskSegmentationModel, self).label_instance_class(img, mask, **kwargs)
return self.infer(img, mask)
......@@ -10,7 +10,7 @@ import vigra
import model_server.extensions.ilastik.conf
from ...base.accessors import PatchStack
from ...base.accessors import GenericImageDataAccessor, InMemoryDataAccessor
from ...base.models import Model, ImageToImageModel, InstanceSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel
from ...base.models import Model, ImageToImageModel, InstanceMaskSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel
class IlastikModel(Model):
......@@ -165,7 +165,7 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
return mask
class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel):
class IlastikObjectClassifierFromMaskSegmentationModel(IlastikModel, InstanceMaskSegmentationModel):
@staticmethod
def _make_8bit_mask(nda):
......@@ -233,7 +233,7 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
return InMemoryDataAccessor(data=yxcz)
def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs):
super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)
super(IlastikObjectClassifierFromMaskSegmentationModel, self).label_instance_class(img, mask, **kwargs)
return self.infer(img, mask)
......
......@@ -54,7 +54,7 @@ def load_seg_to_obj_model(p: IlastikParams, model_id=None) -> dict:
Load an ilastik object classifier from segmentation model from its project file
"""
return load_ilastik_model(
ilm.IlastikObjectClassifierFromSegmentationModel,
ilm.IlastikObjectClassifierFromMaskSegmentationModel,
p,
model_id=model_id,
)
......
......@@ -40,7 +40,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
def test_load_dummy_instance_model(self):
mid = self.assertPutSuccess(f'testing/models/dummy_instance/load')['model_id']
rl = self.assertGetSuccess('models')
self.assertEqual(rl[mid]['class'], 'DummyInstanceSegmentationModel')
self.assertEqual(rl[mid]['class'], 'DummyInstanceMaskSegmentationModel')
return mid
def test_pipeline_errors_when_ids_not_found(self):
......@@ -175,7 +175,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
def test_binary_segmentation_model(self):
mid = self.assertPutSuccess(
'/models/seg/binary_threshold/load/', body={'tr': 10}
'/models/seg/threshold/load/', body={'tr': 10}
)['model_id']
fname = self.copy_input_file_to_server()
......@@ -195,7 +195,8 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
def test_permissive_instance_segmentation_model(self):
self.assertPutSuccess(
'/models/classify/permissive/load',
'/models/classify/threshold/load',
body={}
)
......@@ -3,9 +3,9 @@ import unittest
import numpy as np
import model_server.conf.testing as conf
from model_server.conf.testing import DummySemanticSegmentationModel, DummyInstanceSegmentationModel
from model_server.conf.testing import DummySemanticSegmentationModel, DummyInstanceMaskSegmentationModel
from model_server.base.accessors import CziImageFileAccessor
from model_server.base.models import CouldNotLoadModelError, BinaryThresholdSegmentationModel, PermissiveInstanceSegmentationModel
from model_server.base.models import CouldNotLoadModelError, BinaryThresholdSegmentationModel, IntensityThresholdInstanceMaskSegmentationModel
czifile = conf.meta['image_files']['czifile']
......@@ -64,13 +64,13 @@ class TestCziImageFileAccess(unittest.TestCase):
def test_dummy_instance_segmentation(self):
img, mask = self.test_dummy_pixel_segmentation()
model = DummyInstanceSegmentationModel()
model = DummyInstanceMaskSegmentationModel()
obmap = model.label_instance_class(img, mask)
self.assertTrue(all(obmap.unique()[0] == [0, 1]))
self.assertTrue(all(obmap.unique()[1] > 0))
def test_permissive_instance_segmentation(self):
img, mask = self.test_dummy_pixel_segmentation()
model = PermissiveInstanceSegmentationModel()
model = IntensityThresholdInstanceMaskSegmentationModel(params={})
obmap = model.label_instance_class(img, mask)
self.assertTrue(np.all(mask.data == 255 * obmap.data))
......@@ -11,7 +11,7 @@ from model_server.base.roiset import filter_df_overlap_bbox, filter_df_overlap_s
from model_server.base.roiset import RoiSet
from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file, PatchStack
import model_server.conf.testing as conf
from model_server.conf.testing import DummyInstanceSegmentationModel
from model_server.conf.testing import DummyInstanceMaskSegmentationModel
data = conf.meta['image_files']
output_path = conf.meta['output_path']
......@@ -82,7 +82,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
zero_obmap = InMemoryDataAccessor(np.zeros(self.seg_mask.shape, self.seg_mask.dtype))
roiset = RoiSet.from_object_ids(self.stack_ch_pa, zero_obmap)
self.assertEqual(roiset.count, 0)
roiset.classify_by('dummy_class', [0], DummyInstanceSegmentationModel())
roiset.classify_by('dummy_class', [0], DummyInstanceMaskSegmentationModel())
self.assertTrue('classify_by_dummy_class' in roiset.get_df().columns)
def test_slices_are_valid(self):
......@@ -183,14 +183,14 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
def test_classify_by(self):
roiset = self._make_roi_set()
roiset.classify_by('dummy_class', [0], DummyInstanceSegmentationModel())
roiset.classify_by('dummy_class', [0], DummyInstanceMaskSegmentationModel())
self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1]))
self.assertTrue(all(np.unique(roiset.get_object_class_map('dummy_class').data) == [0, 1]))
return roiset
def test_classify_by_multiple_channels(self):
roiset = RoiSet.from_binary_mask(self.stack, self.seg_mask, params=RoiSetMetaParams(deproject_channel=0))
roiset.classify_by('dummy_class', [0, 1], DummyInstanceSegmentationModel())
roiset.classify_by('dummy_class', [0, 1], DummyInstanceMaskSegmentationModel())
self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1]))
self.assertTrue(all(np.unique(roiset.get_object_class_map('dummy_class').data) == [0, 1]))
return roiset
......@@ -207,7 +207,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
self.assertGreater(total_iou, 0.6)
# classify first RoiSet
roiset1.classify_by('dummy_class', [0, 1], DummyInstanceSegmentationModel())
roiset1.classify_by('dummy_class', [0, 1], DummyInstanceMaskSegmentationModel())
self.assertTrue('dummy_class' in roiset1.classification_columns)
self.assertFalse('dummy_class' in roiset2.classification_columns)
......@@ -455,7 +455,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
'annotated_zstacks': {},
'object_classes': True,
})
self.roiset.classify_by('dummy_class', [0], DummyInstanceSegmentationModel())
self.roiset.classify_by('dummy_class', [0], DummyInstanceMaskSegmentationModel())
interm = self.roiset.get_export_product_accessors(
channel=3,
params=p
......
......@@ -7,7 +7,7 @@ from model_server.base.roiset import RoiSetWithDerivedChannelsExportParams, RoiS
from model_server.base.roiset import RoiSetWithDerivedChannels
from model_server.base.accessors import generate_file_accessor, PatchStack
import model_server.conf.testing as conf
from model_server.conf.testing import DummyInstanceSegmentationModel
from model_server.conf.testing import DummyInstanceMaskSegmentationModel
data = conf.meta['image_files']
params = conf.meta['roiset']
......@@ -20,7 +20,7 @@ class TestDerivedChannels(unittest.TestCase):
self.seg_mask = generate_file_accessor(data['multichannel_zstack_mask2d']['path'])
def test_classify_by_with_derived_channel(self):
class ModelWithDerivedInputs(DummyInstanceSegmentationModel):
class ModelWithDerivedInputs(DummyInstanceMaskSegmentationModel):
def infer(self, img, mask):
return PatchStack(super().infer(img, mask).data * img.chroma)
......
......@@ -58,7 +58,7 @@ class BaseTestRoiSetMonoProducts(object):
}
def _get_models(self):
from model_server.base.models import BinaryThresholdSegmentationModel, PermissiveInstanceSegmentationModel
from model_server.base.models import BinaryThresholdSegmentationModel, IntensityThresholdInstanceMaskSegmentationModel
return {
'pixel_classifier_segmentation': {
'name': 'min_px_mod',
......@@ -66,7 +66,7 @@ class BaseTestRoiSetMonoProducts(object):
},
'object_classifier': {
'name': 'min_ob_mod',
'model': PermissiveInstanceSegmentationModel()
'model': IntensityThresholdInstanceMaskSegmentationModel({}),
},
}
......@@ -125,7 +125,7 @@ class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProd
def test_load_pixel_classifier(self):
mid = self.assertPutSuccess(
'models/seg/binary_threshold/load/',
'models/seg/threshold/load/',
body={'tr': 1e4},
)['model_id']
self.assertTrue(mid.startswith('BinaryThresholdSegmentationModel'))
......@@ -133,9 +133,9 @@ class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProd
def test_load_object_classifier(self):
mid = self.assertPutSuccess(
'models/classify/permissive/load/',
'models/classify/threshold/load/',
)['model_id']
self.assertTrue(mid.startswith('Permissive'))
self.assertTrue(mid.startswith('IntensityThresholdInstanceMaskSegmentation'))
return mid
def _object_map_workflow(self, ob_classifer_id):
......
......@@ -176,7 +176,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
def test_run_object_classifier_from_segmentation(self):
self.test_run_pixel_classifier()
fp = czifile['path']
model = ilm.IlastikObjectClassifierFromSegmentationModel(
model = ilm.IlastikObjectClassifierFromMaskSegmentationModel(
params={'project_file': ilastik_classifiers['seg_to_obj']['path'].__str__()}
)
mask = self.model.label_pixel_class(self.mono_image)
......@@ -287,7 +287,7 @@ class TestIlastikOverApi(TestServerTestCase):
body={'project_file': str(ilastik_classifiers['seg_to_obj']['path'])},
)['model_id']
rl = self.assertGetSuccess('models')
self.assertEqual(rl[mid]['class'], 'IlastikObjectClassifierFromSegmentationModel')
self.assertEqual(rl[mid]['class'], 'IlastikObjectClassifierFromMaskSegmentationModel')
return mid
def test_ilastik_infer_pixel_probability(self):
......@@ -432,7 +432,7 @@ class TestIlastikObjectClassification(unittest.TestCase):
)
)
self.classifier = ilm.IlastikObjectClassifierFromSegmentationModel(
self.classifier = ilm.IlastikObjectClassifierFromMaskSegmentationModel(
params={'project_file': ilastik_classifiers['seg_to_obj']['path'].__str__()},
)
self.raw = self.roiset.get_patches_acc()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment