Skip to content
Snippets Groups Projects
Commit 73754817 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Adapted base segmentation worklow and API endpoint for generic semantic...

Adapted base segmentation worklow and API endpoint for generic semantic segmentation model; made dummy model this type too
parent a372c464
No related branches found
No related tags found
No related merge requests found
......@@ -7,7 +7,7 @@ import numpy as np
from conf.testing import czifile, ilastik_classifiers, output_path
from model_server.accessors import CziImageFileAccessor, InMemoryDataAccessor, write_accessor_data_to_file
from extensions.ilastik import models as ilm
from model_server.workflows import infer_image_to_image
from model_server.workflows import classify_pixels
from tests.test_api import TestServerBaseClass
class TestIlastikPixelClassification(unittest.TestCase):
......@@ -115,7 +115,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
self.assertEqual(objmap.data.max(), 3)
def test_ilastik_pixel_classification_as_workflow(self):
result = infer_image_to_image(
result = classify_pixels(
czifile['path'],
ilm.IlastikPixelClassifierModel(
{'project_file': ilastik_classifiers['px']}
......@@ -243,7 +243,7 @@ class TestIlastikOverApi(TestServerBaseClass):
model_id = self.test_load_ilastik_pixel_model()
resp_infer = requests.put(
self.uri + f'infer/from_image_file',
self.uri + f'workflows/segment',
params={
'model_id': model_id,
'input_filename': czifile['filename'],
......
from fastapi import FastAPI, HTTPException
from model_server.models import DummyImageToImageModel
from model_server.models import DummySegmentationModel
from model_server.session import Session, InvalidPathError
from model_server.validators import validate_workflow_inputs
from model_server.workflows import infer_image_to_image
from model_server.workflows import classify_pixels
from extensions.ilastik.workflows import infer_px_then_ob_model
app = FastAPI(debug=True)
......@@ -67,13 +67,13 @@ def list_active_models():
@app.put('/models/dummy/load/')
def load_dummy_model() -> dict:
return {'model_id': session.load_model(DummyImageToImageModel)}
return {'model_id': session.load_model(DummySegmentationModel)}
@app.put('/infer/from_image_file')
@app.put('/workflows/segment')
def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
inpath = session.paths['inbound_images'] / input_filename
validate_workflow_inputs([model_id], [inpath])
record = infer_image_to_image(
record = classify_pixels(
inpath,
session.models[model_id]['object'],
session.paths['outbound_images'],
......
......@@ -91,7 +91,7 @@ class InstanceSegmentationModel(ImageToImageModel):
class DummyImageToImageModel(ImageToImageModel):
class DummySegmentationModel(SemanticSegmentationModel):
model_id = 'dummy_make_white_square'
......@@ -106,6 +106,11 @@ class DummyImageToImageModel(ImageToImageModel):
result[floor(0.25 * h) : floor(0.75 * h), floor(0.25 * w) : floor(0.75 * w)] = 255
return InMemoryDataAccessor(data=result), {'success': True}
def label_pixel_class(
self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor:
mask, _ = self.infer(img)
return mask
class Error(Exception):
pass
......
......@@ -6,7 +6,7 @@ from time import perf_counter
from typing import Dict
from model_server.accessors import generate_file_accessor, write_accessor_data_to_file
from model_server.models import Model
from model_server.models import SemanticSegmentationModel
from pydantic import BaseModel
......@@ -29,11 +29,12 @@ class WorkflowRunRecord(BaseModel):
timer_results: Dict[str, float]
def infer_image_to_image(fpi: Path, model: Model, where_output: Path, **kwargs) -> WorkflowRunRecord:
def classify_pixels(fpi: Path, model: SemanticSegmentationModel, where_output: Path, **kwargs) -> WorkflowRunRecord:
"""
Generic workflow where a model processes an input image into an output image
Run a semantic segmentation model to compute a binary mask from an input image
:param fpi: Path object that references input image file
:param model: model object
:param model: semantic segmentation model instance
:param where_output: Path object that references output image directory
:param kwargs: variable-length keyword arguments
:return: record object
......
......@@ -4,7 +4,7 @@ import requests
import unittest
from conf.testing import czifile
from model_server.models import DummyImageToImageModel
from model_server.models import DummySegmentationModel
class TestServerBaseClass(unittest.TestCase):
def setUp(self) -> None:
......@@ -70,7 +70,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
resp_list = requests.get(self.uri + 'models')
self.assertEqual(resp_list.status_code, 200)
rj = resp_list.json()
self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel')
self.assertEqual(rj[model_id]['class'], 'DummySegmentationModel')
return model_id
def test_respond_with_error_when_invalid_filepath_requested(self):
......@@ -89,7 +89,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
def test_i2i_inference_errors_when_model_not_found(self):
model_id = 'not_a_real_model'
resp = requests.put(
self.uri + f'infer/from_image_file',
self.uri + f'workflows/segment',
params={
'model_id': model_id,
'input_filename': 'not_a_real_file.name'
......@@ -101,7 +101,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
model_id = self.test_load_dummy_model()
self.copy_input_file_to_server()
resp_infer = requests.put(
self.uri + f'infer/from_image_file',
self.uri + f'workflows/segment',
params={
'model_id': model_id,
'input_filename': czifile['filename'],
......
import unittest
from conf.testing import czifile
from model_server.accessors import CziImageFileAccessor
from model_server.models import DummyImageToImageModel, CouldNotLoadModelError
from model_server.models import DummySegmentationModel, CouldNotLoadModelError
class TestCziImageFileAccess(unittest.TestCase):
def setUp(self) -> None:
self.cf = CziImageFileAccessor(czifile['path'])
def test_instantiate_model(self):
model = DummyImageToImageModel(params=None)
model = DummySegmentationModel(params=None)
self.assertTrue(model.loaded)
def test_instantiate_model_with_nondefault_kwarg(self):
model = DummyImageToImageModel(autoload=False)
model = DummySegmentationModel(autoload=False)
self.assertFalse(model.autoload, 'Could not override autoload flag in subclass of Model.')
def test_raise_error_if_cannot_load_model(self):
class UnloadableDummyImageToImageModel(DummyImageToImageModel):
class UnloadableDummyImageToImageModel(DummySegmentationModel):
def load(self):
return False
......@@ -24,7 +24,7 @@ class TestCziImageFileAccess(unittest.TestCase):
mi = UnloadableDummyImageToImageModel()
def test_czifile_is_correct_shape(self):
model = DummyImageToImageModel()
model = DummySegmentationModel()
img, _ = model.infer(self.cf)
w = czifile['w']
......
import unittest
from conf.testing import czifile, output_path
from model_server.models import DummyImageToImageModel
from model_server.workflows import infer_image_to_image
from model_server.models import DummySegmentationModel
from model_server.workflows import classify_pixels
class TestGetSessionObject(unittest.TestCase):
def setUp(self) -> None:
self.model = DummyImageToImageModel()
self.model = DummySegmentationModel()
def test_single_session_instance(self):
result = infer_image_to_image(czifile['path'], self.model, output_path, channel=2)
result = classify_pixels(czifile['path'], self.model, output_path, channel=2)
self.assertTrue(result.success)
import tifffile
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment