diff --git a/model_server/conf/testing.py b/model_server/conf/testing.py index 85f33a665ec4d7582b52aa871bb1f0e0d8a3038a..042a95512bb4417efd0e09f1359db11a37e4ff45 100644 --- a/model_server/conf/testing.py +++ b/model_server/conf/testing.py @@ -1,13 +1,18 @@ import json import os import unittest +from math import floor from multiprocessing import Process from pathlib import Path from shutil import copyfile +import numpy as np import requests from urllib3 import Retry +from ..base.accessors import GenericImageDataAccessor, InMemoryDataAccessor +from ..base.models import SemanticSegmentationModel, InstanceSegmentationModel + from ..base.accessors import generate_file_accessor class TestServerBaseClass(unittest.TestCase): @@ -128,4 +133,49 @@ def setup_test_data(): return meta # object containing test data paths and metadata, for import into unittest modules -meta = setup_test_data() \ No newline at end of file +meta = setup_test_data() + + +class DummySemanticSegmentationModel(SemanticSegmentationModel): + + model_id = 'dummy_make_white_square' + + def load(self): + return True + + def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict): + super().infer(img) + w = img.shape_dict['X'] + h = img.shape_dict['Y'] + result = np.zeros([h, w], dtype='uint8') + result[floor(0.25 * h) : floor(0.75 * h), floor(0.25 * w) : floor(0.75 * w)] = 255 + return InMemoryDataAccessor(data=result), {'success': True} + + def label_pixel_class( + self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor: + mask, _ = self.infer(img) + return mask + + +class DummyInstanceSegmentationModel(InstanceSegmentationModel): + + model_id = 'dummy_pass_input_mask' + + def load(self): + return True + + def infer( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor + ) -> (GenericImageDataAccessor, dict): + return img.__class__( + (mask.data / mask.data.max()).astype('uint16') + ) + + def label_instance_class( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs + ) -> GenericImageDataAccessor: + """ + Returns a trivial segmentation, i.e. the input mask with value 1 + """ + super(DummyInstanceSegmentationModel, self).label_instance_class(img, mask, **kwargs) + return self.infer(img, mask) diff --git a/tests/base/test_api.py b/tests/base/test_api.py index f368dc05212c59f1b2ece208a719b4e249cedf70..e31b4366f3d4baa9371fcd80c3338cc2c5fbb013 100644 --- a/tests/base/test_api.py +++ b/tests/base/test_api.py @@ -8,7 +8,7 @@ import model_server.conf.testing as conf from model_server.base.accessors import InMemoryDataAccessor from model_server.base.api import app from model_server.base.session import session -from tests.base.test_model import DummyInstanceSegmentationModel, DummySemanticSegmentationModel +from model_server.conf.testing import DummySemanticSegmentationModel, DummyInstanceSegmentationModel czifile = conf.meta['image_files']['czifile'] diff --git a/tests/base/test_model.py b/tests/base/test_model.py index fb5a8f21e263ea5ab5f08d31f727ddcd24386e63..d975f7cd8725e0215391b4a526feab7cd69eeb31 100644 --- a/tests/base/test_model.py +++ b/tests/base/test_model.py @@ -1,58 +1,12 @@ -from math import floor import unittest -import numpy as np - import model_server.conf.testing as conf -from model_server.base.accessors import CziImageFileAccessor, GenericImageDataAccessor, InMemoryDataAccessor -from model_server.base.models import CouldNotLoadModelError, InstanceSegmentationModel, SemanticSegmentationModel, BinaryThresholdSegmentationModel +from model_server.conf.testing import DummySemanticSegmentationModel, DummyInstanceSegmentationModel +from model_server.base.accessors import CziImageFileAccessor +from model_server.base.models import CouldNotLoadModelError, BinaryThresholdSegmentationModel czifile = conf.meta['image_files']['czifile'] -class DummySemanticSegmentationModel(SemanticSegmentationModel): - - model_id = 'dummy_make_white_square' - - def load(self): - return True - - def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict): - super().infer(img) - w = img.shape_dict['X'] - h = img.shape_dict['Y'] - result = np.zeros([h, w], dtype='uint8') - result[floor(0.25 * h) : floor(0.75 * h), floor(0.25 * w) : floor(0.75 * w)] = 255 - return InMemoryDataAccessor(data=result), {'success': True} - - def label_pixel_class( - self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor: - mask, _ = self.infer(img) - return mask - - -class DummyInstanceSegmentationModel(InstanceSegmentationModel): - - model_id = 'dummy_pass_input_mask' - - def load(self): - return True - - def infer( - self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor - ) -> (GenericImageDataAccessor, dict): - return img.__class__( - (mask.data / mask.data.max()).astype('uint16') - ) - - def label_instance_class( - self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs - ) -> GenericImageDataAccessor: - """ - Returns a trivial segmentation, i.e. the input mask with value 1 - """ - super(DummyInstanceSegmentationModel, self).label_instance_class(img, mask, **kwargs) - return self.infer(img, mask) - class TestCziImageFileAccess(unittest.TestCase): def setUp(self) -> None: diff --git a/tests/base/test_pipelines.py b/tests/base/test_pipelines.py index 2a97c5b4088976f96a4973e0f70e8472fc722047..9f1b0303cbacf9fe9fe9969f7beadfe1ce942711 100644 --- a/tests/base/test_pipelines.py +++ b/tests/base/test_pipelines.py @@ -4,7 +4,7 @@ from model_server.base.accessors import generate_file_accessor, write_accessor_d from model_server.base.pipelines import router, segment, segment_zproj import model_server.conf.testing as conf -from tests.base.test_model import DummySemanticSegmentationModel +from model_server.conf.testing import DummySemanticSegmentationModel czifile = conf.meta['image_files']['czifile'] zstack = conf.meta['image_files']['tifffile'] diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index 0a03973e78ca824083f2821157c70bb9adc77265..785358c92961662273f97f8f8c74037e03fc3cf3 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -11,7 +11,7 @@ from model_server.base.roiset import filter_df_overlap_bbox, filter_df_overlap_s from model_server.base.roiset import RoiSet from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file, PatchStack import model_server.conf.testing as conf -from tests.base.test_model import DummyInstanceSegmentationModel +from model_server.conf.testing import DummyInstanceSegmentationModel data = conf.meta['image_files'] output_path = conf.meta['output_path'] diff --git a/tests/base/test_roiset_derived.py b/tests/base/test_roiset_derived.py index 52e7f6fc0a917d719262dcd0c19f3170414f2f18..156ef9fe42c472dc829085ed3b8c26bb46ac003a 100644 --- a/tests/base/test_roiset_derived.py +++ b/tests/base/test_roiset_derived.py @@ -7,7 +7,7 @@ from model_server.base.roiset import RoiSetWithDerivedChannelsExportParams, RoiS from model_server.base.roiset import RoiSetWithDerivedChannels from model_server.base.accessors import generate_file_accessor, PatchStack import model_server.conf.testing as conf -from tests.base.test_model import DummyInstanceSegmentationModel +from model_server.conf.testing import DummyInstanceSegmentationModel data = conf.meta['image_files'] params = conf.meta['roiset']