Skip to content
Snippets Groups Projects
Commit 82a1ca73 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Pipeline handles channel selection, not model

parent 6a0505a4
No related branches found
No related tags found
No related merge requests found
......@@ -69,16 +69,18 @@ def list_session_log() -> list:
def list_active_models():
return session.describe_loaded_models()
class BinaryThresholdSegmentationParams(BaseModel):
channel: int = Field(None, description='Channel to use for segmentation; use all channels if empty.')
tr: Union[int, float] = Field(0.5, description='Threshold for binary segmentation')
@app.put('/models/seg/binary_threshold/load/')
def load_binary_threshold_model(p: BinaryThresholdSegmentationParams, model_id=None) -> dict:
result = session.load_model(BinaryThresholdSegmentationModel, key=model_id, params=p)
session.log_info(f'Loaded binary threshold segmentation model {result}')
return {'model_id': result}
@app.get('/accessors')
def list_accessors():
return session.list_accessors()
......
......@@ -134,18 +134,13 @@ class BinaryThresholdSegmentationModel(SemanticSegmentationModel):
def __init__(self, params=None):
self.tr = params['tr']
self.channel = params['channel']
self.loaded = True
def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict):
if self.channel:
acc = img.get_mono(self.channel)
else:
acc = img
return acc.get_mono(self.channel).apply(lambda x: x > self.tr)
def infer(self, acc: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict):
return acc.apply(lambda x: x > self.tr)
def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor:
return self.infer(img, **kwargs)
def label_pixel_class(self, acc: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor:
return self.infer(acc, **kwargs)
def load(self):
pass
......
......@@ -56,7 +56,7 @@ class TestCziImageFileAccess(unittest.TestCase):
return img, mask
def test_binary_segmentation(self):
model = BinaryThresholdSegmentationModel({'tr': 3e4, 'channel': 0})
model = BinaryThresholdSegmentationModel({'tr': 3e4})
res = model.label_pixel_class(self.cf)
self.assertTrue(res.is_mask())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment