Skip to content
Snippets Groups Projects
Commit 01adb4b2 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Pass params model straight to product exporters, so less code dedicated to matching parameters

parent 401c6ee7
No related branches found
No related tags found
No related merge requests found
...@@ -23,7 +23,7 @@ def draw_boxes_on_2d_image(yx_img, boxes, **kwargs): ...@@ -23,7 +23,7 @@ def draw_boxes_on_2d_image(yx_img, boxes, **kwargs):
draw.rectangle([(x0, y0), (x1, y1)], outline='white', width=linewidth) draw.rectangle([(x0, y0), (x1, y1)], outline='white', width=linewidth)
if kwargs.get('add_label') is True: if kwargs.get('draw_label') is True:
draw.text((xm, y0), f'{la:04d}', fill='white', anchor='mb') draw.text((xm, y0), f'{la:04d}', fill='white', anchor='mb')
return pilimg return pilimg
......
...@@ -331,8 +331,8 @@ def export_multichannel_patches_from_zstack( ...@@ -331,8 +331,8 @@ def export_multichannel_patches_from_zstack(
where: Path, where: Path,
stack: GenericImageDataAccessor, stack: GenericImageDataAccessor,
zmask_meta: list, zmask_meta: list,
ch_rgb_overlay: tuple = None, rgb_overlay_channels: list = None,
overlay_gain: tuple = (1.0, 1.0, 1.0), rgb_overlay_weights: list = [1.0, 1.0, 1.0],
ch_white: int = None, ch_white: int = None,
**kwargs **kwargs
): ):
...@@ -360,17 +360,17 @@ def export_multichannel_patches_from_zstack( ...@@ -360,17 +360,17 @@ def export_multichannel_patches_from_zstack(
else: else:
mdata = idata mdata = idata
if ch_rgb_overlay: if rgb_overlay_channels:
assert len(ch_rgb_overlay) == 3 assert len(rgb_overlay_channels) == 3
assert len(overlay_gain) == 3 assert len(rgb_overlay_weights) == 3
for ii, ci in enumerate(ch_rgb_overlay): for ii, ci in enumerate(rgb_overlay_channels):
if ci is None: if ci is None:
continue continue
assert isinstance(ci, int) assert isinstance(ci, int)
assert ci < stack.chroma assert ci < stack.chroma
mdata[:, :, ii, :] = _safe_add( mdata[:, :, ii, :] = _safe_add(
mdata[:, :, ii, :], mdata[:, :, ii, :],
overlay_gain[ii], rgb_overlay_weights[ii],
idata[:, :, ci, :] idata[:, :, ci, :]
) )
......
...@@ -122,7 +122,6 @@ def infer_object_map_from_zstack( ...@@ -122,7 +122,6 @@ def infer_object_map_from_zstack(
Path(output_folder_path) / 'patch_masks', Path(output_folder_path) / 'patch_masks',
fstem, fstem,
patches_channel, patches_channel,
exports.patch_masks
) )
if exports.annotated_z_stack: if exports.annotated_z_stack:
......
...@@ -9,7 +9,7 @@ from sklearn.linear_model import LinearRegression ...@@ -9,7 +9,7 @@ from sklearn.linear_model import LinearRegression
from extensions.chaeo.annotators import draw_boxes_on_3d_image from extensions.chaeo.annotators import draw_boxes_on_3d_image
from extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack, get_patches_from_zmask_meta, get_patch_masks_from_zmask_meta from extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack, get_patches_from_zmask_meta, get_patch_masks_from_zmask_meta
from extensions.chaeo.params import RoiFilter, RoiSetExportParams from extensions.chaeo.params import AnnotatedZStackParams, PatchParams, RoiFilter, RoiSetExportParams
from extensions.chaeo.process import mask_largest_object from extensions.chaeo.process import mask_largest_object
from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file
from model_server.models import InstanceSegmentationModel from model_server.models import InstanceSegmentationModel
...@@ -40,7 +40,7 @@ class RoiSet(object): ...@@ -40,7 +40,7 @@ class RoiSet(object):
def get_argmax(self): def get_argmax(self):
return self.interm.argmax return self.interm.argmax
def export_3d_patches(self, where, prefix, channel, params: RoiSetExportParams): def export_3d_patches(self, where, prefix, channel, params: PatchParams):
if not self.count: if not self.count:
return return
files = export_patches_from_zstack( files = export_patches_from_zstack(
...@@ -48,12 +48,11 @@ class RoiSet(object): ...@@ -48,12 +48,11 @@ class RoiSet(object):
self.acc_raw.get_one_channel_data(channel), self.acc_raw.get_one_channel_data(channel),
self.zmask_meta, self.zmask_meta,
prefix=prefix, prefix=prefix,
draw_bounding_box=params.draw_bounding_box,
rescale_clip=params.rescale_clip,
make_3d=True, make_3d=True,
**params.__dict__,
) )
def export_2d_patches_for_training(self, where, prefix, channel, params: RoiSetExportParams): def export_2d_patches_for_training(self, where, prefix, channel, params: PatchParams):
if not self.count: if not self.count:
return return
files = export_multichannel_patches_from_zstack( files = export_multichannel_patches_from_zstack(
...@@ -62,15 +61,14 @@ class RoiSet(object): ...@@ -62,15 +61,14 @@ class RoiSet(object):
self.zmask_meta, self.zmask_meta,
ch_white=channel, ch_white=channel,
prefix=prefix, prefix=prefix,
rescale_clip=params.rescale_clip,
make_3d=False, make_3d=False,
focus_metric=params.focus_metric, **params.__dict__,
) )
df_patches = pd.DataFrame(files) df_patches = pd.DataFrame(files)
self.df = pd.merge(self.df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index') self.df = pd.merge(self.df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index')
self.df['patch_id'] = self.df.apply(lambda _: uuid4(), axis=1) self.df['patch_id'] = self.df.apply(lambda _: uuid4(), axis=1)
def export_2d_patches_for_annotation(self, where, prefix, channel, params: RoiSetExportParams): def export_2d_patches_for_annotation(self, where, prefix, channel, params: PatchParams):
if not self.count: if not self.count:
return return
files = export_multichannel_patches_from_zstack( files = export_multichannel_patches_from_zstack(
...@@ -78,21 +76,14 @@ class RoiSet(object): ...@@ -78,21 +76,14 @@ class RoiSet(object):
self.acc_raw, self.acc_raw,
self.zmask_meta, self.zmask_meta,
prefix=prefix, prefix=prefix,
draw_bounding_box=params.draw_bounding_box,
rescale_clip=params.rescale_clip,
make_3d=False, make_3d=False,
focus_metric=params.focus_metric,
ch_white=channel, ch_white=channel,
ch_rgb_overlay=params.rgb_overlay_channels,
bounding_box_channel=1, bounding_box_channel=1,
bounding_box_linewidth=2, bounding_box_linewidth=2,
draw_contour=params.draw_contour, **params.__dict__,
draw_mask=params.draw_mask,
overlay_gain=params.rgb_overlay_weights,
pad_to=params.pad_to,
) )
def export_patch_masks(self, where, prefix, channel, params: RoiSetExportParams): def export_patch_masks(self, where, prefix, channel):
if not self.count: if not self.count:
return return
files = export_patch_masks_from_zstack( files = export_patch_masks_from_zstack(
...@@ -102,12 +93,12 @@ class RoiSet(object): ...@@ -102,12 +93,12 @@ class RoiSet(object):
prefix=prefix, prefix=prefix,
) )
def export_annotated_zstack(self, where, prefix, channel, params: RoiSetExportParams): def export_annotated_zstack(self, where, prefix, channel, params: AnnotatedZStackParams):
annotated = InMemoryDataAccessor( annotated = InMemoryDataAccessor(
draw_boxes_on_3d_image( draw_boxes_on_3d_image(
self.acc_raw.get_one_channel_data(channel).data, self.acc_raw.get_one_channel_data(channel).data,
self.zmask_meta, self.zmask_meta,
add_label=params.draw_label, **params.__dict__,
) )
) )
write_accessor_data_to_file( write_accessor_data_to_file(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment