diff --git a/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py b/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py index 2fee54bc06427cc099af900c894bb6d015e753a4..3c721d25da6393d3ffdeb411d02b86ceb5e34223 100644 --- a/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py +++ b/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py @@ -5,10 +5,51 @@ import json import numpy as np import pandas as pd import uuid +import vigra from extensions.chaeo.util import autonumber_new_file from extensions.ilastik.models import IlastikObjectClassifierFromSegmentationModel -from model_server.accessors import generate_file_accessor, write_accessor_data_to_file +from model_server.accessors import generate_file_accessor, GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file + +class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel): + + @staticmethod + def make_tczyx(acc): + assert acc.chroma == 1 + tyx = np.moveaxis( + acc.data[:, :, 0, :], # YX(C)Z + [2, 0, 1], + [0, 1, 2] + ) + return np.expand_dims(tyx, (1, 2)) + # return tyx + + def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): + assert segmentation_img.is_mask() + assert input_img.chroma == 1 + + tagged_input_data = vigra.taggedView(self.make_tczyx(input_img), 'tczyx') + tagged_seg_data = vigra.taggedView(self.make_tczyx(segmentation_img), 'tczyx') + + dsi = [ + { + 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), + 'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data), + } + ] + + obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] + + assert len(obmaps) == 1, 'ilastik generated more than one object map' + assert obmaps[0].shape == (input_img.nz, 1, input_img.hw[0], input_img.hw[1], 1) # z(1)yx(1) + + yxcz = np.moveaxis( + obmaps[0][:, :, :, :, 0], + [2, 3, 1, 0], + [0, 1, 2, 3] + ) + assert yxcz.shape == input_img.shape + return InMemoryDataAccessor(data=yxcz), {'success': True} def get_dataset_info(h5, lane=0): lns = f'{lane:04d}' @@ -154,8 +195,8 @@ if __name__ == '__main__': train_zstack_mask = generate_file_accessor(where_patch_stack / 'zstack_train_mask.tif') new_ilp = root / 'exp0014/test_obj_from_seg.ilp' - mod = IlastikObjectClassifierFromSegmentationModel({'project_file': new_ilp}) + mod = PatchStackObjectClassifier({'project_file': new_ilp}) - result = mod.infer(train_zstack_raw, train_zstack_mask) - write_accessor_data_to_file(where_patch_stack / 'result.tif', result) - print(mod.project_file_abspath) \ No newline at end of file + result_acc, _ = mod.infer(train_zstack_raw, train_zstack_mask) + write_accessor_data_to_file(where_patch_stack / 'result.tif', result_acc) + print(where_patch_stack / 'result.tif') \ No newline at end of file