From 03e748c198b49fe4b3d5f5ac5dd079499f491c70 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Wed, 17 Jul 2024 13:13:49 +0200 Subject: [PATCH] Changed names of static methods --- model_server/base/roiset.py | 35 +++++++++++++++--------------- tests/base/test_roiset.py | 14 ++++++------ tests/test_ilastik/test_ilastik.py | 4 ++-- 3 files changed, 26 insertions(+), 27 deletions(-) diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 70e73469..74d85ace 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -62,8 +62,7 @@ class RoiSetExportParams(BaseModel): derived_channels: bool = False - -def _get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, connect_3d=True) -> InMemoryDataAccessor: +def get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, connect_3d=True) -> InMemoryDataAccessor: """ Convert binary segmentation mask into either a 2D or 3D object identities map :param acc_seg_mask: binary segmentation mask (mono) of either two or three dimensions @@ -98,7 +97,7 @@ def _get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, conne ) -def _focus_metrics(): +def focus_metrics(): return { 'max_intensity': lambda x: np.max(x), 'stdev': lambda x: np.std(x), @@ -109,7 +108,7 @@ def _focus_metrics(): } -def _filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame: +def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame: query_str = 'label > 0' # always true if filters is not None: # parse filters for k, val in filters.dict(exclude_unset=True).items(): @@ -122,7 +121,7 @@ def _filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame: # TODO: get overlapping bounding boxes -def _filter_overlap_bbox(df: pd.DataFrame) -> pd.DataFrame: +def filter_overlap_bbox(df: pd.DataFrame) -> pd.DataFrame: def _compare(r0, r1): olx = (r0.x0 < r1.x1) and (r0.x1 > r1.x0) @@ -141,7 +140,7 @@ def _filter_overlap_bbox(df: pd.DataFrame) -> pd.DataFrame: sdf['overlaps_with'] = second return sdf -def _make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame: +def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame: """ Build dataframe associate object IDs with summary stats :param acc_raw: accessor to raw image data @@ -169,7 +168,7 @@ def _make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFram }) df['zi'] = df['label'].apply(lambda x: (acc_obj_ids.data == x).sum(axis=(0, 1, 2)).argmax()) - df = _df_insert_slices(df, acc_raw.shape_dict, expand_box_by) + df = df_insert_slices(df, acc_raw.shape_dict, expand_box_by) # TODO: make this contingent on whether seg is included def _make_binary_mask(r): @@ -185,7 +184,7 @@ def _make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFram return df -def _df_insert_slices(df: pd.DataFrame, sd: dict, expand_box_by) -> pd.DataFrame: +def df_insert_slices(df: pd.DataFrame, sd: dict, expand_box_by) -> pd.DataFrame: h = sd['Y'] w = sd['X'] nz = sd['Z'] @@ -233,7 +232,7 @@ def _df_insert_slices(df: pd.DataFrame, sd: dict, expand_box_by) -> pd.DataFrame return df -def _safe_add(a, g, b): +def safe_add(a, g, b): assert a.dtype == b.dtype assert a.shape == b.shape assert g >= 0.0 @@ -244,7 +243,7 @@ def _safe_add(a, g, b): np.iinfo(a.dtype).max ).astype(a.dtype) -def _make_object_ids_from_df(df: pd.DataFrame, sd: dict) -> InMemoryDataAccessor: +def make_object_ids_from_df(df: pd.DataFrame, sd: dict) -> InMemoryDataAccessor: id_mask = np.zeros((sd['Y'], sd['X'], 1, sd['Z']), dtype='uint16') if 'binary_mask' not in df.columns: raise MissingSegmentationError('RoiSet dataframe does not contain segmentation') @@ -305,8 +304,8 @@ class RoiSet(object): ): assert acc_obj_ids.chroma == 1 - df = _filter_df( - _make_df_from_object_ids( + df = filter_df( + make_df_from_object_ids( acc_raw, acc_obj_ids, expand_box_by=params.expand_box_by ), params.filters, @@ -333,7 +332,7 @@ class RoiSet(object): :param params: optional arguments that influence the definition and representation of ROIs :return: object identities map """ - return RoiSet.from_object_ids(acc_raw, _get_label_ids(acc_seg, allow_3d=allow_3d, connect_3d=connect_3d), params) + return RoiSet.from_object_ids(acc_raw, get_label_ids(acc_seg, allow_3d=allow_3d, connect_3d=connect_3d), params) # TODO: generate overlapping RoiSet from multiple masks @@ -347,7 +346,7 @@ class RoiSet(object): # TODO: get overlapping segments def get_overlap_seg(self) -> pd.DataFrame: - dfbb = _filter_overlap_bbox(self._df) + dfbb = filter_overlap_bbox(self._df) def _iou(roi_i): roi1 = self._df.loc[roi_i.index] roi2 = self._df.loc[roi_i.overlaps_with] @@ -568,7 +567,7 @@ class RoiSet(object): continue assert isinstance(ci, int) assert ci < raw.chroma - stack[:, :, ii, :] = _safe_add( + stack[:, :, ii, :] = safe_add( stack[:, :, ii, :], # either black or grayscale channel rgb_overlay_weights[ii], raw.data[:, :, ci, :] @@ -610,7 +609,7 @@ class RoiSet(object): # make a 2d patch, find optimal z-position determined by focus_metric function on each channel separately elif focus_metric is not None: - foc = _focus_metrics()[focus_metric] + foc = focus_metrics()[focus_metric] patch = np.zeros([ph, pw, pc, 1], dtype=patch3d.dtype) @@ -775,7 +774,7 @@ class RoiSet(object): @property def acc_obj_ids(self): - return _make_object_ids_from_df(self._df, self.acc_raw.shape_dict) + return make_object_ids_from_df(self._df, self.acc_raw.shape_dict) # TODO: add docstring # TODO: make this work with obj det dataset @@ -794,7 +793,7 @@ class RoiSet(object): except Exception as e: raise DeserializeRoiSet(e) df['binary_mask'] = df.apply(_read_binary_mask, axis=1) - id_mask = _make_object_ids_from_df(df, acc_raw.shape_dict) + id_mask = make_object_ids_from_df(df, acc_raw.shape_dict) return RoiSet.from_object_ids(acc_raw, id_mask) diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index bb428b5b..faf53d61 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -7,7 +7,7 @@ from pathlib import Path import pandas as pd -from model_server.base.roiset import _filter_overlap_bbox, RoiSetExportParams, RoiSetMetaParams +from model_server.base.roiset import filter_overlap_bbox, RoiSetExportParams, RoiSetMetaParams from model_server.base.roiset import RoiSet from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file, PatchStack from model_server.base.models import DummyInstanceSegmentationModel @@ -70,7 +70,7 @@ class TestOverlapLogic(unittest.TestCase): ] def test_overlap_bbox(self): - res = _filter_overlap_bbox(self.df) + res = filter_overlap_bbox(self.df) self.assertEqual(len(res), 2) self.assertTrue((res.loc[0, 'overlaps_with'] == 1).all()) self.assertTrue((res.loc[1, 'overlaps_with'] == 2).all()) @@ -560,7 +560,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa self.assertEqual(pacc.chroma, 1) -from model_server.base.roiset import _get_label_ids +from model_server.base.roiset import get_label_ids class TestRoiSetSerialization(unittest.TestCase): def setUp(self) -> None: @@ -577,19 +577,19 @@ class TestRoiSetSerialization(unittest.TestCase): return mask_3d.sum() == mask_mip.sum() def test_id_map_connects_z(self): - id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=True) + id_map = get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=True) labels = np.unique(id_map.data)[1:] is_2d = all([self._label_is_2d(id_map.data, la) for la in labels]) self.assertFalse(is_2d) def test_id_map_disconnects_z(self): - id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False) + id_map = get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False) labels = np.unique(id_map.data)[1:] is_2d = all([self._label_is_2d(id_map.data, la) for la in labels]) self.assertTrue(is_2d) def test_create_roiset_from_3d_obj_ids(self): - id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False) + id_map = get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False) self.assertEqual(self.stack_ch_pa.shape, id_map.shape) roiset = RoiSet.from_object_ids( @@ -601,7 +601,7 @@ class TestRoiSetSerialization(unittest.TestCase): self.assertGreater(len(roiset.get_df()['zi'].unique()), 1) def test_create_roiset_from_2d_obj_ids(self): - id_map = _get_label_ids(self.seg_mask_3d, allow_3d=False) + id_map = get_label_ids(self.seg_mask_3d, allow_3d=False) self.assertEqual(self.stack_ch_pa.shape[0:3], id_map.shape[0:3]) self.assertEqual(id_map.nz, 1) diff --git a/tests/test_ilastik/test_ilastik.py b/tests/test_ilastik/test_ilastik.py index bd2be94a..35a41813 100644 --- a/tests/test_ilastik/test_ilastik.py +++ b/tests/test_ilastik/test_ilastik.py @@ -5,7 +5,7 @@ import numpy as np from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, PatchStack, write_accessor_data_to_file from model_server.extensions.ilastik import models as ilm from model_server.extensions.ilastik.workflows import infer_px_then_ob_model -from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams +from model_server.base.roiset import get_label_ids, RoiSet, RoiSetMetaParams from model_server.base.workflows import classify_pixels import model_server.conf.testing as conf @@ -403,7 +403,7 @@ class TestIlastikObjectClassification(unittest.TestCase): self.roiset = RoiSet( stack_ch_pa, - _get_label_ids(seg_mask), + get_label_ids(seg_mask), params=RoiSetMetaParams( mask_type='boxes', filters={'area': {'min': 1e3, 'max': 1e4}}, -- GitLab