Skip to content
Snippets Groups Projects
Commit 3fa8292a authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

API maps RoiSet products as accessors

parent 0bf35725
No related branches found
No related tags found
No related merge requests found
......@@ -50,7 +50,7 @@ class Model(ABC):
@property
def name(self):
return f'{self.__class__}'
return f'{self.__class__.__name__}'
......@@ -131,7 +131,6 @@ class InstanceSegmentationModel(ImageToImageModel):
class BinaryThresholdSegmentationModel(SemanticSegmentationModel):
# TODO: also allow relative threshold
def __init__(self, tr: float = 0.5):
self.tr = tr
......
......@@ -53,20 +53,7 @@ class RoiSetObjectMapParams(PipelineParams):
},
'expand_box_by': [128, 2]
})
# TODO: maybe don't support all these exports here; instead leverage interm accessors
export_params: RoiSetExportParams = RoiSetExportParams(**{
'annotated_patches_2d': {
'draw_bounding_box': True,
'pad_to': 256,
},
'patches_2d': {
'draw_bounding_box': True,
'draw_mask': False,
},
'patch_masks': {},
'object_classes': True,
'dataframe': True,
})
export_params: RoiSetExportParams = RoiSetExportParams()
derived_channels_input_channel: Union[int, None] = Field(
None,
description='Channel of input image from which to compute derived channels; use all if empty'
......@@ -114,6 +101,10 @@ def roiset_object_map_pipeline(
d['labeled'] = get_label_ids(d.last)
rois = RoiSet.from_object_ids(d['input'], d['labeled'], RoiSetMetaParams(**k['roi_params']))
# optionally append RoiSet products
for ki, vi in rois.get_export_product_accessors(k['patches_channel'], RoiSetExportParams(**k['export_params'])).items():
d[ki] = vi
# optionally run an object classifier if specified
if obmod := models.get('object_classifier_model'):
obmod_name = k['object_classifier_model_id']
......
......@@ -10,7 +10,7 @@ from typing import Union
import pandas as pd
from ..conf import defaults
from .accessors import GenericImageDataAccessor
from .accessors import GenericImageDataAccessor, PatchStack
from .models import Model
logger = logging.getLogger(__name__)
......@@ -166,7 +166,10 @@ class _Session(object):
f'Cannot overwrite accessor that is already written to {old_fp}'
)
acc.write(fp)
if isinstance(acc, PatchStack):
acc.export_pyxcz(fp)
else:
acc.write(fp)
self.accessors[acc_id]['filepath'] = fp.__str__()
return fp.name
......
......@@ -39,8 +39,7 @@ class BaseTestRoiSetMonoProducts(object):
def _get_export_params(self):
return {
'pixel_probabilities': True,
'patches_3d': {},
'patches_3d': None,
'annotated_patches_2d': {
'draw_bounding_box': True,
'rgb_overlay_channels': [3, None, None],
......@@ -51,12 +50,8 @@ class BaseTestRoiSetMonoProducts(object):
'draw_bounding_box': False,
'draw_mask': False,
},
'patch_masks': {
'pad_to': 256,
},
'annotated_zstacks': {},
'annotated_zstacks': None,
'object_classes': True,
'dataframe': True,
}
def _get_roi_params(self):
......@@ -119,7 +114,8 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase):
{f'{k}_model': v['model'] for k, v in self._get_models().items()},
**params.dict()
)
self.assertEqual(trace.pop('annotated_patches_2d').count, 13)
self.assertEqual(trace.pop('patches_2d').count, 13)
trace.write_interm(Path(output_path) / 'trace', 'roiset_worfklow_trace', skip_first=False, skip_last=False)
self.assertTrue('ob_id' in trace.keys())
self.assertEqual(len(trace['labeled'].unique()[0]), 14)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment