From ecf4bdfed8789aa2571b8b360414fdc1683bf29c Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 16 Oct 2023 16:04:16 +0200
Subject: [PATCH] Populate input data in object classification model

---
 ...fer_labels_to_ilastik_object_classifier.py | 107 +++++++++++-------
 extensions/chaeo/util.py                      |   8 ++
 extensions/chaeo/workflows.py                 |   4 +
 3 files changed, 80 insertions(+), 39 deletions(-)

diff --git a/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py b/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py
index d528f86b..14c4a985 100644
--- a/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py
+++ b/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py
@@ -1,3 +1,4 @@
+import shutil
 from pathlib import Path
 import h5py
 import json
@@ -5,10 +6,12 @@ import numpy as np
 import pandas as pd
 import uuid
 
+from extensions.chaeo.util import autonumber_new_file
 from model_server.accessors import generate_file_accessor
 
-def get_dataset_info(h5):
-    lane = 'Input Data/infos/lane0000'
+def get_dataset_info(h5, lane=0):
+    lns = f'{lane:04d}'
+    lane = f'Input Data/infos/lane{lns}'
     info = {}
     for gk in ['Raw Data', 'Segmentation Image']:
         info[gk] = {}
@@ -23,9 +26,14 @@ def get_dataset_info(h5):
             info[gk]['id'] = '<invalid UUID>'
         info[gk]['axistags'] = json.loads(h5[f'{lane}/{gk}/axistags'][()].decode())
         info[gk]['axes'] = [ax['key'] for ax in info[gk]['axistags']['axes']]
+
+    obj_cl_group = h5[f'ObjectClassification/LabelInputs/{lns}']
+    info['misc'] = {
+        'number_of_label_inputs': len(obj_cl_group.items())
+    }
     return info
 
-def transfer_labels_to_ilastik_ilp(ilp, df_stack_meta):
+def transfer_labels_to_ilastik_ilp(ilp, df_stack_meta, dump_csv=False):
 
     with h5py.File(ilp, 'r+') as h5:
         # TODO: force make copy if ilp file starts with template_
@@ -33,12 +41,12 @@ def transfer_labels_to_ilastik_ilp(ilp, df_stack_meta):
         where_out = Path(ilp).parent
 
         # export complete HDF5 tree
-        with open(where_out / 'h5tree.txt', 'w') as hf:
-            tt = []
-            h5.visititems(lambda k, v: tt.append([k, str(v)]))
-            for line in tt:
-                hf.write(f'{line[0]} --- {line[1]}\n')
-        h5.visititems(lambda k, v: print(k + ' : ' + str(v)))
+        if dump_csv:
+            with open(where_out / 'h5tree.txt', 'w') as hf:
+                tt = []
+                h5.visititems(lambda k, v: tt.append([k, str(v)]))
+                for line in tt:
+                    hf.write(f'{line[0]} --- {line[1]}\n')
 
         # put certain h5 groups in scope
         h5info = get_dataset_info(h5)
@@ -60,57 +68,78 @@ def transfer_labels_to_ilastik_ilp(ilp, df_stack_meta):
             ds[1] = float(df_stack_meta.loc[df_stack_meta.zi == idx, 'annotation_class_id'].iat[0])
             print(f'Changed label {ti} from {la_old} to {ds[1]}')
 
-def generate_ilastik_object_classifier(template_ilp, where_training: str):
+def generate_ilastik_object_classifier(template_ilp, where: str, lane=0):
 
     # validate z-stack input data
-    where = Path(where_training)
-    zstacks = {
-        'Raw Data': {
-            'path': where / 'zstack_train_raw.tif',
-        },
-        'Segmentation Image': {
-            'path': where / 'zstack_train_mask.tif',
-        }
+    root = Path(where)
+    paths = {
+        'Raw Data': root / 'zstack_train_raw.tif',
+        'Segmentation Image': root / 'zstack_train_mask.tif',
     }
 
-    for k, v in zstacks.items():
-        v['acc'] = generate_file_accessor(v['path'])
-
-    assert zstacks['Segmentation Image']['acc'].is_mask()
+    accessors = {k: generate_file_accessor(pa) for k, pa in paths.items()}
 
-    assert len(set([v['acc'].hw for k, v in zstacks.items()])) == 1  # same height and width
-    assert len(set([v['acc'].nz for k, v in zstacks.items()])) == 1  # same z-depth
+    assert accessors['Raw Data'].chroma == 1
+    assert accessors['Segmentation Image'].is_mask()
+    assert len(set([a.hw for a in accessors.values()])) == 1  # same height and width
+    assert len(set([a.nz for a in accessors.values()])) == 1  # same z-depth
+    nz = accessors['Raw Data'].nz
 
     # now load CSV
-    csv_path = where / 'train_stack.csv'
+    csv_path = root / 'train_stack.csv'
     assert csv_path.exists()
-    df_meta = pd.read_csv(csv_path)
+    df_patches = pd.read_csv(root / 'train_stack.csv')
     assert np.all(
-        df_meta['zi'].sort_values().to_numpy() == np.arange(0, zstacks['Raw Data']['acc'].nz)
+        df_patches['zi'].sort_values().to_numpy() == np.arange(0, nz)
     )
+    df_labels = pd.read_csv(root / 'labels_key.csv')
+    label_names = list(df_labels.sort_values('annotation_class_id').annotation_class.unique())
+    label_names[0] = 'none'
+    assert len(label_names) >= 2
 
-    with h5py.File(template_ilp, 'r+') as h5:
+    # open, validate, and copy template project file
+    with h5py.File(template_ilp, 'r') as h5:
         info = get_dataset_info(h5)
 
+        for hg in ['Raw Data', 'Segmentation Image']:
+            assert info[hg]['location'] == b'FileSystem'
+            assert info[hg]['axes'] == ['t', 'y', 'x']
+
+    new_ilp = shutil.copy(template_ilp, root / autonumber_new_file(root, 'auto-obj', 'ilp'))
+
+    # write to new project file
+    lns = f'{lane:04d}'
+    with h5py.File(new_ilp, 'r+') as h5:
         def set_ds(grp, ds, val):
-            ds = h5[f'Input Data/infos/lane0000/{grp}/{ds}']
+            ds = h5[f'Input Data/infos/lane{lns}/{grp}/{ds}']
             ds[()] = val
             return ds[()]
 
+        def get_label(idx):
+            return df_patches.loc[df_patches.zi == idx, 'annotation_class_id'].iat[0]
+
         for hg in ['Raw Data', 'Segmentation Image']:
-            assert info[hg]['location'] == b'FileSystem'
-            assert info[hg]['axes'] == ['t', 'y', 'x']
-            set_ds(hg, 'filePath', zstacks[hg]['path'].__str__())
-            set_ds(hg, 'nickname', zstacks[hg]['path'].stem)
-            shape_zyx = [zstacks[hg]['acc'].shape_dict[ax] for ax in ['Z', 'Y', 'X']]
+            set_ds(hg, 'filePath', paths[hg].__str__())
+            set_ds(hg, 'nickname', paths[hg].stem)
+            shape_zyx = [accessors[hg].shape_dict[ax] for ax in ['Z', 'Y', 'X']]
             set_ds(hg, 'shape', np.array(shape_zyx))
-        new_info = get_dataset_info(h5)
 
+        # change key of label names
+        del h5['ObjectClassification/LabelNames']
+        ln = np.array(label_names)
+        h5.create_dataset('ObjectClassification/LabelNames', data=ln.astype('O'))
+
+        # change object labels
+        la_groupname = f'ObjectClassification/LabelInputs/{lns}'
+        # la_group = h5[la_groupname]
+
+        del h5[la_groupname]
+        lag = h5.create_group(la_groupname)
+        for zi in range(0, nz):
+            lag[f'{zi}'] = np.array([0., float(get_label(zi))])
 
 if __name__ == '__main__':
-    ilp = 'c:/Users/rhodes/model-server/ilastik/test_autolabel_obj - Copy.ilp'
-
     generate_ilastik_object_classifier(
-        'c:/Users/rhodes/model-server/ilastik/test_template_obj.ilp',
-        'c:/Users/rhodes/projects/proj0011-plankton-seg/exp0009/output/labeled_patches-20231014-0004'
+        'c:/Users/rhodes/projects/proj0011-plankton-seg/exp0014/template_obj.ilp',
+        'c:/Users/rhodes/projects/proj0011-plankton-seg/exp0009/output/labeled_patches-20231016-0002'
     )
\ No newline at end of file
diff --git a/extensions/chaeo/util.py b/extensions/chaeo/util.py
index 5432a95c..9bed0644 100644
--- a/extensions/chaeo/util.py
+++ b/extensions/chaeo/util.py
@@ -18,6 +18,14 @@ def autonumber_new_directory(where: str, prefix: str) -> str:
     new_path.mkdir(parents=True, exist_ok=False)
     return new_path.__str__()
 
+def autonumber_new_file(where: str, prefix: str, ext: str) -> str:
+    idx = 0
+    for ff in Path(where).iterdir():
+        ma = re.match(f'{prefix}-([\d]+).{ext}', ff.name)
+        if ma:
+            idx = max(idx, int(ma.groups()[0]) + 1)
+    return f'{prefix}-{idx:04d}.{ext}'
+
 def get_matching_files(where: str, ext: str, coord_filter: dict={}) -> str:
     files = []
 
diff --git a/extensions/chaeo/workflows.py b/extensions/chaeo/workflows.py
index d949323d..5ef2ec16 100644
--- a/extensions/chaeo/workflows.py
+++ b/extensions/chaeo/workflows.py
@@ -4,6 +4,7 @@ from uuid import uuid4
 
 import numpy as np
 import pandas as pd
+from skimage.morphology import dilation
 from sklearn.model_selection import train_test_split
 
 from extensions.ilastik.models import IlastikPixelClassifierModel
@@ -176,6 +177,7 @@ def transfer_ecotaxa_labels_to_patch_stacks(
         where_output: str,
         patch_size: tuple = (256, 256),
         tr_split=0.6,
+        dilate_label_mask: bool = True, # to mitigate connected components error in ilastik
 ) -> Dict:
     assert tr_split > 0.5 # reduce chance that low-probability objects are omitted from training
 
@@ -255,6 +257,8 @@ def transfer_ecotaxa_labels_to_patch_stacks(
             assert acc_bm.chroma == 1
             assert acc_bm.nz == 1
             mask = acc_bm.data[:, :, 0, 0]
+            if dilate_label_mask:
+                mask = dilation(mask)
             zstacks[dfk + '_mask'][:, :, 0, fi] = mask
             zstacks[dfk + '_label'][:, :, 0, fi] = (mask == 255) * aci
 
-- 
GitLab