Skip to content
Snippets Groups Projects
Commit c1cd28eb authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Test covers derived inputs pipeline method directly

parent 29c2818c
No related branches found
No related tags found
No related merge requests found
......@@ -139,7 +139,7 @@ def roiset_object_map_pipeline(
rois = RoiSet.from_binary_mask(d['raw'], d.last, RoiSetMetaParams(**k['roi_params']))
# optionally derive additional inputs for object classification
if dpmod := models.get('pixel_classifier_derived'):
if dpmod := models.get('pixel_classifier_derived_model'):
ic = k['derived_channels_input_channel']
ocs = k['derived_channels_output_channels']
assert ic < d['raw'].chroma
......@@ -157,7 +157,7 @@ def roiset_object_map_pipeline(
else:
derived_channel_handles = None
# optionally classify if an object classifier is passed
# optionally run an object classifier if specified
if obmod := models.get('object_classifier_model'):
obmod_name = k['object_classifier_model_id']
assert isinstance(obmod, InstanceSegmentationModel)
......@@ -167,8 +167,14 @@ def roiset_object_map_pipeline(
obmod,
derived_channel_functions=derived_channel_handles
)
# add derived channels to intermediate data
for di, dacc in enumerate(rois.accs_derived):
d[f'derived_{di:02d}'] = dacc
d[obmod_name] = rois.get_object_class_map(obmod_name)
# TODO: subclass this boundary pipeline in separate module
# if p.label_params and p.export_label_interm:
# fp_la_interm = Path(output_folder_path) / 'mip_masks' / f'interm_{fstem}.png'
......
import shutil
from pathlib import Path
from shutil import copyfile
import unittest
import numpy as np
import model_server.conf.testing as conf
from model_server.base.accessors import generate_file_accessor
from tests.base.test_model import DummyInstanceSegmentationModel
from model_server.base.roiset import RoiSetMetaParams, RoiSetExportParams
import model_server.conf.testing as conf
from model_server.base.pipelines.roiset_obmap import RoiSetObjectMapParams, roiset_object_map_pipeline
......@@ -121,6 +117,7 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase):
self.assertTrue('ob_id' in trace.keys())
self.assertEqual(len(trace['ob_id']._unique()[0]), 2)
def test_object_map_workflow_with_derived_channels(self):
acc_in = generate_file_accessor(self.fpi)
params = RoiSetObjectMapParams(
......@@ -140,7 +137,8 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase):
self.assertTrue('ob_id' in trace.keys())
self.assertEqual(len(trace['ob_id']._unique()[0]), 2)
self.assertTrue(all([Path(pa).exists() for pa in record['interm']['derived_channels']]))
self.assertTrue('derived_00' in trace.keys())
self.assertEqual(trace['derived_00'].chroma, 1)
class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts):
......@@ -255,69 +253,3 @@ class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProd
# self.assertTrue(all(np.unique(acc_obmap.data) == [0, 1]))
class TestRoiSetWithDerivedChannels(conf.TestServerBaseClass):
app_name = 'api:app'
input_data = data['multichannel_zstack_raw']
# def test_object_map_workflow_with_derived_channels(self):
# models = _get_models()
# models['pixel_classifier_derived'] = models['pixel_classifier_segmentation'] # re-use same classifier
# where_out = output_path / 'roiset' / 'workflow'
#
# p = ClassifyZStackApiParams(
# input_filename=str(data['multichannel_zstack_raw']['path']),
# model_ids={
# 'pixel_classifier_segmentation': 'id_px',
# 'object_classifier': 'id_ob',
# 'pixel_classifier_derived': 'id_px',
# },
# derived_channels_input_channel=0,
# derived_channels_output_channels=[1],
# segmentation_channel=test_params['segmentation_channel'],
# patches_channel=test_params['patches_channel'],
# roi_params=_get_roi_params(),
# export_params=_get_export_params(),
# )
#
# record = export_zstack_roiset(
# data['multichannel_zstack_raw']['path'],
# where_out,
# models,
# p
# )
# self.assertTrue(all([Path(pa).exists() for pa in record['interm']['derived_channels']]))
def test_derived_channels_api(self):
resp = self._put(
'ilastik/seg/load/',
body={'project_file': _get_model_params()['pixel_classifier_segmentation']['project_file'].__str__()},
)
self.assertEqual(resp.status_code, 200)
mid_px = resp.json()['model_id']
resp = self._put(f'models/dummy_instance/load')
mid_ob = resp.json()['model_id']
p = ClassifyZStackApiParams(
input_filename=data['multichannel_zstack_raw']['path'].__str__(),
model_ids={
'pixel_classifier_segmentation': mid_px,
'object_classifier': mid_ob,
'pixel_classifier_derived': mid_px,
},
derived_channels_input_channel=0,
derived_channels_output_channels=[1],
segmentation_channel=test_params['segmentation_channel'],
patches_channel=test_params['patches_channel'],
roi_params=_get_roi_params(),
export_params=_get_export_params(),
)
resp = self._put('chaeo/classify_zstack/infer', body=p.dict())
self.assertEqual(resp.status_code, 200, resp.json())
omfp = resp.json()['object_map_filepath']
self.assertTrue(Path(omfp).exists())
acc_obmap = generate_file_accessor(omfp)
self.assertTrue(all(np.unique(acc_obmap.data) == [0, 1]))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment