Skip to content
Snippets Groups Projects
Commit 82a1ca73 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Pipeline handles channel selection, not model

parent 6a0505a4
No related branches found
No related tags found
No related merge requests found
...@@ -69,16 +69,18 @@ def list_session_log() -> list: ...@@ -69,16 +69,18 @@ def list_session_log() -> list:
def list_active_models(): def list_active_models():
return session.describe_loaded_models() return session.describe_loaded_models()
class BinaryThresholdSegmentationParams(BaseModel): class BinaryThresholdSegmentationParams(BaseModel):
channel: int = Field(None, description='Channel to use for segmentation; use all channels if empty.')
tr: Union[int, float] = Field(0.5, description='Threshold for binary segmentation') tr: Union[int, float] = Field(0.5, description='Threshold for binary segmentation')
@app.put('/models/seg/binary_threshold/load/') @app.put('/models/seg/binary_threshold/load/')
def load_binary_threshold_model(p: BinaryThresholdSegmentationParams, model_id=None) -> dict: def load_binary_threshold_model(p: BinaryThresholdSegmentationParams, model_id=None) -> dict:
result = session.load_model(BinaryThresholdSegmentationModel, key=model_id, params=p) result = session.load_model(BinaryThresholdSegmentationModel, key=model_id, params=p)
session.log_info(f'Loaded binary threshold segmentation model {result}') session.log_info(f'Loaded binary threshold segmentation model {result}')
return {'model_id': result} return {'model_id': result}
@app.get('/accessors') @app.get('/accessors')
def list_accessors(): def list_accessors():
return session.list_accessors() return session.list_accessors()
......
...@@ -134,18 +134,13 @@ class BinaryThresholdSegmentationModel(SemanticSegmentationModel): ...@@ -134,18 +134,13 @@ class BinaryThresholdSegmentationModel(SemanticSegmentationModel):
def __init__(self, params=None): def __init__(self, params=None):
self.tr = params['tr'] self.tr = params['tr']
self.channel = params['channel']
self.loaded = True self.loaded = True
def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict): def infer(self, acc: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict):
if self.channel: return acc.apply(lambda x: x > self.tr)
acc = img.get_mono(self.channel)
else:
acc = img
return acc.get_mono(self.channel).apply(lambda x: x > self.tr)
def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor: def label_pixel_class(self, acc: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor:
return self.infer(img, **kwargs) return self.infer(acc, **kwargs)
def load(self): def load(self):
pass pass
......
...@@ -56,7 +56,7 @@ class TestCziImageFileAccess(unittest.TestCase): ...@@ -56,7 +56,7 @@ class TestCziImageFileAccess(unittest.TestCase):
return img, mask return img, mask
def test_binary_segmentation(self): def test_binary_segmentation(self):
model = BinaryThresholdSegmentationModel({'tr': 3e4, 'channel': 0}) model = BinaryThresholdSegmentationModel({'tr': 3e4})
res = model.label_pixel_class(self.cf) res = model.label_pixel_class(self.cf)
self.assertTrue(res.is_mask()) self.assertTrue(res.is_mask())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment