Skip to content
Snippets Groups Projects
Commit 6a24b7e2 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Removed direct pixel map access from workflow function

parent cb8000e3
No related branches found
No related tags found
No related merge requests found
from typing import Dict, List, Union
from typing import List, Union
from pydantic import BaseModel
......@@ -17,11 +17,22 @@ class PatchParams(BaseModel):
class AnnotatedZStackParams(BaseModel):
draw_label: bool = False
class RoiSetExportMetaParams(BaseModel):
class RoiFilterRange(BaseModel):
min: float
max: float
class RoiFilter(BaseModel):
area: Union[RoiFilterRange, None] = None
solidity: Union[RoiFilterRange, None] = None
class RoiSetMetaParams(BaseModel):
mask_type: str = 'boxes'
filters: RoiFilter = None
expand_box_by: List[int] = [128, 0]
class RoiSetExportParams(BaseModel):
meta: RoiSetExportMetaParams = RoiSetExportMetaParams()
pixel_probabilities: bool = False
patches_3d: Union[PatchParams, None] = None
patches_2d_for_annotation: Union[PatchParams, None] = None
......@@ -30,11 +41,3 @@ class RoiSetExportParams(BaseModel):
annotated_zstacks: Union[AnnotatedZStackParams, None] = None
class RoiFilterRange(BaseModel):
min: float
max: float
class RoiFilter(BaseModel):
area: Union[RoiFilterRange, None] = None
solidity: Union[RoiFilterRange, None] = None
......@@ -5,7 +5,7 @@ import numpy as np
from conf.testing import output_path
from extensions.chaeo.conf.testing import multichannel_zstack, pixel_classifier, pipeline_params
from extensions.chaeo.params import RoiSetExportParams
from extensions.chaeo.params import RoiSetExportParams, RoiSetMetaParams
from extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack
from extensions.chaeo.workflows import infer_object_map_from_zstack
from extensions.chaeo.zmask import build_zmask_from_object_mask
......@@ -208,8 +208,14 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
self.pxmodel,
DummyInstanceSegmentationModel(),
]
roi_params = RoiSetMetaParams(**{
'mask_type': 'boxes',
'filters': {},
'expand_box_by': [128, 2]
})
export_params = RoiSetExportParams(**{
'expand_box_by': [128, 2],
'pixel_probabilities': True,
'patches_3d': {},
'patches_2d_for_annotation': {
......@@ -229,10 +235,10 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
multichannel_zstack['path'],
output_path / 'roiset' / 'workflow',
models,
pxmap_foreground_channel=pp['pxmap_channel'],
pxmap_threshold=pp['pxmap_threshold'],
pixel_class=pp['pxmap_channel'],
pixel_probability_threshold=pp['pxmap_threshold'],
segmentation_channel=pp['segmentation_channel'],
patches_channel=pp['patches_channel'],
exports=export_params,
export_params=export_params,
)
......@@ -9,7 +9,7 @@ from skimage.measure import label, regionprops_table
from skimage.morphology import dilation
from sklearn.model_selection import train_test_split
from extensions.chaeo.params import RoiSetExportParams
from extensions.chaeo.params import RoiSetExportParams, RoiSetMetaParams
from extensions.chaeo.process import mask_largest_object
from extensions.chaeo.zmask import RoiSet
......@@ -23,15 +23,14 @@ def infer_object_map_from_zstack(
input_file_path: str,
output_folder_path: str,
models: List[Model],
pxmap_foreground_channel: int,
pxmap_threshold: float,
segmentation_channel: int,
patches_channel: int,
zmask_zindex: int = None, # None for MIP,
zmask_clip: int = None,
zmask_type: str = 'boxes',
zmask_filters: Dict = None,
exports: RoiSetExportParams = RoiSetExportParams(),
roi_params: RoiSetMetaParams = RoiSetMetaParams(),
export_params: RoiSetExportParams = RoiSetExportParams(),
pixel_class=0,
pixel_probability_threshold=0.6,
) -> Dict:
assert len(models) == 2
pixel_classifier = models[0]
......@@ -55,29 +54,13 @@ def infer_object_map_from_zstack(
mip = InMemoryDataAccessor(
zmask_data,
)
pxmap, _ = pixel_classifier.infer(mip)
ti.click('infer_pixel_probability')
# if exports.pixel_probabilities:
# write_accessor_data_to_file(
# Path(output_folder_path) / 'pixel_probabilities' / (fstem + '.tif'),
# pxmap
# )
# ti.click('export_pixel_probability')
obmask = InMemoryDataAccessor(
pxmap.data > pxmap_threshold
)
ti.click('threshold_pixel_mask')
# pxmap, _ = pixel_classifier.infer(mip)
mip_mask = pixel_classifier.label_pixel_class(mip, pixel_class, pixel_probability_threshold,)
ti.click('classify_pixels')
# make zmask
rois = RoiSet(
obmask.get_one_channel_data(pxmap_foreground_channel),
stack,
mask_type=zmask_type,
filters=zmask_filters,
expand_box_by=exports.meta.expand_box_by,
)
# rois = RoiSet(mip_mask, stack, mask_type=zmask_type, filters=zmask_filters, expand_box_by=meta.expand_box_by)
rois = RoiSet(mip_mask, stack, params=roi_params)
ti.click('generate_zmasks')
object_class_map = rois.classify_by(patches_channel, object_classifier)
......@@ -90,7 +73,8 @@ def infer_object_map_from_zstack(
)
ti.click('export_object_classes')
rois.run_exports(Path(output_folder_path), patches_channel, fstem, exports)
rois.run_exports(Path(output_folder_path), patches_channel, fstem, export_params)
ti.click('export_roi_products')
return {
'timer_results': ti.events,
......@@ -100,7 +84,6 @@ def infer_object_map_from_zstack(
}
def transfer_ecotaxa_labels_to_patch_stacks(
where_masks: str,
where_patches: str,
......
......@@ -9,7 +9,7 @@ from sklearn.linear_model import LinearRegression
from extensions.chaeo.annotators import draw_boxes_on_3d_image
from extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack, get_patches_from_zmask_meta, get_patch_masks_from_zmask_meta
from extensions.chaeo.params import AnnotatedZStackParams, PatchParams, RoiFilter, RoiSetExportParams
from extensions.chaeo.params import AnnotatedZStackParams, PatchParams, RoiFilter, RoiSetMetaParams, RoiSetExportParams
from extensions.chaeo.process import mask_largest_object
from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file
from model_server.models import InstanceSegmentationModel
......@@ -21,16 +21,14 @@ class RoiSet(object):
self,
acc_mask: GenericImageDataAccessor,
acc_raw: GenericImageDataAccessor,
filters=None,
mask_type='contours',
expand_box_by=(0, 0),
params: RoiSetMetaParams = RoiSetMetaParams(),
):
self.zmask, self.zmask_meta, self.df, self.interm = build_zmask_from_object_mask(
acc_mask,
acc_raw,
filters=filters,
mask_type=mask_type,
expand_box_by=expand_box_by
filters=params.filters,
mask_type=params.mask_type,
expand_box_by=params.expand_box_by
)
self.acc_raw = acc_raw
......@@ -106,7 +104,7 @@ class RoiSet(object):
subdir = where / k
pr = prefix
kp = params.dict()[k]
if k == 'meta' or kp is None:
if kp is None:
continue
if k == 'patches_3d':
files = export_patches_from_zstack(
......@@ -136,6 +134,7 @@ class RoiSet(object):
def build_zmask_from_object_mask(
obmask: GenericImageDataAccessor,
zstack: GenericImageDataAccessor,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment