From 03e748c198b49fe4b3d5f5ac5dd079499f491c70 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Wed, 17 Jul 2024 13:13:49 +0200
Subject: [PATCH] Changed names of static methods

---
 model_server/base/roiset.py        | 35 +++++++++++++++---------------
 tests/base/test_roiset.py          | 14 ++++++------
 tests/test_ilastik/test_ilastik.py |  4 ++--
 3 files changed, 26 insertions(+), 27 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 70e73469..74d85ace 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -62,8 +62,7 @@ class RoiSetExportParams(BaseModel):
     derived_channels: bool = False
 
 
-
-def _get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, connect_3d=True) -> InMemoryDataAccessor:
+def get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, connect_3d=True) -> InMemoryDataAccessor:
     """
     Convert binary segmentation mask into either a 2D or 3D object identities map
     :param acc_seg_mask: binary segmentation mask (mono) of either two or three dimensions
@@ -98,7 +97,7 @@ def _get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, conne
         )
 
 
-def _focus_metrics():
+def focus_metrics():
     return {
         'max_intensity': lambda x: np.max(x),
         'stdev': lambda x: np.std(x),
@@ -109,7 +108,7 @@ def _focus_metrics():
     }
 
 
-def _filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame:
+def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame:
     query_str = 'label > 0'  # always true
     if filters is not None:  # parse filters
         for k, val in filters.dict(exclude_unset=True).items():
@@ -122,7 +121,7 @@ def _filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame:
 
 
 # TODO: get overlapping bounding boxes
-def _filter_overlap_bbox(df: pd.DataFrame) -> pd.DataFrame:
+def filter_overlap_bbox(df: pd.DataFrame) -> pd.DataFrame:
 
     def _compare(r0, r1):
         olx = (r0.x0 < r1.x1) and (r0.x1 > r1.x0)
@@ -141,7 +140,7 @@ def _filter_overlap_bbox(df: pd.DataFrame) -> pd.DataFrame:
     sdf['overlaps_with'] = second
     return sdf
 
-def _make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame:
+def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame:
     """
     Build dataframe associate object IDs with summary stats
     :param acc_raw: accessor to raw image data
@@ -169,7 +168,7 @@ def _make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFram
         })
         df['zi'] = df['label'].apply(lambda x: (acc_obj_ids.data == x).sum(axis=(0, 1, 2)).argmax())
 
-    df = _df_insert_slices(df, acc_raw.shape_dict, expand_box_by)
+    df = df_insert_slices(df, acc_raw.shape_dict, expand_box_by)
 
     # TODO: make this contingent on whether seg is included
     def _make_binary_mask(r):
@@ -185,7 +184,7 @@ def _make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFram
     return df
 
 
-def _df_insert_slices(df: pd.DataFrame, sd: dict, expand_box_by) -> pd.DataFrame:
+def df_insert_slices(df: pd.DataFrame, sd: dict, expand_box_by) -> pd.DataFrame:
     h = sd['Y']
     w = sd['X']
     nz = sd['Z']
@@ -233,7 +232,7 @@ def _df_insert_slices(df: pd.DataFrame, sd: dict, expand_box_by) -> pd.DataFrame
     return df
 
 
-def _safe_add(a, g, b):
+def safe_add(a, g, b):
     assert a.dtype == b.dtype
     assert a.shape == b.shape
     assert g >= 0.0
@@ -244,7 +243,7 @@ def _safe_add(a, g, b):
         np.iinfo(a.dtype).max
     ).astype(a.dtype)
 
-def _make_object_ids_from_df(df: pd.DataFrame, sd: dict) -> InMemoryDataAccessor:
+def make_object_ids_from_df(df: pd.DataFrame, sd: dict) -> InMemoryDataAccessor:
     id_mask = np.zeros((sd['Y'], sd['X'], 1, sd['Z']), dtype='uint16')
     if 'binary_mask' not in df.columns:
         raise MissingSegmentationError('RoiSet dataframe does not contain segmentation')
@@ -305,8 +304,8 @@ class RoiSet(object):
     ):
         assert acc_obj_ids.chroma == 1
 
-        df = _filter_df(
-            _make_df_from_object_ids(
+        df = filter_df(
+            make_df_from_object_ids(
                 acc_raw, acc_obj_ids, expand_box_by=params.expand_box_by
             ),
             params.filters,
@@ -333,7 +332,7 @@ class RoiSet(object):
         :param params: optional arguments that influence the definition and representation of ROIs
         :return: object identities map
         """
-        return RoiSet.from_object_ids(acc_raw, _get_label_ids(acc_seg, allow_3d=allow_3d, connect_3d=connect_3d), params)
+        return RoiSet.from_object_ids(acc_raw, get_label_ids(acc_seg, allow_3d=allow_3d, connect_3d=connect_3d), params)
 
 
     # TODO: generate overlapping RoiSet from multiple masks
@@ -347,7 +346,7 @@ class RoiSet(object):
 
     # TODO: get overlapping segments
     def get_overlap_seg(self) -> pd.DataFrame:
-        dfbb = _filter_overlap_bbox(self._df)
+        dfbb = filter_overlap_bbox(self._df)
         def _iou(roi_i):
             roi1 = self._df.loc[roi_i.index]
             roi2 = self._df.loc[roi_i.overlaps_with]
@@ -568,7 +567,7 @@ class RoiSet(object):
                     continue
                 assert isinstance(ci, int)
                 assert ci < raw.chroma
-                stack[:, :, ii, :] = _safe_add(
+                stack[:, :, ii, :] = safe_add(
                     stack[:, :, ii, :],  # either black or grayscale channel
                     rgb_overlay_weights[ii],
                     raw.data[:, :, ci, :]
@@ -610,7 +609,7 @@ class RoiSet(object):
 
             # make a 2d patch, find optimal z-position determined by focus_metric function on each channel separately
             elif focus_metric is not None:
-                foc = _focus_metrics()[focus_metric]
+                foc = focus_metrics()[focus_metric]
 
                 patch = np.zeros([ph, pw, pc, 1], dtype=patch3d.dtype)
 
@@ -775,7 +774,7 @@ class RoiSet(object):
 
     @property
     def acc_obj_ids(self):
-        return _make_object_ids_from_df(self._df, self.acc_raw.shape_dict)
+        return make_object_ids_from_df(self._df, self.acc_raw.shape_dict)
 
     # TODO: add docstring
     # TODO: make this work with obj det dataset
@@ -794,7 +793,7 @@ class RoiSet(object):
             except Exception as e:
                 raise DeserializeRoiSet(e)
         df['binary_mask'] = df.apply(_read_binary_mask, axis=1)
-        id_mask = _make_object_ids_from_df(df, acc_raw.shape_dict)
+        id_mask = make_object_ids_from_df(df, acc_raw.shape_dict)
         return RoiSet.from_object_ids(acc_raw, id_mask)
 
 
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index bb428b5b..faf53d61 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -7,7 +7,7 @@ from pathlib import Path
 
 import pandas as pd
 
-from model_server.base.roiset import _filter_overlap_bbox, RoiSetExportParams, RoiSetMetaParams
+from model_server.base.roiset import filter_overlap_bbox, RoiSetExportParams, RoiSetMetaParams
 from model_server.base.roiset import RoiSet
 from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file, PatchStack
 from model_server.base.models import DummyInstanceSegmentationModel
@@ -70,7 +70,7 @@ class TestOverlapLogic(unittest.TestCase):
         ]
 
     def test_overlap_bbox(self):
-        res = _filter_overlap_bbox(self.df)
+        res = filter_overlap_bbox(self.df)
         self.assertEqual(len(res), 2)
         self.assertTrue((res.loc[0, 'overlaps_with'] == 1).all())
         self.assertTrue((res.loc[1, 'overlaps_with'] == 2).all())
@@ -560,7 +560,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
             self.assertEqual(pacc.chroma, 1)
 
 
-from model_server.base.roiset import _get_label_ids
+from model_server.base.roiset import get_label_ids
 class TestRoiSetSerialization(unittest.TestCase):
 
     def setUp(self) -> None:
@@ -577,19 +577,19 @@ class TestRoiSetSerialization(unittest.TestCase):
         return mask_3d.sum() == mask_mip.sum()
 
     def test_id_map_connects_z(self):
-        id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=True)
+        id_map = get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=True)
         labels = np.unique(id_map.data)[1:]
         is_2d = all([self._label_is_2d(id_map.data, la) for la in labels])
         self.assertFalse(is_2d)
 
     def test_id_map_disconnects_z(self):
-        id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False)
+        id_map = get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False)
         labels = np.unique(id_map.data)[1:]
         is_2d = all([self._label_is_2d(id_map.data, la) for la in labels])
         self.assertTrue(is_2d)
 
     def test_create_roiset_from_3d_obj_ids(self):
-        id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False)
+        id_map = get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False)
         self.assertEqual(self.stack_ch_pa.shape, id_map.shape)
 
         roiset = RoiSet.from_object_ids(
@@ -601,7 +601,7 @@ class TestRoiSetSerialization(unittest.TestCase):
         self.assertGreater(len(roiset.get_df()['zi'].unique()), 1)
 
     def test_create_roiset_from_2d_obj_ids(self):
-        id_map = _get_label_ids(self.seg_mask_3d, allow_3d=False)
+        id_map = get_label_ids(self.seg_mask_3d, allow_3d=False)
         self.assertEqual(self.stack_ch_pa.shape[0:3], id_map.shape[0:3])
         self.assertEqual(id_map.nz, 1)
 
diff --git a/tests/test_ilastik/test_ilastik.py b/tests/test_ilastik/test_ilastik.py
index bd2be94a..35a41813 100644
--- a/tests/test_ilastik/test_ilastik.py
+++ b/tests/test_ilastik/test_ilastik.py
@@ -5,7 +5,7 @@ import numpy as np
 from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, PatchStack, write_accessor_data_to_file
 from model_server.extensions.ilastik import models as ilm
 from model_server.extensions.ilastik.workflows import infer_px_then_ob_model
-from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams
+from model_server.base.roiset import get_label_ids, RoiSet, RoiSetMetaParams
 from model_server.base.workflows import classify_pixels
 import model_server.conf.testing as conf
 
@@ -403,7 +403,7 @@ class TestIlastikObjectClassification(unittest.TestCase):
 
         self.roiset = RoiSet(
             stack_ch_pa,
-            _get_label_ids(seg_mask),
+            get_label_ids(seg_mask),
             params=RoiSetMetaParams(
                 mask_type='boxes',
                 filters={'area': {'min': 1e3, 'max': 1e4}},
-- 
GitLab