From 73754817a88edd79da8260b2d9e46990946f37c0 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Wed, 20 Dec 2023 13:43:19 +0100 Subject: [PATCH] Adapted base segmentation worklow and API endpoint for generic semantic segmentation model; made dummy model this type too --- extensions/ilastik/tests/test_ilastik.py | 6 +++--- model_server/api.py | 10 +++++----- model_server/models.py | 7 ++++++- model_server/workflows.py | 9 +++++---- tests/test_api.py | 8 ++++---- tests/test_model.py | 10 +++++----- tests/test_workflow.py | 8 ++++---- 7 files changed, 32 insertions(+), 26 deletions(-) diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py index 1dfbde16..5881679b 100644 --- a/extensions/ilastik/tests/test_ilastik.py +++ b/extensions/ilastik/tests/test_ilastik.py @@ -7,7 +7,7 @@ import numpy as np from conf.testing import czifile, ilastik_classifiers, output_path from model_server.accessors import CziImageFileAccessor, InMemoryDataAccessor, write_accessor_data_to_file from extensions.ilastik import models as ilm -from model_server.workflows import infer_image_to_image +from model_server.workflows import classify_pixels from tests.test_api import TestServerBaseClass class TestIlastikPixelClassification(unittest.TestCase): @@ -115,7 +115,7 @@ class TestIlastikPixelClassification(unittest.TestCase): self.assertEqual(objmap.data.max(), 3) def test_ilastik_pixel_classification_as_workflow(self): - result = infer_image_to_image( + result = classify_pixels( czifile['path'], ilm.IlastikPixelClassifierModel( {'project_file': ilastik_classifiers['px']} @@ -243,7 +243,7 @@ class TestIlastikOverApi(TestServerBaseClass): model_id = self.test_load_ilastik_pixel_model() resp_infer = requests.put( - self.uri + f'infer/from_image_file', + self.uri + f'workflows/segment', params={ 'model_id': model_id, 'input_filename': czifile['filename'], diff --git a/model_server/api.py b/model_server/api.py index bc9a57fe..a1c78531 100644 --- a/model_server/api.py +++ b/model_server/api.py @@ -1,9 +1,9 @@ from fastapi import FastAPI, HTTPException -from model_server.models import DummyImageToImageModel +from model_server.models import DummySegmentationModel from model_server.session import Session, InvalidPathError from model_server.validators import validate_workflow_inputs -from model_server.workflows import infer_image_to_image +from model_server.workflows import classify_pixels from extensions.ilastik.workflows import infer_px_then_ob_model app = FastAPI(debug=True) @@ -67,13 +67,13 @@ def list_active_models(): @app.put('/models/dummy/load/') def load_dummy_model() -> dict: - return {'model_id': session.load_model(DummyImageToImageModel)} + return {'model_id': session.load_model(DummySegmentationModel)} -@app.put('/infer/from_image_file') +@app.put('/workflows/segment') def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: inpath = session.paths['inbound_images'] / input_filename validate_workflow_inputs([model_id], [inpath]) - record = infer_image_to_image( + record = classify_pixels( inpath, session.models[model_id]['object'], session.paths['outbound_images'], diff --git a/model_server/models.py b/model_server/models.py index 4f07da57..f0864147 100644 --- a/model_server/models.py +++ b/model_server/models.py @@ -91,7 +91,7 @@ class InstanceSegmentationModel(ImageToImageModel): -class DummyImageToImageModel(ImageToImageModel): +class DummySegmentationModel(SemanticSegmentationModel): model_id = 'dummy_make_white_square' @@ -106,6 +106,11 @@ class DummyImageToImageModel(ImageToImageModel): result[floor(0.25 * h) : floor(0.75 * h), floor(0.25 * w) : floor(0.75 * w)] = 255 return InMemoryDataAccessor(data=result), {'success': True} + def label_pixel_class( + self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor: + mask, _ = self.infer(img) + return mask + class Error(Exception): pass diff --git a/model_server/workflows.py b/model_server/workflows.py index 6eb7b50d..9c6da214 100644 --- a/model_server/workflows.py +++ b/model_server/workflows.py @@ -6,7 +6,7 @@ from time import perf_counter from typing import Dict from model_server.accessors import generate_file_accessor, write_accessor_data_to_file -from model_server.models import Model +from model_server.models import SemanticSegmentationModel from pydantic import BaseModel @@ -29,11 +29,12 @@ class WorkflowRunRecord(BaseModel): timer_results: Dict[str, float] -def infer_image_to_image(fpi: Path, model: Model, where_output: Path, **kwargs) -> WorkflowRunRecord: +def classify_pixels(fpi: Path, model: SemanticSegmentationModel, where_output: Path, **kwargs) -> WorkflowRunRecord: """ - Generic workflow where a model processes an input image into an output image + Run a semantic segmentation model to compute a binary mask from an input image + :param fpi: Path object that references input image file - :param model: model object + :param model: semantic segmentation model instance :param where_output: Path object that references output image directory :param kwargs: variable-length keyword arguments :return: record object diff --git a/tests/test_api.py b/tests/test_api.py index e6c0f799..e846be74 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -4,7 +4,7 @@ import requests import unittest from conf.testing import czifile -from model_server.models import DummyImageToImageModel +from model_server.models import DummySegmentationModel class TestServerBaseClass(unittest.TestCase): def setUp(self) -> None: @@ -70,7 +70,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): resp_list = requests.get(self.uri + 'models') self.assertEqual(resp_list.status_code, 200) rj = resp_list.json() - self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel') + self.assertEqual(rj[model_id]['class'], 'DummySegmentationModel') return model_id def test_respond_with_error_when_invalid_filepath_requested(self): @@ -89,7 +89,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): def test_i2i_inference_errors_when_model_not_found(self): model_id = 'not_a_real_model' resp = requests.put( - self.uri + f'infer/from_image_file', + self.uri + f'workflows/segment', params={ 'model_id': model_id, 'input_filename': 'not_a_real_file.name' @@ -101,7 +101,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): model_id = self.test_load_dummy_model() self.copy_input_file_to_server() resp_infer = requests.put( - self.uri + f'infer/from_image_file', + self.uri + f'workflows/segment', params={ 'model_id': model_id, 'input_filename': czifile['filename'], diff --git a/tests/test_model.py b/tests/test_model.py index 661aeadf..0d5f98ae 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,22 +1,22 @@ import unittest from conf.testing import czifile from model_server.accessors import CziImageFileAccessor -from model_server.models import DummyImageToImageModel, CouldNotLoadModelError +from model_server.models import DummySegmentationModel, CouldNotLoadModelError class TestCziImageFileAccess(unittest.TestCase): def setUp(self) -> None: self.cf = CziImageFileAccessor(czifile['path']) def test_instantiate_model(self): - model = DummyImageToImageModel(params=None) + model = DummySegmentationModel(params=None) self.assertTrue(model.loaded) def test_instantiate_model_with_nondefault_kwarg(self): - model = DummyImageToImageModel(autoload=False) + model = DummySegmentationModel(autoload=False) self.assertFalse(model.autoload, 'Could not override autoload flag in subclass of Model.') def test_raise_error_if_cannot_load_model(self): - class UnloadableDummyImageToImageModel(DummyImageToImageModel): + class UnloadableDummyImageToImageModel(DummySegmentationModel): def load(self): return False @@ -24,7 +24,7 @@ class TestCziImageFileAccess(unittest.TestCase): mi = UnloadableDummyImageToImageModel() def test_czifile_is_correct_shape(self): - model = DummyImageToImageModel() + model = DummySegmentationModel() img, _ = model.infer(self.cf) w = czifile['w'] diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 2cab499d..a88d8791 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -1,16 +1,16 @@ import unittest from conf.testing import czifile, output_path -from model_server.models import DummyImageToImageModel -from model_server.workflows import infer_image_to_image +from model_server.models import DummySegmentationModel +from model_server.workflows import classify_pixels class TestGetSessionObject(unittest.TestCase): def setUp(self) -> None: - self.model = DummyImageToImageModel() + self.model = DummySegmentationModel() def test_single_session_instance(self): - result = infer_image_to_image(czifile['path'], self.model, output_path, channel=2) + result = classify_pixels(czifile['path'], self.model, output_path, channel=2) self.assertTrue(result.success) import tifffile -- GitLab