From da091a5d7200bf1b6ed5ac37eb952de664b4908d Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 13 Oct 2023 10:42:49 +0200
Subject: [PATCH] Implemented training-test split in label mask export

---
 extensions/chaeo/examples/label_patches.py |  1 +
 extensions/chaeo/workflows.py              | 37 +++++++++++++++++-----
 2 files changed, 30 insertions(+), 8 deletions(-)

diff --git a/extensions/chaeo/examples/label_patches.py b/extensions/chaeo/examples/label_patches.py
index 5e032086..1ccd7033 100644
--- a/extensions/chaeo/examples/label_patches.py
+++ b/extensions/chaeo/examples/label_patches.py
@@ -11,3 +11,4 @@ if __name__ == '__main__':
         ecotaxa_tsv='c:/Users/rhodes/projects/proj0011-plankton-seg/exp0013/ecotaxa_export_10468_20231012_0930.tsv',
         where_output=autonumber_new_directory(root, 'labeled_patches')
     )
+    print('Finished')
diff --git a/extensions/chaeo/workflows.py b/extensions/chaeo/workflows.py
index d6fc26e1..aa59fce8 100644
--- a/extensions/chaeo/workflows.py
+++ b/extensions/chaeo/workflows.py
@@ -4,6 +4,7 @@ from uuid import uuid4
 
 import numpy as np
 import pandas as pd
+from sklearn.model_selection import train_test_split
 
 from extensions.ilastik.models import IlastikPixelClassifierModel
 from extensions.chaeo.annotators import draw_boxes_on_3d_image
@@ -173,7 +174,11 @@ def transfer_ecotaxa_labels_to_patch_stacks(
         ecotaxa_tsv: str,
         where_output: str,
         patch_size: tuple = (256, 256),
+        tr_split=0.6,
 ) -> Dict:
+    assert tr_split > 0.5 # reduce chance that low-probability objects are omitted from training
+
+    # read patch metadata
     df_obj = pd.read_csv(
         object_csv,
     )
@@ -188,6 +193,8 @@ def transfer_ecotaxa_labels_to_patch_stacks(
         }
     )
     df_merge = pd.merge(df_obj, df_ecotaxa, left_on='patch_id', right_on='object_id')
+
+    # assign each unique lowest-level annotation to a class index
     se_unique = pd.Series(
         df_merge.object_annotation_hierarchy.unique()
     )
@@ -202,6 +209,7 @@ def transfer_ecotaxa_labels_to_patch_stacks(
         'annotation_class': df_split.loc[:, 1].str.lower()
     })
 
+    # join patch filenames and annotation classes
     df_pf = pd.merge(
         df_merge[['patch_filename', 'object_annotation_hierarchy']],
         df_labels,
@@ -210,13 +218,28 @@ def transfer_ecotaxa_labels_to_patch_stacks(
     )
     df_pl = df_pf[df_pf['object_annotation_hierarchy'].notnull()]
 
-    zstack = np.zeros((*patch_size, 1, len(df_pl)), dtype='uint8')
-
+    # export annotation classes and their summary stats
+    df_tr, df_te = train_test_split(df_pl, train_size=tr_split)
     df_labels['counts'] = df_pl['annotation_class_id'].value_counts()
-    df_labels.to_csv(Path(where_output) / 'labels_key.csv')
-
-    # export patches as z-stack
-    for fi, pl in enumerate(df_pl.itertuples(name='PatchFile')):
+    df_labels = pd.merge(
+        df_labels,
+        pd.DataFrame(
+            [df_tr.annotation_class_id.value_counts(), df_te.annotation_class_id.value_counts()],
+            index=['to_train', 'to_test']
+        ).T,
+        left_on='annotation_class_id',
+        right_index=True,
+        how='outer'
+    )
+    df_labels.loc[df_labels.to_train.isna(), 'to_train'] = 0
+    df_labels.loc[df_labels.to_test.isna(), 'to_test'] = 0
+    for col in ['to_train', 'to_test', 'counts']:
+        df_labels.loc[df_labels[col].isna(), col] = 0
+    df_labels.to_csv(Path(where_output) / 'labels_key.csv', index=False)
+
+    # export patches as a single z-stack
+    zstack = np.zeros((*patch_size, 1, len(df_tr)), dtype='uint8')
+    for fi, pl in enumerate(df_tr.itertuples(name='PatchFile')):
         fn = pl._asdict()['patch_filename']
         ac = pl._asdict()['annotation_class_id']
         acc_bm = generate_file_accessor(Path(where_masks) / fn)
@@ -224,8 +247,6 @@ def transfer_ecotaxa_labels_to_patch_stacks(
         assert acc_bm.chroma == 1
         assert acc_bm.nz == 1
         zstack[:, :, 0, fi] = (acc_bm.data[:, :, 0, 0] == 255) * ac
-
-    # export masks as z-stack
     write_accessor_data_to_file(Path(where_output) / 'zstack_object_label.tif', InMemoryDataAccessor(zstack))
 
 
-- 
GitLab