From c35f657eb812f906e204ce2f80b1a93767030607 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 9 Oct 2023 10:57:34 +0200
Subject: [PATCH] Batch-export 3d patches and table of focus metrics;
 optionally annotate bounding box at plane of best focus

---
 .../chaeo/examples/batch_run_patches.py       |   6 +-
 .../examples/export_patch_focus_metrics.py    | 142 ++++++++++++++++++
 extensions/chaeo/products.py                  |  89 ++++++++++-
 extensions/chaeo/workflows.py                 |  40 ++---
 4 files changed, 255 insertions(+), 22 deletions(-)
 create mode 100644 extensions/chaeo/examples/export_patch_focus_metrics.py

diff --git a/extensions/chaeo/examples/batch_run_patches.py b/extensions/chaeo/examples/batch_run_patches.py
index c831c08e..7d38aa13 100644
--- a/extensions/chaeo/examples/batch_run_patches.py
+++ b/extensions/chaeo/examples/batch_run_patches.py
@@ -1,4 +1,3 @@
-import json
 from pathlib import Path
 import re
 from time import localtime, strftime
@@ -10,7 +9,7 @@ from model_server.accessors import InMemoryDataAccessor, write_accessor_data_to_
 
 if __name__ == '__main__':
     where_czi = Path(
-        'z:/rhodes/projects/proj0004-marine-photoactivation/data/exp0038/AutoMic/20230906-163415/Selection'
+        'c:/Users/rhodes/projects/proj0004-marine-photoactivation/data/exp0038/AutoMic/20230906-163415/Selection'
     )
 
     where_output_root = Path(
@@ -67,5 +66,4 @@ if __name__ == '__main__':
             write_accessor_data_to_file(
                 where_output / k / (ff.stem + '.tif'),
                 InMemoryDataAccessor(result['interm'][k])
-            )
-        break
\ No newline at end of file
+            )
\ No newline at end of file
diff --git a/extensions/chaeo/examples/export_patch_focus_metrics.py b/extensions/chaeo/examples/export_patch_focus_metrics.py
new file mode 100644
index 00000000..36e2ab76
--- /dev/null
+++ b/extensions/chaeo/examples/export_patch_focus_metrics.py
@@ -0,0 +1,142 @@
+from pathlib import Path
+import re
+from time import localtime, strftime
+from typing import Dict
+
+import pandas as pd
+
+from extensions.ilastik.models import IlastikPixelClassifierModel
+from extensions.chaeo.products import export_3d_patches_with_focus_metrics
+from extensions.chaeo.zmask import build_zmask_from_object_mask
+from model_server.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
+from model_server.workflows import Timer
+
+
+def export_patch_focus_metrics_from_multichannel_zstack(
+        input_zstack_path: str,
+        ilastik_project_file: str,
+        pxmap_threshold: float,
+        pixel_class: int,
+        zmask_channel: int,
+        patches_channel: int,
+        where_output: str,
+        mask_type: str = 'boxes',
+        zmask_filters: Dict = None,
+        zmask_expand_box_by: int = None,
+        annotate_focus_metric=None,
+        **kwargs,
+) -> Dict:
+
+    ti = Timer()
+    stack = generate_file_accessor(Path(input_zstack_path))
+    fstem = Path(input_zstack_path).stem
+    ti.click('file_input')
+    assert stack.nz > 1, 'Expecting z-stack'
+
+    # MIP and classify pixels
+    mip = InMemoryDataAccessor(
+        stack.get_one_channel_data(channel=0).data.max(axis=-1, keepdims=True)
+    )
+    px_model = IlastikPixelClassifierModel(
+        params={'project_file': Path(ilastik_project_file)}
+    )
+    pxmap, _ = px_model.infer(mip)
+    ti.click('infer_pixel_probability')
+
+    obmask = InMemoryDataAccessor(
+        pxmap.data > pxmap_threshold
+    )
+    ti.click('threshold_pixel_mask')
+
+    # make zmask
+    zmask, zmask_meta, df, interm = build_zmask_from_object_mask(
+        obmask.get_one_channel_data(pixel_class),
+        stack.get_one_channel_data(zmask_channel),
+        mask_type=mask_type,
+        filters=zmask_filters,
+        expand_box_by=zmask_expand_box_by,
+    )
+    zmask_acc = InMemoryDataAccessor(zmask)
+    ti.click('generate_zmasks')
+
+    files = export_3d_patches_with_focus_metrics(
+        Path(where_output) / '3d_patches',
+        stack.get_one_channel_data(patches_channel),
+        zmask_meta,
+        prefix=fstem,
+        draw_bounding_box=False,
+        rescale_clip=0.0,
+        make_3d=True,
+        annotate_focus_metric=annotate_focus_metric,
+    )
+    ti.click('export_patches')
+
+    return {
+        'pixel_model_id': px_model.model_id,
+        'input_filepath': input_zstack_path,
+        'number_of_objects': len(zmask_meta),
+        'success': True,
+        'timer_results': ti.events,
+        'dataframe': df,
+        'interm': interm,
+    }
+
+if __name__ == '__main__':
+    where_czi = Path(
+        'c:/Users/rhodes/projects/proj0004-marine-photoactivation/data/exp0038/AutoMic/20230906-163415/Selection'
+    )
+
+    where_output_root = Path(
+        'c:/Users/rhodes/projects/proj0011-plankton-seg/exp0009/output'
+    )
+    yyyymmdd = strftime('%Y%m%d', localtime())
+    idx = 0
+    while Path(where_output_root / f'batch-output-{yyyymmdd}-{idx:04d}').exists():
+        idx += 1
+    where_output = Path(
+        where_output_root / f'batch-output-{yyyymmdd}-{idx:04d}'
+    )
+
+    csv_args = {'mode': 'w', 'header': True} # when creating file
+    px_ilp = Path.home() / 'model-server' / 'ilastik' / 'AF405-bodies_boundaries.ilp'
+
+    for ff in where_czi.iterdir():
+        pattern = 'Selection--W([\d]+)--P([\d]+)-T([\d]+)'
+        ma = re.match(pattern, ff.stem)
+
+        print(ff)
+        if not ff.suffix.upper() == '.CZI':
+            continue
+        if int(ma.groups()[1]) > 10: # skip second half of set
+            continue
+
+        export_kwargs = {
+            'input_zstack_path': (where_czi / ff).__str__(),
+            'ilastik_project_file': px_ilp.__str__(),
+            'pxmap_threshold': 0.25,
+            'pixel_class': 0,
+            'zmask_channel': 0,
+            'patches_channel': 4,
+            'where_output': where_output.__str__(),
+            'mask_type': 'boxes',
+            'zmask_filters': {'area': (1e3, 1e8)},
+            'zmask_expand_box_by': (128, 3),
+            'annotate_focus_metric': 'rms_sobel'
+        }
+
+        result = export_patch_focus_metrics_from_multichannel_zstack(**export_kwargs)
+
+        # parse and record results
+        df = result['dataframe']
+        df['filename'] = ff.name
+        df.to_csv(where_output / 'df_objects.csv', **csv_args)
+        pd.DataFrame(result['timer_results'], index=[0]).to_csv(where_output / 'timer_results.csv', **csv_args)
+        pd.json_normalize(export_kwargs).to_csv(where_output / 'workflow_params.csv', **csv_args)
+        csv_args = {'mode': 'a', 'header': False} # append to CSV from here on
+
+        # export intermediate data if flagged
+        for k in result['interm'].keys():
+            write_accessor_data_to_file(
+                where_output / k / (ff.stem + '.tif'),
+                InMemoryDataAccessor(result['interm'][k])
+            )
\ No newline at end of file
diff --git a/extensions/chaeo/products.py b/extensions/chaeo/products.py
index e7a12e76..df2ac7dc 100644
--- a/extensions/chaeo/products.py
+++ b/extensions/chaeo/products.py
@@ -1,8 +1,12 @@
+from math import sqrt
 from pathlib import Path
 
 import numpy as np
-
+import pandas as pd
+from scipy.stats import moment
+from skimage.filters import sobel
 from skimage.io import imsave
+from skimage.measure import shannon_entropy
 from tifffile import imwrite
 
 from extensions.chaeo.annotators import draw_box_on_patch
@@ -80,3 +84,86 @@ def export_patches_from_zstack(
         _write_patch_to_file(where, fname, resample_to_8bit(patch))
         exported.append(fname)
     return exported
+
+def export_3d_patches_with_focus_metrics(
+        where: Path,
+        stack: GenericImageDataAccessor,
+        zmask_meta: list,
+        rescale_clip: float = 0.0,
+        pad_to: int = 256,
+        prefix='patch',
+        **kwargs
+):
+    """
+    Export 3D patches as multi-level z-stacks, along with CSV of various focus methods for each z-position
+
+    :param kwargs:
+        annotate_focus_metric: name focus metric to use when drawing bounding box at optimal focus z-position
+    :return:
+        list of exported files
+    """
+    assert stack.chroma == 1, 'Expecting monochromatic image data'
+    assert stack.nz > 1, 'Expecting z-stack'
+
+    def get_zstack_focus_metrics(zs):
+        nz = zs.shape[3]
+        me = {
+            'max_intensity': lambda x: np.max(x),
+            'stdev': lambda x: np.std(x),
+            'max_sobel': lambda x: np.max(sobel(x)),
+            'rms_sobel': lambda x: sqrt(np.mean(sobel(x) ** 2)),
+            'entropy': lambda x: shannon_entropy(x),
+            'moment': lambda x: moment(x, moment=2),
+        }
+        dd = {}
+        for zi in range(0, nz):
+            spf = zs[:, :, :, zi].flatten()
+            dd[zi] = {k: me[k](spf) for k in me.keys()}
+        return dd
+
+    exported = []
+    patch_meta = []
+    for mi in zmask_meta:
+        obj = mi['info']
+        sl = mi['slice']
+        rbb = mi['relative_bounding_box']
+
+        patch = stack.data[sl]
+
+        assert len(patch.shape) == 4
+        assert patch.shape[2] == stack.chroma
+
+        if rescale_clip is not None:
+            patch = rescale(patch, rescale_clip)
+
+        # unpack relative bounding box and define subset of patch data
+        x0 = rbb['x0']
+        y0 = rbb['y0']
+        x1 = rbb['x1']
+        y1 = rbb['y1']
+        sp_sl = np.s_[y0: y1, x0: x1, :, :]
+        subpatch = patch[sp_sl]
+
+        # compute focus metrics for all z-levels
+        me_dict = get_zstack_focus_metrics(subpatch)
+        patch_meta.append({'label': obj.label, 'zi': obj.zi, 'metrics': me_dict})
+        me_df = pd.DataFrame(me_dict).T
+
+        # drawing bounding box only on focused slice
+        if ak := kwargs.get('annotate_focus_metric') in me_dict.keys():
+            zi_foc = me_df.idxmax().to_dict()[ak]
+            patch[:, :, 0, zi_foc] = draw_box_on_patch(
+                patch[:, :, 0, zi_foc],
+                ((x0, y0), (x1, y1)),
+            )
+
+        if pad_to:
+            patch = pad(patch, pad_to)
+
+        fstem = f'{prefix}-la{obj.label:04d}-zi{obj.zi:04d}'
+        _write_patch_to_file(where, fstem + '.tif', resample_to_8bit(patch))
+        exported.append(fstem + '.tif')
+        me_df.to_csv(where / (fstem + '.csv'))
+        exported.append(fstem + '.csv')
+
+    return exported
diff --git a/extensions/chaeo/workflows.py b/extensions/chaeo/workflows.py
index 30c2221b..b0d6287a 100644
--- a/extensions/chaeo/workflows.py
+++ b/extensions/chaeo/workflows.py
@@ -64,20 +64,26 @@ def export_patches_from_multichannel_zstack(
     import numpy as np
     from skimage.filters import gaussian, sobel
 
-    def zs_projector(zs):
-        sigma = 1.5
-        blur = gaussian(sobel(zs), sigma)
-        argmax = np.argmax(blur, axis=3, keepdims=True)
-        return np.take_along_axis(zs, argmax, axis=3)
+    # def zs_projector(zs):
+    #     sigma = 1.5
+    #     blur = gaussian(sobel(zs), sigma)
+    #     argmax = np.argmax(blur, axis=3, keepdims=True)
+    #     return np.take_along_axis(zs, argmax, axis=3)
+    #
+    # def zs_annotate_best_focus(zs):
+    #     pass
 
     files = export_patches_from_zstack(
-        Path(where_output) / '2d_patches',
+        # Path(where_output) / '2d_patches',
+        Path(where_output) / '3d_patches',
         stack.get_one_channel_data(patches_channel),
         zmask_meta,
         prefix=fstem,
-        draw_bounding_box=True,
+        # draw_bounding_box=True,
+        draw_bounding_box=False,
         rescale_clip=0.0,
-        projector=zs_projector,
+        # projector=zs_projector,
+        make_3d=True,
     )
     ti.click('export_patches')
 
@@ -94,15 +100,15 @@ def export_patches_from_multichannel_zstack(
     )
     ti.click('export_annotated_zstack')
 
-    # generate multichannel projection from label centroids
-    dff = df[df['keeper']]
-    interm['projected'] = project_stack_from_focal_points(
-        dff['centroid-0'].to_numpy(),
-        dff['centroid-1'].to_numpy(),
-        dff['zi'].to_numpy(),
-        stack,
-        degree=4,
-    )
+    # # generate multichannel projection from label centroids
+    # dff = df[df['keeper']]
+    # interm['projected'] = project_stack_from_focal_points(
+    #     dff['centroid-0'].to_numpy(),
+    #     dff['centroid-1'].to_numpy(),
+    #     dff['zi'].to_numpy(),
+    #     stack,
+    #     degree=4,
+    # )
 
     return {
         'pixel_model_id': px_model.model_id,
-- 
GitLab