Skip to content
Snippets Groups Projects
Commit dae64e30 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Can load binary segmentation model

parent bfcd18af
No related branches found
No related tags found
No related merge requests found
from pydantic import BaseModel, Field
from typing import Union from typing import Union
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from .accessors import generate_file_accessor from .accessors import generate_file_accessor
from .models import BinaryThresholdSegmentationModel
from .session import session, AccessorIdError, InvalidPathError, WriteAccessorError from .session import session, AccessorIdError, InvalidPathError, WriteAccessorError
app = FastAPI(debug=True) app = FastAPI(debug=True)
from .pipelines.router import router from .pipelines.router import router
...@@ -68,6 +69,15 @@ def list_session_log() -> list: ...@@ -68,6 +69,15 @@ def list_session_log() -> list:
def list_active_models(): def list_active_models():
return session.describe_loaded_models() return session.describe_loaded_models()
class BinaryThresholdSegmentationParams(BaseModel):
channel: int = Field(None, description='Channel to use for segmentation; use all channels if empty.')
tr: Union[int, float] = Field(0.5, description='Threshold for binary segmentation')
@app.put('/models/seg/binary_threshold/load/')
def load_binary_threshold_model(p: BinaryThresholdSegmentationParams, model_id=None) -> dict:
result = session.load_model(BinaryThresholdSegmentationModel, key=model_id, params=p)
session.log_info(f'Loaded binary threshold segmentation model {result}')
return {'model_id': result}
@app.get('/accessors') @app.get('/accessors')
def list_accessors(): def list_accessors():
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from math import floor
import numpy as np import numpy as np
from pydantic import BaseModel
from .accessors import GenericImageDataAccessor, InMemoryDataAccessor, PatchStack from .accessors import GenericImageDataAccessor, PatchStack
class Model(ABC): class Model(ABC):
def __init__(self, autoload=True, params: BaseModel = None): def __init__(self, autoload=True, params: dict = None):
""" """
Abstract base class for an inference model that uses image data as an input. Abstract base class for an inference model that uses image data as an input.
...@@ -18,7 +16,7 @@ class Model(ABC): ...@@ -18,7 +16,7 @@ class Model(ABC):
""" """
self.autoload = autoload self.autoload = autoload
if params: if params:
self.params = params.dict() self.params = params
self.loaded = False self.loaded = False
if not autoload: if not autoload:
return None return None
...@@ -130,15 +128,24 @@ class InstanceSegmentationModel(ImageToImageModel): ...@@ -130,15 +128,24 @@ class InstanceSegmentationModel(ImageToImageModel):
class BinaryThresholdSegmentationModel(SemanticSegmentationModel): class BinaryThresholdSegmentationModel(SemanticSegmentationModel):
"""
Trivial but functional model that labels all pixels above an intensity threshold as class 1
"""
def __init__(self, tr: float = 0.5): def __init__(self, params=None):
self.tr = tr self.tr = params['tr']
self.channel = params['channel']
self.loaded = True
def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict): def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict):
return img.apply(lambda x: x > self.tr), {'success': True} if self.channel:
acc = img.get_mono(self.channel)
else:
acc = img
return acc.get_mono(self.channel).apply(lambda x: x > self.tr)
def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor: def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor:
return self.infer(img, **kwargs)[0] return self.infer(img, **kwargs)
def load(self): def load(self):
pass pass
......
...@@ -182,4 +182,10 @@ class TestApiFromAutomatedClient(TestServerBaseClass): ...@@ -182,4 +182,10 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
sd = self.assertGetSuccess(f'accessors/{acc_id}')['shape_dict'] sd = self.assertGetSuccess(f'accessors/{acc_id}')['shape_dict']
self.assertEqual(self.assertGetSuccess(f'accessors/{acc_id}')['filepath'], '') self.assertEqual(self.assertGetSuccess(f'accessors/{acc_id}')['filepath'], '')
acc_out = self.get_accessor(accessor_id=acc_id, filename='test_output.tif') acc_out = self.get_accessor(accessor_id=acc_id, filename='test_output.tif')
self.assertEqual(sd, acc_out.shape_dict) self.assertEqual(sd, acc_out.shape_dict)
\ No newline at end of file
def test_load_binary_segmentation_model(self):
mid = self.assertPutSuccess(
'/models/seg/binary_threshold/load/', body={'channel': 0, 'tr': 10}
)['model_id']
return mid
\ No newline at end of file
...@@ -56,9 +56,8 @@ class TestCziImageFileAccess(unittest.TestCase): ...@@ -56,9 +56,8 @@ class TestCziImageFileAccess(unittest.TestCase):
return img, mask return img, mask
def test_binary_segmentation(self): def test_binary_segmentation(self):
model = BinaryThresholdSegmentationModel(tr=3e4) model = BinaryThresholdSegmentationModel({'tr': 3e4, 'channel': 0})
img = self.cf.get_mono(0) res = model.label_pixel_class(self.cf)
res = model.label_pixel_class(img)
self.assertTrue(res.is_mask()) self.assertTrue(res.is_mask())
def test_dummy_instance_segmentation(self): def test_dummy_instance_segmentation(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment