Skip to content
Snippets Groups Projects
Commit 01adb4b2 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Pass params model straight to product exporters, so less code dedicated to matching parameters

parent 401c6ee7
No related branches found
No related tags found
No related merge requests found
......@@ -23,7 +23,7 @@ def draw_boxes_on_2d_image(yx_img, boxes, **kwargs):
draw.rectangle([(x0, y0), (x1, y1)], outline='white', width=linewidth)
if kwargs.get('add_label') is True:
if kwargs.get('draw_label') is True:
draw.text((xm, y0), f'{la:04d}', fill='white', anchor='mb')
return pilimg
......
......@@ -331,8 +331,8 @@ def export_multichannel_patches_from_zstack(
where: Path,
stack: GenericImageDataAccessor,
zmask_meta: list,
ch_rgb_overlay: tuple = None,
overlay_gain: tuple = (1.0, 1.0, 1.0),
rgb_overlay_channels: list = None,
rgb_overlay_weights: list = [1.0, 1.0, 1.0],
ch_white: int = None,
**kwargs
):
......@@ -360,17 +360,17 @@ def export_multichannel_patches_from_zstack(
else:
mdata = idata
if ch_rgb_overlay:
assert len(ch_rgb_overlay) == 3
assert len(overlay_gain) == 3
for ii, ci in enumerate(ch_rgb_overlay):
if rgb_overlay_channels:
assert len(rgb_overlay_channels) == 3
assert len(rgb_overlay_weights) == 3
for ii, ci in enumerate(rgb_overlay_channels):
if ci is None:
continue
assert isinstance(ci, int)
assert ci < stack.chroma
mdata[:, :, ii, :] = _safe_add(
mdata[:, :, ii, :],
overlay_gain[ii],
rgb_overlay_weights[ii],
idata[:, :, ci, :]
)
......
......@@ -122,7 +122,6 @@ def infer_object_map_from_zstack(
Path(output_folder_path) / 'patch_masks',
fstem,
patches_channel,
exports.patch_masks
)
if exports.annotated_z_stack:
......
......@@ -9,7 +9,7 @@ from sklearn.linear_model import LinearRegression
from extensions.chaeo.annotators import draw_boxes_on_3d_image
from extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack, get_patches_from_zmask_meta, get_patch_masks_from_zmask_meta
from extensions.chaeo.params import RoiFilter, RoiSetExportParams
from extensions.chaeo.params import AnnotatedZStackParams, PatchParams, RoiFilter, RoiSetExportParams
from extensions.chaeo.process import mask_largest_object
from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file
from model_server.models import InstanceSegmentationModel
......@@ -40,7 +40,7 @@ class RoiSet(object):
def get_argmax(self):
return self.interm.argmax
def export_3d_patches(self, where, prefix, channel, params: RoiSetExportParams):
def export_3d_patches(self, where, prefix, channel, params: PatchParams):
if not self.count:
return
files = export_patches_from_zstack(
......@@ -48,12 +48,11 @@ class RoiSet(object):
self.acc_raw.get_one_channel_data(channel),
self.zmask_meta,
prefix=prefix,
draw_bounding_box=params.draw_bounding_box,
rescale_clip=params.rescale_clip,
make_3d=True,
**params.__dict__,
)
def export_2d_patches_for_training(self, where, prefix, channel, params: RoiSetExportParams):
def export_2d_patches_for_training(self, where, prefix, channel, params: PatchParams):
if not self.count:
return
files = export_multichannel_patches_from_zstack(
......@@ -62,15 +61,14 @@ class RoiSet(object):
self.zmask_meta,
ch_white=channel,
prefix=prefix,
rescale_clip=params.rescale_clip,
make_3d=False,
focus_metric=params.focus_metric,
**params.__dict__,
)
df_patches = pd.DataFrame(files)
self.df = pd.merge(self.df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index')
self.df['patch_id'] = self.df.apply(lambda _: uuid4(), axis=1)
def export_2d_patches_for_annotation(self, where, prefix, channel, params: RoiSetExportParams):
def export_2d_patches_for_annotation(self, where, prefix, channel, params: PatchParams):
if not self.count:
return
files = export_multichannel_patches_from_zstack(
......@@ -78,21 +76,14 @@ class RoiSet(object):
self.acc_raw,
self.zmask_meta,
prefix=prefix,
draw_bounding_box=params.draw_bounding_box,
rescale_clip=params.rescale_clip,
make_3d=False,
focus_metric=params.focus_metric,
ch_white=channel,
ch_rgb_overlay=params.rgb_overlay_channels,
bounding_box_channel=1,
bounding_box_linewidth=2,
draw_contour=params.draw_contour,
draw_mask=params.draw_mask,
overlay_gain=params.rgb_overlay_weights,
pad_to=params.pad_to,
**params.__dict__,
)
def export_patch_masks(self, where, prefix, channel, params: RoiSetExportParams):
def export_patch_masks(self, where, prefix, channel):
if not self.count:
return
files = export_patch_masks_from_zstack(
......@@ -102,12 +93,12 @@ class RoiSet(object):
prefix=prefix,
)
def export_annotated_zstack(self, where, prefix, channel, params: RoiSetExportParams):
def export_annotated_zstack(self, where, prefix, channel, params: AnnotatedZStackParams):
annotated = InMemoryDataAccessor(
draw_boxes_on_3d_image(
self.acc_raw.get_one_channel_data(channel).data,
self.zmask_meta,
add_label=params.draw_label,
**params.__dict__,
)
)
write_accessor_data_to_file(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment