Skip to content
Snippets Groups Projects
Commit 2f644dfc authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Pixel classification pipeline is tested with smoothing parameter

parent 2a8db3e2
No related branches found
No related tags found
No related merge requests found
from pathlib import Path
from fastapi import APIRouter, HTTPException
from model_server.base.accessors import GenericImageDataAccessor, generate_file_accessor, write_accessor_data_to_file
from model_server.base.models import SemanticSegmentationModel
from model_server.base.util import PipelineTrace
from model_server.base.session import session
from ..accessors import GenericImageDataAccessor, generate_file_accessor, write_accessor_data_to_file
from ..models import SemanticSegmentationModel
from ..process import smooth
from ..util import PipelineTrace
from ..session import session
from pydantic import BaseModel, validator
......@@ -14,7 +13,7 @@ router = APIRouter(
tags=['pipelines'],
)
class FileCallParams(BaseModel):
class Params(BaseModel):
model_id: str
input_filename: str
channel: int = None
......@@ -32,7 +31,7 @@ class FileCallParams(BaseModel):
return v
class FileCallRecord(BaseModel):
class Record(BaseModel):
model_id: str
input_filepath: str
output_filepath: str
......@@ -41,18 +40,18 @@ class FileCallRecord(BaseModel):
@router.put('/segment')
def file_call(p: FileCallParams) -> FileCallRecord:
def file_call(p: Params) -> Record:
inpath = session.paths['inbound_images'] / p.input_filename
steps = classify_pixels(
generate_file_accessor(inpath),
session.models[p.model_id]['object'],
channel=p.channel,
**p,
)
outpath = session.paths['outbound_images'] / (p.model_id + '_' + inpath.stem + '.tif')
write_accessor_data_to_file(outpath, steps.last)
session.log_info(f'Completed segmentation of {p.input_filename}')
return FileCallRecord(
return Record(
model_id=p.model_id,
input_filepath=inpath.__str__(),
output_filepath=outpath.__str__(),
......@@ -60,8 +59,7 @@ def file_call(p: FileCallParams) -> FileCallRecord:
timer=steps.times
)
# TODO: implement smoothing, move this away from model itself
def classify_pixels(acc_in: GenericImageDataAccessor, model: SemanticSegmentationModel, **kwargs) -> PipelineTrace:
def classify_pixels(acc_in: GenericImageDataAccessor, model: SemanticSegmentationModel, **k) -> PipelineTrace:
"""
Run a semantic segmentation model to compute a binary mask from an input image
......@@ -72,8 +70,11 @@ def classify_pixels(acc_in: GenericImageDataAccessor, model: SemanticSegmentatio
:return: record object
"""
d = PipelineTrace()
ch = kwargs.get('channel')
d['input'] = acc_in.get_mono(ch)
d['input'] = acc_in
if ch := k.get('channel'):
d['mono'] = acc_in.get_mono(ch)
d['inference'] = model.label_pixel_class(d.last)
if sm := k.get('smooth'):
d['smooth'] = d.last.apply(lambda x: smooth(x, sm))
d['output'] = d.last
return d
\ No newline at end of file
import unittest
from model_server.base.accessors import generate_file_accessor, write_accessor_data_to_file
from model_server.base.pipelines.segment import classify_pixels
from model_server.base.pipelines.segment import classify_pixels, Params
import model_server.conf.testing as conf
from tests.base.test_model import DummySemanticSegmentationModel
......@@ -14,7 +14,7 @@ class TestGetSessionObject(unittest.TestCase):
def test_single_session_instance(self):
acc = generate_file_accessor(czifile['path'])
trace = classify_pixels(acc, self.model, channel=2)
trace = classify_pixels(acc, self.model, channel=2, smooth=3)
outfp = output_path / 'classify_pixels.tif'
write_accessor_data_to_file(outfp, trace.last)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment