Skip to content
Snippets Groups Projects
test_pipelines.py 2.38 KiB
Newer Older
from model_server.base.accessors import generate_file_accessor, write_accessor_data_to_file
from model_server.base.pipelines import router, segment, segment_zproj
import model_server.conf.testing as conf
from tests.base.test_model import DummySemanticSegmentationModel
czifile = conf.meta['image_files']['czifile']
zstack = conf.meta['image_files']['tifffile']
output_path = conf.meta['output_path']
class TestSegmentationPipelines(unittest.TestCase):
        self.model = DummySemanticSegmentationModel()
        acc = generate_file_accessor(czifile['path'])
        trace = segment.segment_pipeline({'accessor': acc}, {'model': self.model}, channel=2, smooth=3)
        outfp = output_path / 'pipelines' / 'segment_binary_mask.tif'
        write_accessor_data_to_file(outfp, trace.last)
        img = tifffile.imread(outfp)
        w = czifile['w']
        h = czifile['h']

        self.assertEqual(
            img.shape,
            'Inferred image is not the expected shape'
        )

        self.assertEqual(
            img[int(w/2), int(h/2)],
            255,
            'Middle pixel is not white as expected'
        )

        self.assertEqual(
            img[0, 0],
            0,
            'First pixel is not black as expected'
        )

        interm_fps = trace.write_interm(
            output_path / 'pipelines' / 'segment_interm',
            prefix=czifile['name']
        )
        self.assertTrue([ofp.stem.split('_')[-1] for ofp in interm_fps] == ['mono', 'inference', 'smooth'])

    def test_call_segment_zproj_pipeline(self):
        acc = generate_file_accessor(zstack['path'])

        trace1 = segment_zproj.segment_zproj_pipeline({'accessor': acc}, {'model': self.model}, channel=0, smooth=3, zi=4)
        self.assertEqual(trace1.last.chroma, 1)
        self.assertEqual(trace1.last.nz, 1)

        trace2 = segment_zproj.segment_zproj_pipeline({'accessor': acc}, {'model': self.model}, channel=0, smooth=3)
        self.assertEqual(trace2.last.chroma, 1)
        self.assertEqual(trace2.last.nz, 1)

        trace3 = segment_zproj.segment_zproj_pipeline({'accessor': acc}, {'model': self.model})
        self.assertEqual(trace3.last.chroma, 1)  # still == 1: model returns a single channel regardless of input
        self.assertEqual(trace3.last.nz, 1)