diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py index 2c972e9d179a2f2d4314eeda0b23c43a181b0a20..c21e338d758a253392f19823bffd72504e7d0082 100644 --- a/model_server/extensions/ilastik/tests/test_ilastik.py +++ b/model_server/extensions/ilastik/tests/test_ilastik.py @@ -4,9 +4,10 @@ import unittest import numpy as np -from model_server.conf.testing import czifile, ilastik_classifiers, output_path -from model_server.base.accessors import CziImageFileAccessor, InMemoryDataAccessor, write_accessor_data_to_file +from model_server.conf.testing import czifile, ilastik_classifiers, output_path, roiset_test_data +from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.extensions.ilastik import models as ilm +from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams from model_server.base.workflows import classify_pixels from tests.test_api import TestServerBaseClass @@ -270,4 +271,26 @@ class TestIlastikOverApi(TestServerBaseClass): class TestIlastikObjectClassification(unittest.TestCase): def setUp(self): - pass + stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path']) + stack_ch_pa = stack.get_one_channel_data(roiset_test_data['pipeline_params']['patches_channel']) + seg_mask = generate_file_accessor(roiset_test_data['multichannel_zstack']['mask_path']) + + self.roiset = RoiSet( + stack_ch_pa, + _get_label_ids(seg_mask), + params=RoiSetMetaParams( + mask_type='boxes', + filters={'area': {'min': 1e3, 'max': 1e4}}, + expand_box_by=(64, 2) + ) + ) + + self.object_classifier = ilm.PatchStackObjectClassifier( + params={'project_file': ilastik_classifiers['seg_to_obj']} + ) + + def test_classify_patches(self): + raw_patches = self.roiset.get_raw_patches() + patch_masks = self.roiset.get_patch_masks() + res = self.object_classifier.infer(raw_patches, patch_masks) + self.assertEqual(0, 1)