From f21346e9c320767ce5780248b4524786904f212d Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 19 Feb 2024 13:27:35 +0100
Subject: [PATCH] Roughed in test coverage of ilastik classification by patch
 stack

---
 .../extensions/ilastik/tests/test_ilastik.py  | 29 +++++++++++++++++--
 1 file changed, 26 insertions(+), 3 deletions(-)

diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py
index 2c972e9d..c21e338d 100644
--- a/model_server/extensions/ilastik/tests/test_ilastik.py
+++ b/model_server/extensions/ilastik/tests/test_ilastik.py
@@ -4,9 +4,10 @@ import unittest
 
 import numpy as np
 
-from model_server.conf.testing import czifile, ilastik_classifiers, output_path
-from model_server.base.accessors import CziImageFileAccessor, InMemoryDataAccessor, write_accessor_data_to_file
+from model_server.conf.testing import czifile, ilastik_classifiers, output_path, roiset_test_data
+from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
 from model_server.extensions.ilastik import models as ilm
+from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams
 from model_server.base.workflows import classify_pixels
 from tests.test_api import TestServerBaseClass
 
@@ -270,4 +271,26 @@ class TestIlastikOverApi(TestServerBaseClass):
 
 class TestIlastikObjectClassification(unittest.TestCase):
     def setUp(self):
-        pass
+        stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path'])
+        stack_ch_pa = stack.get_one_channel_data(roiset_test_data['pipeline_params']['patches_channel'])
+        seg_mask = generate_file_accessor(roiset_test_data['multichannel_zstack']['mask_path'])
+
+        self.roiset = RoiSet(
+            stack_ch_pa,
+            _get_label_ids(seg_mask),
+            params=RoiSetMetaParams(
+                mask_type='boxes',
+                filters={'area': {'min': 1e3, 'max': 1e4}},
+                expand_box_by=(64, 2)
+            )
+        )
+
+        self.object_classifier = ilm.PatchStackObjectClassifier(
+            params={'project_file': ilastik_classifiers['seg_to_obj']}
+        )
+
+    def test_classify_patches(self):
+        raw_patches = self.roiset.get_raw_patches()
+        patch_masks = self.roiset.get_patch_masks()
+        res = self.object_classifier.infer(raw_patches, patch_masks)
+        self.assertEqual(0, 1)
-- 
GitLab