Skip to content
Snippets Groups Projects
Commit 5dc646d9 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Made the one implemented ilastik object classification model accomodate only...

Made the one implemented ilastik object classification model accomodate only the ilastik workflow that requires pixel prediction maps
parent 5eb99ebd
No related branches found
No related tags found
No related merge requests found
...@@ -77,13 +77,13 @@ class IlastikPixelClassifierModel(IlastikImageToImageModel): ...@@ -77,13 +77,13 @@ class IlastikPixelClassifierModel(IlastikImageToImageModel):
) )
return InMemoryDataAccessor(data=yxcz), {'success': True} return InMemoryDataAccessor(data=yxcz), {'success': True}
class IlastikObjectClassifierModel(IlastikImageToImageModel): class IlastikObjectClassifierFromPixelPredictionsModel(IlastikImageToImageModel):
model_id = 'ilastik_object_classification' model_id = 'ilastik_object_classification_from_pixel_predictions'
@staticmethod @staticmethod
def get_workflow(): def get_workflow():
from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflow from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction
return ObjectClassificationWorkflow return ObjectClassificationWorkflowPrediction
def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict): def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
......
...@@ -3,7 +3,7 @@ from fastapi import APIRouter, HTTPException ...@@ -3,7 +3,7 @@ from fastapi import APIRouter, HTTPException
from model_server.session import Session from model_server.session import Session
from model_server.validators import validate_workflow_inputs from model_server.validators import validate_workflow_inputs
from extensions.ilastik.models import IlastikImageToImageModel, IlastikPixelClassifierModel, IlastikObjectClassifierModel from extensions.ilastik.models import IlastikImageToImageModel, IlastikPixelClassifierModel, IlastikObjectClassifierFromPixelPredictionsModel
from model_server.models import ParameterExpectedError from model_server.models import ParameterExpectedError
from extensions.ilastik.workflows import infer_px_then_ob_model from extensions.ilastik.workflows import infer_px_then_ob_model
...@@ -46,7 +46,7 @@ def load_ilastik_pixel_classification_model(project_file: str, duplicate: bool = ...@@ -46,7 +46,7 @@ def load_ilastik_pixel_classification_model(project_file: str, duplicate: bool =
@router.put('/object_classification/load/') @router.put('/object_classification/load/')
def load_ilastik_object_classification_model(project_file: str, duplicate: bool = True) -> dict: def load_ilastik_object_classification_model(project_file: str, duplicate: bool = True) -> dict:
return load_ilastik_model(IlastikObjectClassifierModel, project_file, duplicate=duplicate) return load_ilastik_model(IlastikObjectClassifierFromPixelPredictionsModel, project_file, duplicate=duplicate)
@router.put('/pixel_then_object_classification/infer') @router.put('/pixel_then_object_classification/infer')
def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: str, channel: int = None) -> dict: def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: str, channel: int = None) -> dict:
......
...@@ -5,7 +5,7 @@ import numpy as np ...@@ -5,7 +5,7 @@ import numpy as np
import conf.testing import conf.testing
from model_server.accessors import CziImageFileAccessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.accessors import CziImageFileAccessor, InMemoryDataAccessor, write_accessor_data_to_file
from extensions.ilastik.models import IlastikObjectClassifierModel, IlastikPixelClassifierModel from extensions.ilastik.models import IlastikObjectClassifierFromPixelPredictionsModel, IlastikPixelClassifierModel
from model_server.workflows import infer_image_to_image from model_server.workflows import infer_image_to_image
from tests.test_api import TestServerBaseClass from tests.test_api import TestServerBaseClass
...@@ -83,7 +83,7 @@ class TestIlastikPixelClassification(unittest.TestCase): ...@@ -83,7 +83,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
def test_run_object_classifier(self): def test_run_object_classifier(self):
self.test_run_pixel_classifier() self.test_run_pixel_classifier()
fp = conf.testing.czifile['path'] fp = conf.testing.czifile['path']
model = IlastikObjectClassifierModel( model = IlastikObjectClassifierFromPixelPredictionsModel(
{'project_file': conf.testing.ilastik['object_classifier']} {'project_file': conf.testing.ilastik['object_classifier']}
) )
objmap, _ = model.infer(self.mono_image, self.pxmap) objmap, _ = model.infer(self.mono_image, self.pxmap)
...@@ -167,7 +167,7 @@ class TestIlastikOverApi(TestServerBaseClass): ...@@ -167,7 +167,7 @@ class TestIlastikOverApi(TestServerBaseClass):
resp_list = requests.get(self.uri + 'models') resp_list = requests.get(self.uri + 'models')
self.assertEqual(resp_list.status_code, 200) self.assertEqual(resp_list.status_code, 200)
rj = resp_list.json() rj = resp_list.json()
self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierModel') self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierFromPixelPredictionsModel')
return model_id return model_id
def test_ilastik_infer_pixel_probability(self): def test_ilastik_infer_pixel_probability(self):
......
...@@ -4,7 +4,7 @@ Implementation of image analysis work behind API endpoints, without knowledge of ...@@ -4,7 +4,7 @@ Implementation of image analysis work behind API endpoints, without knowledge of
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict
from extensions.ilastik.models import IlastikPixelClassifierModel, IlastikObjectClassifierModel from extensions.ilastik.models import IlastikPixelClassifierModel, IlastikObjectClassifierFromPixelPredictionsModel
from model_server.accessors import generate_file_accessor, write_accessor_data_to_file from model_server.accessors import generate_file_accessor, write_accessor_data_to_file
from model_server.workflows import Timer from model_server.workflows import Timer
...@@ -23,7 +23,7 @@ class WorkflowRunRecord(BaseModel): ...@@ -23,7 +23,7 @@ class WorkflowRunRecord(BaseModel):
def infer_px_then_ob_model( def infer_px_then_ob_model(
fpi: Path, fpi: Path,
px_model: IlastikPixelClassifierModel, px_model: IlastikPixelClassifierModel,
ob_model: IlastikObjectClassifierModel, ob_model: IlastikObjectClassifierFromPixelPredictionsModel,
where_output: Path, where_output: Path,
**kwargs **kwargs
) -> WorkflowRunRecord: ) -> WorkflowRunRecord:
...@@ -38,7 +38,7 @@ def infer_px_then_ob_model( ...@@ -38,7 +38,7 @@ def infer_px_then_ob_model(
:return: :return:
""" """
assert isinstance(px_model, IlastikPixelClassifierModel) assert isinstance(px_model, IlastikPixelClassifierModel)
assert isinstance(ob_model, IlastikObjectClassifierModel) assert isinstance(ob_model, IlastikObjectClassifierFromPixelPredictionsModel)
ti = Timer() ti = Timer()
ch = kwargs.get('channel') ch = kwargs.get('channel')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment