Skip to content
Snippets Groups Projects
Commit dae64e30 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Can load binary segmentation model

parent bfcd18af
No related branches found
No related tags found
No related merge requests found
from pydantic import BaseModel, Field
from typing import Union
from fastapi import FastAPI, HTTPException
from .accessors import generate_file_accessor
from .models import BinaryThresholdSegmentationModel
from .session import session, AccessorIdError, InvalidPathError, WriteAccessorError
app = FastAPI(debug=True)
from .pipelines.router import router
......@@ -68,6 +69,15 @@ def list_session_log() -> list:
def list_active_models():
return session.describe_loaded_models()
class BinaryThresholdSegmentationParams(BaseModel):
channel: int = Field(None, description='Channel to use for segmentation; use all channels if empty.')
tr: Union[int, float] = Field(0.5, description='Threshold for binary segmentation')
@app.put('/models/seg/binary_threshold/load/')
def load_binary_threshold_model(p: BinaryThresholdSegmentationParams, model_id=None) -> dict:
result = session.load_model(BinaryThresholdSegmentationModel, key=model_id, params=p)
session.log_info(f'Loaded binary threshold segmentation model {result}')
return {'model_id': result}
@app.get('/accessors')
def list_accessors():
......
from abc import ABC, abstractmethod
from math import floor
import numpy as np
from pydantic import BaseModel
from .accessors import GenericImageDataAccessor, InMemoryDataAccessor, PatchStack
from .accessors import GenericImageDataAccessor, PatchStack
class Model(ABC):
def __init__(self, autoload=True, params: BaseModel = None):
def __init__(self, autoload=True, params: dict = None):
"""
Abstract base class for an inference model that uses image data as an input.
......@@ -18,7 +16,7 @@ class Model(ABC):
"""
self.autoload = autoload
if params:
self.params = params.dict()
self.params = params
self.loaded = False
if not autoload:
return None
......@@ -130,15 +128,24 @@ class InstanceSegmentationModel(ImageToImageModel):
class BinaryThresholdSegmentationModel(SemanticSegmentationModel):
"""
Trivial but functional model that labels all pixels above an intensity threshold as class 1
"""
def __init__(self, tr: float = 0.5):
self.tr = tr
def __init__(self, params=None):
self.tr = params['tr']
self.channel = params['channel']
self.loaded = True
def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict):
return img.apply(lambda x: x > self.tr), {'success': True}
if self.channel:
acc = img.get_mono(self.channel)
else:
acc = img
return acc.get_mono(self.channel).apply(lambda x: x > self.tr)
def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor:
return self.infer(img, **kwargs)[0]
return self.infer(img, **kwargs)
def load(self):
pass
......
......@@ -182,4 +182,10 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
sd = self.assertGetSuccess(f'accessors/{acc_id}')['shape_dict']
self.assertEqual(self.assertGetSuccess(f'accessors/{acc_id}')['filepath'], '')
acc_out = self.get_accessor(accessor_id=acc_id, filename='test_output.tif')
self.assertEqual(sd, acc_out.shape_dict)
\ No newline at end of file
self.assertEqual(sd, acc_out.shape_dict)
def test_load_binary_segmentation_model(self):
mid = self.assertPutSuccess(
'/models/seg/binary_threshold/load/', body={'channel': 0, 'tr': 10}
)['model_id']
return mid
\ No newline at end of file
......@@ -56,9 +56,8 @@ class TestCziImageFileAccess(unittest.TestCase):
return img, mask
def test_binary_segmentation(self):
model = BinaryThresholdSegmentationModel(tr=3e4)
img = self.cf.get_mono(0)
res = model.label_pixel_class(img)
model = BinaryThresholdSegmentationModel({'tr': 3e4, 'channel': 0})
res = model.label_pixel_class(self.cf)
self.assertTrue(res.is_mask())
def test_dummy_instance_segmentation(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment