Skip to content
Snippets Groups Projects
test_zstack.py 9.08 KiB
Newer Older
import numpy as np

from conf.testing import output_path

from extensions.chaeo.conf.testing import multichannel_zstack, pixel_classifier, pipeline_params
from extensions.chaeo.params import RoiSetExportParams, RoiSetMetaParams
from extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack
from extensions.chaeo.workflows import infer_object_map_from_zstack
from extensions.chaeo.zmask import build_zmask_from_object_mask
from model_server.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
from extensions.ilastik.models import IlastikPixelClassifierModel
from model_server.models import DummyInstanceSegmentationModel

class TestZStackDerivedDataProducts(unittest.TestCase):

    def setUp(self) -> None:
        # need test data incl obj map
        self.stack = generate_file_accessor(multichannel_zstack['path'])
        self.stack_ch_seg = self.stack.get_one_channel_data(pipeline_params['segmentation_channel'])
        self.stack_ch_pa = self.stack.get_one_channel_data(pipeline_params['patches_channel'])
        self.pxmodel = IlastikPixelClassifierModel(
            {'project_file': pixel_classifier['path']},
        mip = InMemoryDataAccessor(
            self.stack_ch_seg.data.max(axis=-1, keepdims=True)
        )
        pxmap, _ = self.pxmodel.infer(mip)
        write_accessor_data_to_file(output_path / 'pxmap.tif', pxmap)
        self.seg_mask = InMemoryDataAccessor(
            pxmap.get_one_channel_data(
                pipeline_params['pxmap_channel']
            ).data > pipeline_params['pxmap_threshold']
        )
        write_accessor_data_to_file(output_path / 'seg_mask.tif', self.seg_mask)
    def test_zmask_makes_correct_boxes(self, mask_type='boxes', **kwargs):
        zmask, meta, df, interm = build_zmask_from_object_mask(
            self.seg_mask,
            self.stack_ch_pa,
        )
        zmask_acc = InMemoryDataAccessor(zmask)
        self.assertTrue(zmask_acc.is_mask())

        # assert dimensionality of zmask
        self.assertGreater(zmask_acc.shape_dict['Z'], 1)
        self.assertEqual(zmask_acc.shape_dict['C'], 1)
        write_accessor_data_to_file(output_path / 'zmask.tif', zmask_acc)

        # mask values are not just all True or all False
        self.assertTrue(np.any(zmask))
        self.assertFalse(np.all(zmask))

        # assert non-trivial meta info in boxes
        self.assertGreater(len(meta), 1)
        sh = meta[1]['mask'].shape
        ar = meta[1]['info'].area
        self.assertGreaterEqual(sh[0] * sh[1], ar)
        # assert dimensionality of intermediate data products
        self.assertEqual(interm['label_map'].shape, zmask.shape[0:2])
        self.assertEqual(interm['argmax'].shape, zmask.shape[0:2])

    def test_zmask_works_on_non_zstacks(self, **kwargs):
        acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0])
        self.assertEqual(acc_zstack_slice.nz, 1)
        zmask, meta, df, interm = build_zmask_from_object_mask(
            acc_zstack_slice,
            mask_type='boxes',
            **kwargs,
        )
        zmask_acc = InMemoryDataAccessor(zmask)
        self.assertTrue(zmask_acc.is_mask())

    def test_zmask_makes_correct_contours(self):
        return self.test_zmask_makes_correct_boxes(mask_type='contours')

    def test_zmask_makes_correct_boxes_with_filters(self):
        return self.test_zmask_makes_correct_boxes(filters={'area': (1e3, 1e4)})

    def test_zmask_makes_correct_expanded_boxes(self):
        return self.test_zmask_makes_correct_boxes(expand_box_by=(64, 2))

    def test_make_2d_patches_from_zmask(self):
        zmask, meta = self.test_zmask_makes_correct_boxes(
            filters={'area': (1e3, 1e4)},
            expand_box_by=(64, 2)
        )
            meta,
            draw_bounding_box=True,
        )
        self.assertGreaterEqual(len(files), 1)

    def test_make_3d_patches_from_zmask(self):
        zmask, meta = self.test_zmask_makes_correct_boxes(
            filters={'area': (1e3, 1e4)},
            expand_box_by=(64, 2),
        )
            meta,
            make_3d=True)
        self.assertGreaterEqual(len(files), 1)

    def test_flatten_image(self):
        zmask, meta, df, interm = build_zmask_from_object_mask(
            self.seg_mask,
            self.stack_ch_pa,
            mask_type='boxes',
        )

        from extensions.chaeo.zmask import project_stack_from_focal_points


        img = project_stack_from_focal_points(
            dff['centroid-0'].to_numpy(),
            dff['centroid-1'].to_numpy(),
            dff['zi'].to_numpy(),
        )

        self.assertEqual(img.shape[0:2], self.stack.shape[0:2])

        write_accessor_data_to_file(
            output_path / 'flattened.tif',
            InMemoryDataAccessor(img)
        )

    def test_make_multichannel_2d_patches_from_zmask(self):
        zmask, meta = self.test_zmask_makes_correct_boxes(
            filters={'area': (1e3, 1e4)},
            expand_box_by=(128, 2)
        )
        files = export_multichannel_patches_from_zstack(
            output_path / '2d_patches_chlorophyl_bbox_overlay',
            InMemoryDataAccessor(self.stack.data),
            meta,
            ch_white=4,
            draw_bounding_box=True,
            bounding_box_channel=1,
    def test_make_multichannel_2d_patches_with_mask_overlay(self):
        zmask, meta = self.test_zmask_makes_correct_boxes(
            filters={'area': (1e3, 1e4)},
            expand_box_by=(128, 2)
        )
        files = export_multichannel_patches_from_zstack(
            output_path / '2d_patches_chlorophyl_mask_overlay',
            InMemoryDataAccessor(self.stack.data),
            meta,
            ch_white=4,
            ch_rgb_overlay=(3, None, None),
            draw_mask=True,
            mask_channel=0,
            overlay_gain=(0.1, 1.0, 1.0)
        )
        self.assertGreaterEqual(len(files), 1)

    def test_make_multichannel_2d_patches_with_contour_overlay(self):
        zmask, meta = self.test_zmask_makes_correct_boxes(
            filters={'area': (1e3, 1e4)},
            expand_box_by=(128, 2)
        )
        files = export_multichannel_patches_from_zstack(
            output_path / '2d_patches_chlorophyl_contour_overlay',
            InMemoryDataAccessor(self.stack.data),
            meta,
            ch_white=4,
            ch_rgb_overlay=(3, None, None),
            draw_contour=True,
            contour_channel=1,
        self.assertGreaterEqual(len(files), 1)

    def test_make_binary_masks_from_zmask(self):
        zmask, meta = self.test_zmask_makes_correct_boxes(
            filters={'area': (1e3, 1e4)},
            expand_box_by=(128, 2)
        )
        files = export_patch_masks_from_zstack(
            output_path / '2d_mask_patches',
            InMemoryDataAccessor(self.stack.data),
            meta,
        )
        self.assertGreaterEqual(len(files), 1)

    def test_object_map_workflow(self):
        pp = pipeline_params
        models = [
            self.pxmodel,
            DummyInstanceSegmentationModel(),
        ]
        models = {
            'pixel_classifier': {
                'model': self.pxmodel,
                'params': {
                    'px_class': 0,
                    'px_prob_threshold': 0.6,
                }
            },
            'object_classifier': {
                'model': DummyInstanceSegmentationModel(),
            }
        }

        roi_params = RoiSetMetaParams(**{
            'mask_type': 'boxes',
            'filters': {
                'area': {'min': 1e3, 'max': 1e8}
            },
        export_params = RoiSetExportParams(**{
            'pixel_probabilities': True,
            'patches_3d': {},
            'patches_2d_for_annotation': {
                'draw_bounding_box': True,
                'rgb_overlay_channels': [3, None, None],
                'rgb_overlay_weights': [0.2, 1.0, 1.0],
                'pad_to': 512,
            },
            'patches_2d_for_training': {
                'draw_bounding_box': False,
                'draw_mask': False,
            },
            'patch_masks': True,
            'annotated_zstacks': {},
            'object_classes': True,
            'dataframe': True,
        infer_object_map_from_zstack(
            multichannel_zstack['path'],
            output_path / 'roiset' / 'workflow',
            models,
            segmentation_channel=pp['segmentation_channel'],
            patches_channel=pp['patches_channel'],