diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py index 1dfbde1685218234f152ace5968a1aae4a8a9fff..5881679b98ee1bdf7b0013250862eabce2d99cbb 100644 --- a/extensions/ilastik/tests/test_ilastik.py +++ b/extensions/ilastik/tests/test_ilastik.py @@ -7,7 +7,7 @@ import numpy as np from conf.testing import czifile, ilastik_classifiers, output_path from model_server.accessors import CziImageFileAccessor, InMemoryDataAccessor, write_accessor_data_to_file from extensions.ilastik import models as ilm -from model_server.workflows import infer_image_to_image +from model_server.workflows import classify_pixels from tests.test_api import TestServerBaseClass class TestIlastikPixelClassification(unittest.TestCase): @@ -115,7 +115,7 @@ class TestIlastikPixelClassification(unittest.TestCase): self.assertEqual(objmap.data.max(), 3) def test_ilastik_pixel_classification_as_workflow(self): - result = infer_image_to_image( + result = classify_pixels( czifile['path'], ilm.IlastikPixelClassifierModel( {'project_file': ilastik_classifiers['px']} @@ -243,7 +243,7 @@ class TestIlastikOverApi(TestServerBaseClass): model_id = self.test_load_ilastik_pixel_model() resp_infer = requests.put( - self.uri + f'infer/from_image_file', + self.uri + f'workflows/segment', params={ 'model_id': model_id, 'input_filename': czifile['filename'], diff --git a/model_server/api.py b/model_server/api.py index bc9a57fe79288c2f63ebf79e055bf5785be2f823..a1c785313881d594e49fa8f335f24664e4429bf2 100644 --- a/model_server/api.py +++ b/model_server/api.py @@ -1,9 +1,9 @@ from fastapi import FastAPI, HTTPException -from model_server.models import DummyImageToImageModel +from model_server.models import DummySegmentationModel from model_server.session import Session, InvalidPathError from model_server.validators import validate_workflow_inputs -from model_server.workflows import infer_image_to_image +from model_server.workflows import classify_pixels from extensions.ilastik.workflows import infer_px_then_ob_model app = FastAPI(debug=True) @@ -67,13 +67,13 @@ def list_active_models(): @app.put('/models/dummy/load/') def load_dummy_model() -> dict: - return {'model_id': session.load_model(DummyImageToImageModel)} + return {'model_id': session.load_model(DummySegmentationModel)} -@app.put('/infer/from_image_file') +@app.put('/workflows/segment') def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: inpath = session.paths['inbound_images'] / input_filename validate_workflow_inputs([model_id], [inpath]) - record = infer_image_to_image( + record = classify_pixels( inpath, session.models[model_id]['object'], session.paths['outbound_images'], diff --git a/model_server/models.py b/model_server/models.py index 4f07da5730a3d06d93f6a310bafead67f740a202..f0864147279dbef501138307ae885a742d390757 100644 --- a/model_server/models.py +++ b/model_server/models.py @@ -91,7 +91,7 @@ class InstanceSegmentationModel(ImageToImageModel): -class DummyImageToImageModel(ImageToImageModel): +class DummySegmentationModel(SemanticSegmentationModel): model_id = 'dummy_make_white_square' @@ -106,6 +106,11 @@ class DummyImageToImageModel(ImageToImageModel): result[floor(0.25 * h) : floor(0.75 * h), floor(0.25 * w) : floor(0.75 * w)] = 255 return InMemoryDataAccessor(data=result), {'success': True} + def label_pixel_class( + self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor: + mask, _ = self.infer(img) + return mask + class Error(Exception): pass diff --git a/model_server/workflows.py b/model_server/workflows.py index 6eb7b50d1b6e78518cc500ac5c1936cd266a333b..9c6da2142ff557fdaf5d0cb521169f03816eb512 100644 --- a/model_server/workflows.py +++ b/model_server/workflows.py @@ -6,7 +6,7 @@ from time import perf_counter from typing import Dict from model_server.accessors import generate_file_accessor, write_accessor_data_to_file -from model_server.models import Model +from model_server.models import SemanticSegmentationModel from pydantic import BaseModel @@ -29,11 +29,12 @@ class WorkflowRunRecord(BaseModel): timer_results: Dict[str, float] -def infer_image_to_image(fpi: Path, model: Model, where_output: Path, **kwargs) -> WorkflowRunRecord: +def classify_pixels(fpi: Path, model: SemanticSegmentationModel, where_output: Path, **kwargs) -> WorkflowRunRecord: """ - Generic workflow where a model processes an input image into an output image + Run a semantic segmentation model to compute a binary mask from an input image + :param fpi: Path object that references input image file - :param model: model object + :param model: semantic segmentation model instance :param where_output: Path object that references output image directory :param kwargs: variable-length keyword arguments :return: record object diff --git a/tests/test_api.py b/tests/test_api.py index e6c0f7997ea1a6396d30530d964f40cb7213fbe6..e846be74218912a97b346d862bdce4850c0eb733 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -4,7 +4,7 @@ import requests import unittest from conf.testing import czifile -from model_server.models import DummyImageToImageModel +from model_server.models import DummySegmentationModel class TestServerBaseClass(unittest.TestCase): def setUp(self) -> None: @@ -70,7 +70,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): resp_list = requests.get(self.uri + 'models') self.assertEqual(resp_list.status_code, 200) rj = resp_list.json() - self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel') + self.assertEqual(rj[model_id]['class'], 'DummySegmentationModel') return model_id def test_respond_with_error_when_invalid_filepath_requested(self): @@ -89,7 +89,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): def test_i2i_inference_errors_when_model_not_found(self): model_id = 'not_a_real_model' resp = requests.put( - self.uri + f'infer/from_image_file', + self.uri + f'workflows/segment', params={ 'model_id': model_id, 'input_filename': 'not_a_real_file.name' @@ -101,7 +101,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): model_id = self.test_load_dummy_model() self.copy_input_file_to_server() resp_infer = requests.put( - self.uri + f'infer/from_image_file', + self.uri + f'workflows/segment', params={ 'model_id': model_id, 'input_filename': czifile['filename'], diff --git a/tests/test_model.py b/tests/test_model.py index 661aeadf93825afc430e55f56326def0fe3f43b0..0d5f98aea4aee70da25ef3edf71b36eebac01a3e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,22 +1,22 @@ import unittest from conf.testing import czifile from model_server.accessors import CziImageFileAccessor -from model_server.models import DummyImageToImageModel, CouldNotLoadModelError +from model_server.models import DummySegmentationModel, CouldNotLoadModelError class TestCziImageFileAccess(unittest.TestCase): def setUp(self) -> None: self.cf = CziImageFileAccessor(czifile['path']) def test_instantiate_model(self): - model = DummyImageToImageModel(params=None) + model = DummySegmentationModel(params=None) self.assertTrue(model.loaded) def test_instantiate_model_with_nondefault_kwarg(self): - model = DummyImageToImageModel(autoload=False) + model = DummySegmentationModel(autoload=False) self.assertFalse(model.autoload, 'Could not override autoload flag in subclass of Model.') def test_raise_error_if_cannot_load_model(self): - class UnloadableDummyImageToImageModel(DummyImageToImageModel): + class UnloadableDummyImageToImageModel(DummySegmentationModel): def load(self): return False @@ -24,7 +24,7 @@ class TestCziImageFileAccess(unittest.TestCase): mi = UnloadableDummyImageToImageModel() def test_czifile_is_correct_shape(self): - model = DummyImageToImageModel() + model = DummySegmentationModel() img, _ = model.infer(self.cf) w = czifile['w'] diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 2cab499d04c222651894d9064c1f7f990050d626..a88d8791eb4cba9ce6544124e51edacfa3a79c0a 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -1,16 +1,16 @@ import unittest from conf.testing import czifile, output_path -from model_server.models import DummyImageToImageModel -from model_server.workflows import infer_image_to_image +from model_server.models import DummySegmentationModel +from model_server.workflows import classify_pixels class TestGetSessionObject(unittest.TestCase): def setUp(self) -> None: - self.model = DummyImageToImageModel() + self.model = DummySegmentationModel() def test_single_session_instance(self): - result = infer_image_to_image(czifile['path'], self.model, output_path, channel=2) + result = classify_pixels(czifile['path'], self.model, output_path, channel=2) self.assertTrue(result.success) import tifffile