From a3043334826dc7589d6d699abb7951a46e4f1353 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Wed, 20 Dec 2023 10:01:39 +0100 Subject: [PATCH] Renamed ZMask to RoiSet, prototyped object classification methods that add classification result as a dataframe column --- extensions/chaeo/params.py | 25 ++++++++++++++++++++++--- extensions/chaeo/workflows.py | 4 ++-- extensions/chaeo/zmask.py | 29 +++++++++++++++++++++-------- 3 files changed, 45 insertions(+), 13 deletions(-) diff --git a/extensions/chaeo/params.py b/extensions/chaeo/params.py index 1d03cef5..e516f49b 100644 --- a/extensions/chaeo/params.py +++ b/extensions/chaeo/params.py @@ -2,19 +2,22 @@ from typing import Dict, List, Union from pydantic import BaseModel + class PatchParams(BaseModel): draw_bounding_box: bool = False draw_contour: bool = False draw_mask: bool = False rescale_clip: float = 0.001 focus_metric: str = 'max_sobel' - rgb_overlay_channels: List[Union[int, None]] = (None, None, None), - rgb_overlay_weights: List[float] = (1.0, 1.0, 1.0) + rgb_overlay_channels: List[Union[int, None]] = [None, None, None] + rgb_overlay_weights: List[float] = [1.0, 1.0, 1.0] + class AnnotatedZStackParams(BaseModel): draw_label: bool = False -class ZMaskExportParams(BaseModel): + +class RoiSetExportParams(BaseModel): pixel_probabilities: bool = False patches_3d: Union[PatchParams, None] = None patches_2d_for_annotation: Union[PatchParams, None] = None @@ -22,3 +25,19 @@ class ZMaskExportParams(BaseModel): patch_masks: bool = False annotated_z_stack: Union[AnnotatedZStackParams, None] = None + +class RoiClassifierValue(BaseModel): + name: str + value: int + + +class RoiFilterRange(BaseModel): + min: float + max: float + + +class RoiFilter(BaseModel): + area: Union[RoiFilterRange, None] = None + solidity: Union[RoiFilterRange, None] = None + classifiers: List[RoiClassifierValue] = [] + diff --git a/extensions/chaeo/workflows.py b/extensions/chaeo/workflows.py index 7dfe5b84..6e9732f6 100644 --- a/extensions/chaeo/workflows.py +++ b/extensions/chaeo/workflows.py @@ -15,7 +15,7 @@ from extensions.chaeo.models import PatchStackObjectClassifier from extensions.chaeo.params import ZMaskExportParams from extensions.chaeo.process import mask_largest_object from extensions.chaeo.products import export_patches_from_zstack, export_patch_masks_from_zstack, export_multichannel_patches_from_zstack, get_patches_from_zmask_meta, get_patch_masks_from_zmask_meta -from extensions.chaeo.zmask import project_stack_from_focal_points, ZMaskObjectTable +from extensions.chaeo.zmask import project_stack_from_focal_points, RoiSet from extensions.ilastik.models import IlastikPixelClassifierModel from model_server.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file @@ -281,7 +281,7 @@ def infer_object_map_from_zstack( ti.click('threshold_pixel_mask') # make zmask - obj_table = ZMaskObjectTable( + obj_table = RoiSet( obmask.get_one_channel_data(pxmap_foreground_channel), stack.get_one_channel_data(segmentation_channel), mask_type=zmask_type, diff --git a/extensions/chaeo/zmask.py b/extensions/chaeo/zmask.py index 04bfeb41..67a78334 100644 --- a/extensions/chaeo/zmask.py +++ b/extensions/chaeo/zmask.py @@ -9,12 +9,12 @@ from sklearn.linear_model import LinearRegression from extensions.chaeo.annotators import draw_boxes_on_3d_image from extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack -from extensions.chaeo.params import ZMaskExportParams +from extensions.chaeo.params import RoiFilter, RoiSetExportParams from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file -class ZMaskObjectTable(object): +class RoiSet(object): def __init__( self, @@ -41,7 +41,7 @@ class ZMaskObjectTable(object): def get_argmax(self): return self.interm.argmax - def export_3d_patches(self, where, prefix, channel, params: ZMaskExportParams): + def export_3d_patches(self, where, prefix, channel, params: RoiSetExportParams): if not self.count: return files = export_patches_from_zstack( @@ -54,7 +54,7 @@ class ZMaskObjectTable(object): make_3d=True, ) - def export_2d_patches_for_annotation(self, where, prefix, channel, params: ZMaskExportParams): + def export_2d_patches_for_annotation(self, where, prefix, channel, params: RoiSetExportParams): if not self.count: return files = export_multichannel_patches_from_zstack( @@ -78,7 +78,7 @@ class ZMaskObjectTable(object): self.df = pd.merge(self.df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index') self.df['patch_id'] = self.df.apply(lambda _: uuid4(), axis=1) - def export_2d_patches_for_training(self, where, prefix, channel, params: ZMaskExportParams): + def export_2d_patches_for_training(self, where, prefix, channel, params: RoiSetExportParams): if not self.count: return files = export_multichannel_patches_from_zstack( @@ -102,7 +102,7 @@ class ZMaskObjectTable(object): self.df = pd.merge(self.df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index') self.df['patch_id'] = self.df.apply(lambda _: uuid4(), axis=1) - def export_2d_patches_for_annotation(self, where, prefix, channel, params: ZMaskExportParams): + def export_2d_patches_for_annotation(self, where, prefix, channel, params: RoiSetExportParams): if not self.count: return files = export_multichannel_patches_from_zstack( @@ -115,7 +115,7 @@ class ZMaskObjectTable(object): focus_metric=params.focus_metric, ) - def export_patch_masks(self, where, prefix, channel, params: ZMaskExportParams): + def export_patch_masks(self, where, prefix, channel, params: RoiSetExportParams): if not self.count: return files = export_patch_masks_from_zstack( @@ -125,7 +125,7 @@ class ZMaskObjectTable(object): prefix=prefix, ) - def export_annotated_zstack(self, where, prefix, channel, params: ZMaskExportParams): + def export_annotated_zstack(self, where, prefix, channel, params: RoiSetExportParams): annotated = InMemoryDataAccessor( draw_boxes_on_3d_image( self.acc_raw.get_one_channel_data(channel).data, @@ -152,6 +152,19 @@ class ZMaskObjectTable(object): projected = self.acc_raw.data.max(axis=-1) return projected + def classify_by(self, column_name, object_classification_model, patch_basis=True): + # add one column to df where each value is an integer + pass + + def get_raw_patches(self, filters: RoiFilter): + pass + + def get_patch_masks(self, filters: RoiFilter): + pass + + def get_object_map(self, filters: RoiFilter): + pass + def build_zmask_from_object_mask( obmask: GenericImageDataAccessor, zstack: GenericImageDataAccessor, -- GitLab