From c6e2d7fc6278bb759a99ccd166bfe24750e343aa Mon Sep 17 00:00:00 2001
From: Constantin Pape <constantin.pape@iwr.uni-heidelberg.de>
Date: Tue, 11 Jun 2019 13:53:20 +0200
Subject: [PATCH] Implement test for genes table

---
 test/attributes/test_genes.py | 47 ++++++++++++++++++++++++++++++++++-
 1 file changed, 46 insertions(+), 1 deletion(-)

diff --git a/test/attributes/test_genes.py b/test/attributes/test_genes.py
index 8848ab1..8f34910 100644
--- a/test/attributes/test_genes.py
+++ b/test/attributes/test_genes.py
@@ -1,3 +1,48 @@
 import unittest
+import sys
+import os
+import numpy as np
+sys.path.append('../..')
 
-# TODO write unittest that checks new version of gene mapping against original
+
+# check new version of gene mapping against original
+class TestGeneAttributes(unittest.TestCase):
+    test_file = 'test_table.csv'
+
+    def tearDown(self):
+        try:
+            os.remove(self.test_file)
+        except OSError:
+            pass
+
+    def load_table(self, table_file):
+        table = np.genfromtxt(table_file, delimiter='\t', skip_header=1,
+                              dtype='float32')
+        return table
+
+    def test_genes(self):
+        from scripts.attributes.genes import write_genes_table
+        from scripts.files import get_h5_path_from_xml
+
+        # load original genes table
+        original_table_file = '../../data/0.0.0/tables/em-segmented-cells-labels/genes.csv'
+        original_table = self.load_table(original_table_file)
+        self.assertEqual(original_table.dtype, np.dtype('float32'))
+        labels = original_table[:, 0].astype('uint64')
+
+        # compute and load the genes table
+        segm_file = '../../data/0.0.0/segmentations/em-segmented-cells-labels.h5'
+        genes_file = '../../data/0.0.0/misc/meds_all_genes.xml'
+        genes_file = get_h5_path_from_xml(genes_file)
+        table_file = self.test_file
+        print("Start computation ...")
+        write_genes_table(segm_file, genes_file, table_file, labels)
+        table = self.load_table(table_file)
+
+        # make sure new and old table agree
+        self.assertEqual(table.shape, original_table.shape)
+        self.assertTrue(np.allclose(table, original_table))
+
+
+if __name__ == '__main__':
+    unittest.main()
-- 
GitLab