From 5dc646d95c2784e6ff3bedb1db0b8b1af409d50d Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Mon, 16 Oct 2023 16:27:52 +0200 Subject: [PATCH] Made the one implemented ilastik object classification model accomodate only the ilastik workflow that requires pixel prediction maps --- extensions/ilastik/models.py | 8 ++++---- extensions/ilastik/router.py | 4 ++-- extensions/ilastik/tests/test_ilastik.py | 6 +++--- extensions/ilastik/workflows.py | 6 +++--- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/extensions/ilastik/models.py b/extensions/ilastik/models.py index 47c81abb..a6daf81a 100644 --- a/extensions/ilastik/models.py +++ b/extensions/ilastik/models.py @@ -77,13 +77,13 @@ class IlastikPixelClassifierModel(IlastikImageToImageModel): ) return InMemoryDataAccessor(data=yxcz), {'success': True} -class IlastikObjectClassifierModel(IlastikImageToImageModel): - model_id = 'ilastik_object_classification' +class IlastikObjectClassifierFromPixelPredictionsModel(IlastikImageToImageModel): + model_id = 'ilastik_object_classification_from_pixel_predictions' @staticmethod def get_workflow(): - from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflow - return ObjectClassificationWorkflow + from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction + return ObjectClassificationWorkflowPrediction def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict): tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') diff --git a/extensions/ilastik/router.py b/extensions/ilastik/router.py index 155d8fa0..3287dc47 100644 --- a/extensions/ilastik/router.py +++ b/extensions/ilastik/router.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, HTTPException from model_server.session import Session from model_server.validators import validate_workflow_inputs -from extensions.ilastik.models import IlastikImageToImageModel, IlastikPixelClassifierModel, IlastikObjectClassifierModel +from extensions.ilastik.models import IlastikImageToImageModel, IlastikPixelClassifierModel, IlastikObjectClassifierFromPixelPredictionsModel from model_server.models import ParameterExpectedError from extensions.ilastik.workflows import infer_px_then_ob_model @@ -46,7 +46,7 @@ def load_ilastik_pixel_classification_model(project_file: str, duplicate: bool = @router.put('/object_classification/load/') def load_ilastik_object_classification_model(project_file: str, duplicate: bool = True) -> dict: - return load_ilastik_model(IlastikObjectClassifierModel, project_file, duplicate=duplicate) + return load_ilastik_model(IlastikObjectClassifierFromPixelPredictionsModel, project_file, duplicate=duplicate) @router.put('/pixel_then_object_classification/infer') def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: str, channel: int = None) -> dict: diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py index 555d8286..348f955e 100644 --- a/extensions/ilastik/tests/test_ilastik.py +++ b/extensions/ilastik/tests/test_ilastik.py @@ -5,7 +5,7 @@ import numpy as np import conf.testing from model_server.accessors import CziImageFileAccessor, InMemoryDataAccessor, write_accessor_data_to_file -from extensions.ilastik.models import IlastikObjectClassifierModel, IlastikPixelClassifierModel +from extensions.ilastik.models import IlastikObjectClassifierFromPixelPredictionsModel, IlastikPixelClassifierModel from model_server.workflows import infer_image_to_image from tests.test_api import TestServerBaseClass @@ -83,7 +83,7 @@ class TestIlastikPixelClassification(unittest.TestCase): def test_run_object_classifier(self): self.test_run_pixel_classifier() fp = conf.testing.czifile['path'] - model = IlastikObjectClassifierModel( + model = IlastikObjectClassifierFromPixelPredictionsModel( {'project_file': conf.testing.ilastik['object_classifier']} ) objmap, _ = model.infer(self.mono_image, self.pxmap) @@ -167,7 +167,7 @@ class TestIlastikOverApi(TestServerBaseClass): resp_list = requests.get(self.uri + 'models') self.assertEqual(resp_list.status_code, 200) rj = resp_list.json() - self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierModel') + self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierFromPixelPredictionsModel') return model_id def test_ilastik_infer_pixel_probability(self): diff --git a/extensions/ilastik/workflows.py b/extensions/ilastik/workflows.py index 45763ad7..1f622b67 100644 --- a/extensions/ilastik/workflows.py +++ b/extensions/ilastik/workflows.py @@ -4,7 +4,7 @@ Implementation of image analysis work behind API endpoints, without knowledge of from pathlib import Path from typing import Dict -from extensions.ilastik.models import IlastikPixelClassifierModel, IlastikObjectClassifierModel +from extensions.ilastik.models import IlastikPixelClassifierModel, IlastikObjectClassifierFromPixelPredictionsModel from model_server.accessors import generate_file_accessor, write_accessor_data_to_file from model_server.workflows import Timer @@ -23,7 +23,7 @@ class WorkflowRunRecord(BaseModel): def infer_px_then_ob_model( fpi: Path, px_model: IlastikPixelClassifierModel, - ob_model: IlastikObjectClassifierModel, + ob_model: IlastikObjectClassifierFromPixelPredictionsModel, where_output: Path, **kwargs ) -> WorkflowRunRecord: @@ -38,7 +38,7 @@ def infer_px_then_ob_model( :return: """ assert isinstance(px_model, IlastikPixelClassifierModel) - assert isinstance(ob_model, IlastikObjectClassifierModel) + assert isinstance(ob_model, IlastikObjectClassifierFromPixelPredictionsModel) ti = Timer() ch = kwargs.get('channel') -- GitLab