import unittest import numpy as np from pathlib import Path from model_server.conf.testing import output_path, roiset_test_data from model_server.base.roiset import RoiSetMetaParams from model_server.base.roiset import _get_label_ids, RoiSet from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.base.models import DummyInstanceSegmentationModel class BaseTestRoiSetMonoProducts(object): def setUp(self) -> None: # set up test raw data and segmentation from file self.stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path']) self.stack_ch_pa = self.stack.get_one_channel_data(roiset_test_data['pipeline_params']['patches_channel']) self.seg_mask = generate_file_accessor(roiset_test_data['multichannel_zstack']['mask_path']) class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def _make_roi_set(self, mask_type='boxes', **kwargs): id_map = _get_label_ids(self.seg_mask) roiset = RoiSet( self.stack_ch_pa, id_map, params=RoiSetMetaParams( mask_type=mask_type, filters=kwargs.get('filters', {'area': {'min': 1e3, 'max': 1e4}}), expand_box_by=(64, 2) ) ) return roiset def test_roi_mask_shape(self, **kwargs): roiset = self._make_roi_set(**kwargs) zmask = roiset.get_zmask() zmask_acc = InMemoryDataAccessor(zmask) self.assertTrue(zmask_acc.is_mask()) # assert dimensionality of zmask self.assertGreater(zmask_acc.shape_dict['Z'], 1) self.assertEqual(zmask_acc.shape_dict['C'], 1) write_accessor_data_to_file(output_path / 'zmask.tif', zmask_acc) # mask values are not just all True or all False self.assertTrue(np.any(zmask)) self.assertFalse(np.all(zmask)) # assert non-trivial meta info in boxes self.assertGreater(roiset.count, 1) sh = roiset.get_df().iloc[1]['mask'].shape ar = roiset.get_df().iloc[1]['area'] self.assertGreaterEqual(sh[0] * sh[1], ar) def test_roiset_from_non_zstacks(self, **kwargs): acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0]) self.assertEqual(acc_zstack_slice.nz, 1) id_map = _get_label_ids(self.seg_mask) roiset = RoiSet(acc_zstack_slice, id_map, params=RoiSetMetaParams(mask_type='boxes')) zmask = roiset.get_zmask() zmask_acc = InMemoryDataAccessor(zmask) self.assertTrue(zmask_acc.is_mask()) def test_slices_are_valid(self): roiset = self._make_roi_set() for s in roiset.get_slices(): ebb = roiset.acc_raw.data[s] self.assertEqual(len(ebb.shape), 4) self.assertTrue(np.all([si >= 1 for si in ebb.shape])) def test_rel_slices_are_valid(self): roiset = self._make_roi_set() for roi in roiset: ebb = roiset.acc_raw.data[roi.slice] self.assertEqual(len(ebb.shape), 4) self.assertTrue(np.all([si >= 1 for si in ebb.shape])) rbb = ebb[roi.relative_slice] self.assertEqual(len(rbb.shape), 4) self.assertTrue(np.all([si >= 1 for si in rbb.shape])) def test_make_2d_patches(self): roiset = self._make_roi_set() files = roiset.export_patches( output_path / '2d_patches', draw_bounding_box=True, ) self.assertGreaterEqual(len(files), 1) def test_make_3d_patches(self): roiset = self._make_roi_set() files = roiset.export_patches( output_path / '3d_patches', make_3d=True) self.assertGreaterEqual(len(files), 1) def test_export_annotated_zstack(self): roiset = self._make_roi_set() file = roiset.export_annotated_zstack( output_path / 'annotated_zstack', ) result = generate_file_accessor(Path(file['location']) / file['filename']) self.assertEqual(result.shape, roiset.acc_raw.shape) def test_flatten_image(self): id_map = _get_label_ids(self.seg_mask) roiset = RoiSet(self.stack_ch_pa, id_map, params=RoiSetMetaParams(mask_type='boxes')) df = roiset.get_df() from base.roiset import project_stack_from_focal_points img = project_stack_from_focal_points( df['centroid-0'].to_numpy(), df['centroid-1'].to_numpy(), df['zi'].to_numpy(), self.stack, degree=4, ) self.assertEqual(img.shape[0:2], self.stack.shape[0:2]) write_accessor_data_to_file( output_path / 'flattened.tif', InMemoryDataAccessor(img) ) def test_make_binary_masks(self): roiset = self._make_roi_set() files = roiset.export_patch_masks(output_path / '2d_mask_patches', ) self.assertGreaterEqual(len(files), 1) def test_classify_by(self): roiset = self._make_roi_set() roiset.classify_by('dummy_class', 0, DummyInstanceSegmentationModel()) self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1])) self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1])) def test_raw_patches_are_correct_shape(self): roiset = self._make_roi_set() patches = roiset.get_raw_patches() np, h, w, nc, nz = patches.shape self.assertEqual(np, roiset.count) self.assertEqual(nc, roiset.acc_raw.chroma) def test_patch_masks_are_correct_shape(self): roiset = self._make_roi_set() patch_masks = roiset.get_patch_masks() np, h, w, nc, nz = patch_masks.shape self.assertEqual(np, roiset.count) self.assertEqual(nc, 1) class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def setUp(self) -> None: super().setUp() id_map = _get_label_ids(self.seg_mask) self.roiset = RoiSet( self.stack, id_map, params=RoiSetMetaParams( expand_box_by=(128, 2), mask_type='boxes', filters={'area': {'min': 1e3, 'max': 1e4}}, ) ) def test_multichannel_to_mono_2d_patches(self): files = self.roiset.export_patches( output_path / 'multichannel' / 'mono_2d_patches', white_channel=3, draw_bounding_box=True, ) result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) self.assertEqual(result.chroma, 1) def test_multichannnel_to_mono_2d_patches_rgb_bbox(self): files = self.roiset.export_patches( output_path / 'multichannel' / 'mono_2d_patches_rgb_bbox', white_channel=3, draw_bounding_box=True, bounding_box_channel=1, ) result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) self.assertEqual(result.chroma, 3) def test_multichannnel_to_rgb_2d_patches_bbox(self): files = self.roiset.export_patches( output_path / 'multichannel' / 'rgb_2d_patches_bbox', white_channel=4, rgb_overlay_channels=(3, None, None), draw_mask=True, mask_channel=0, rgb_overlay_weights=(0.1, 1.0, 1.0) ) result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) self.assertEqual(result.chroma, 3) def test_multichannnel_to_rgb_2d_patches_contour(self): files = self.roiset.export_patches( output_path / 'multichannel' / 'rgb_2d_patches_contour', rgb_overlay_channels=(3, None, None), draw_contour=True, contour_channel=1, rgb_overlay_weights=(0.1, 1.0, 1.0) ) result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) self.assertEqual(result.chroma, 3) self.assertEqual(result.get_one_channel_data(2).data.max(), 0) # blue channel is black def test_multichannel_to_multichannel_tif_patches(self): files = self.roiset.export_patches( output_path / 'multichannel' / 'multichannel_tif_patches', ) result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) self.assertEqual(result.chroma, 5) def test_multichannel_annotated_zstack(self): file = self.roiset.export_annotated_zstack( output_path / 'multichannel' / 'annotated_zstack', 'test_multichannel_annotated_zstack', ) result = generate_file_accessor(Path(file['location']) / file['filename']) self.assertEqual(result.chroma, self.stack.chroma) self.assertEqual(result.nz, self.stack.nz)