Skip to content
Snippets Groups Projects
Commit 16d96d55 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Implemented but did not yet test adding object classification result to RoiSet dataframe

parent 73754817
No related branches found
No related tags found
No related merge requests found
...@@ -26,11 +26,6 @@ class RoiSetExportParams(BaseModel): ...@@ -26,11 +26,6 @@ class RoiSetExportParams(BaseModel):
annotated_z_stack: Union[AnnotatedZStackParams, None] = None annotated_z_stack: Union[AnnotatedZStackParams, None] = None
class RoiClassifierValue(BaseModel):
name: str
value: int
class RoiFilterRange(BaseModel): class RoiFilterRange(BaseModel):
min: float min: float
max: float max: float
...@@ -39,5 +34,3 @@ class RoiFilterRange(BaseModel): ...@@ -39,5 +34,3 @@ class RoiFilterRange(BaseModel):
class RoiFilter(BaseModel): class RoiFilter(BaseModel):
area: Union[RoiFilterRange, None] = None area: Union[RoiFilterRange, None] = None
solidity: Union[RoiFilterRange, None] = None solidity: Union[RoiFilterRange, None] = None
classifiers: List[RoiClassifierValue] = []
...@@ -12,7 +12,7 @@ from sklearn.model_selection import train_test_split ...@@ -12,7 +12,7 @@ from sklearn.model_selection import train_test_split
from extensions.chaeo.accessors import MonoPatchStack from extensions.chaeo.accessors import MonoPatchStack
from extensions.chaeo.annotators import draw_boxes_on_3d_image from extensions.chaeo.annotators import draw_boxes_on_3d_image
from extensions.chaeo.models import PatchStackObjectClassifier from extensions.chaeo.models import PatchStackObjectClassifier
from extensions.chaeo.params import ZMaskExportParams from extensions.chaeo.params import RoiSetExportParams
from extensions.chaeo.process import mask_largest_object from extensions.chaeo.process import mask_largest_object
from extensions.chaeo.products import export_patches_from_zstack, export_patch_masks_from_zstack, export_multichannel_patches_from_zstack, get_patches_from_zmask_meta, get_patch_masks_from_zmask_meta from extensions.chaeo.products import export_patches_from_zstack, export_patch_masks_from_zstack, export_multichannel_patches_from_zstack, get_patches_from_zmask_meta, get_patch_masks_from_zmask_meta
from extensions.chaeo.zmask import project_stack_from_focal_points, RoiSet from extensions.chaeo.zmask import project_stack_from_focal_points, RoiSet
...@@ -240,7 +240,7 @@ def infer_object_map_from_zstack( ...@@ -240,7 +240,7 @@ def infer_object_map_from_zstack(
zmask_type: str = 'boxes', zmask_type: str = 'boxes',
zmask_filters: Dict = None, zmask_filters: Dict = None,
# zmask_expand_box_by: int = None, # zmask_expand_box_by: int = None,
exports: ZMaskExportParams = None, exports: RoiSetExportParams = None,
**kwargs, **kwargs,
) -> Dict: ) -> Dict:
assert len(models) == 2 assert len(models) == 2
...@@ -281,7 +281,7 @@ def infer_object_map_from_zstack( ...@@ -281,7 +281,7 @@ def infer_object_map_from_zstack(
ti.click('threshold_pixel_mask') ti.click('threshold_pixel_mask')
# make zmask # make zmask
obj_table = RoiSet( rois = RoiSet(
obmask.get_one_channel_data(pxmap_foreground_channel), obmask.get_one_channel_data(pxmap_foreground_channel),
stack.get_one_channel_data(segmentation_channel), stack.get_one_channel_data(segmentation_channel),
mask_type=zmask_type, mask_type=zmask_type,
...@@ -291,7 +291,7 @@ def infer_object_map_from_zstack( ...@@ -291,7 +291,7 @@ def infer_object_map_from_zstack(
ti.click('generate_zmasks') ti.click('generate_zmasks')
# record pixel scale # record pixel scale
obj_table.df['pixel_scale_in_micrometers'] = float(stack.pixel_scale_in_micrometers.get('X')) rois.df['pixel_scale_in_micrometers'] = float(stack.pixel_scale_in_micrometers.get('X'))
# ti, stack, fstem, obmask, pxmap, obj_table = get_zmask_meta( # ti, stack, fstem, obmask, pxmap, obj_table = get_zmask_meta(
# input_file_path, # input_file_path,
...@@ -307,79 +307,79 @@ def infer_object_map_from_zstack( ...@@ -307,79 +307,79 @@ def infer_object_map_from_zstack(
# **kwargs # **kwargs
# ) # )
# extract patches to accessor # # extract patches to accessor
patches_acc = get_patches_from_zmask_meta( # patches_acc = get_patches_from_zmask_meta(
stack.get_one_channel_data(patches_channel), # stack.get_one_channel_data(patches_channel),
obj_table.zmask_meta, # obj_table.zmask_meta,
rescale_clip=zmask_clip, # rescale_clip=zmask_clip,
make_3d=False, # make_3d=False,
focus_metric='max_sobel', # focus_metric='max_sobel',
**kwargs # **kwargs
) # )
#
# TODO: make this a method of ZMaskObjectTable class # # extract masks
# extract masks # patch_masks_acc = get_patch_masks_from_zmask_meta(
patch_masks_acc = get_patch_masks_from_zmask_meta( # stack,
stack, # obj_table.zmask_meta,
obj_table.zmask_meta, # **kwargs
**kwargs # )
)
# TODO: add ZMaskObjectTable method to apply object classification results as new DataFrame column
# send patches and mask stacks to object classifier
result_acc, _ = object_classifier.infer(patches_acc, patch_masks_acc)
labels_map = obj_table.interm['label_map']
output_map = np.zeros(labels_map.shape, dtype=labels_map.dtype)
assert labels_map.shape == obj_table.get_label_map().shape
assert labels_map.dtype == obj_table.get_label_map().dtype
# assign labels to object map: # # send patches and mask stacks to object classifier
meta = [] # result_acc, _ = object_classifier.infer(patches_acc, patch_masks_acc)
for ii in range(0, len(obj_table.zmask_meta)):
object_id = obj_table.zmask_meta[ii]['info'].label # labels_map = obj_table.interm['label_map']
result_patch = mask_largest_object(result_acc.iat(ii)) # output_map = np.zeros(labels_map.shape, dtype=labels_map.dtype)
object_class = np.unique(result_patch)[1] # assert labels_map.shape == obj_table.get_label_map().shape
output_map[labels_map == object_id] = object_class # assert labels_map.dtype == obj_table.get_label_map().dtype
meta.append({'object_id': ii, 'object_class': object_id}) #
# # assign labels to object map:
# meta = []
# for ii in range(0, len(obj_table.zmask_meta)):
# object_id = obj_table.zmask_meta[ii]['info'].label
# result_patch = mask_largest_object(result_acc.iat(ii))
# object_class = np.unique(result_patch)[1]
# output_map[labels_map == object_id] = object_class
# meta.append({'object_id': ii, 'object_class': object_id})
object_class_map = rois.classify_by(patches_channel)
# TODO: add ZMaskObjectTable method to export object map # TODO: add ZMaskObjectTable method to export object map
output_path = Path(output_folder_path) / ('obj_classes_' + (fstem + '.tif')) output_path = Path(output_folder_path) / ('obj_classes_' + (fstem + '.tif'))
write_accessor_data_to_file( write_accessor_data_to_file(
output_path, output_path,
InMemoryDataAccessor(output_map) InMemoryDataAccessor(object_class_map)
) )
ti.click('export_object_classes') ti.click('export_object_classes')
if exports.patches_3d: if exports.patches_3d:
obj_table.export_3d_patches( rois.export_3d_patches(
Path(output_folder_path) / '3d_patches', Path(output_folder_path) / '3d_patches',
fstem, fstem,
patches_channel, patches_channel,
exports.patches_3d exports.patches_3d
) )
ti.click('export_3d_patches') ti.click('export_3d_patches')
if exports.patches_2d_for_annotation: if exports.patches_2d_for_annotation:
obj_table.export_2d_patches_for_annotation( rois.export_2d_patches_for_annotation(
Path(output_folder_path) / '2d_patches_annotation', Path(output_folder_path) / '2d_patches_annotation',
fstem, fstem,
patches_channel, patches_channel,
exports.patches_2d_for_annotation exports.patches_2d_for_annotation
) )
ti.click('export_2d_patches_for_annotation') ti.click('export_2d_patches_for_annotation')
if exports.patches_2d_for_training: if exports.patches_2d_for_training:
obj_table.export_2d_patches_for_training( rois.export_2d_patches_for_training(
Path(output_folder_path) / '2d_patches_training', Path(output_folder_path) / '2d_patches_training',
fstem, fstem,
patches_channel, patches_channel,
exports.patches_2d_for_training exports.patches_2d_for_training
) )
ti.click('export_2d_patches_for_training') ti.click('export_2d_patches_for_training')
if exports.patch_masks: if exports.patch_masks:
obj_table.export_patch_masks( rois.export_patch_masks(
Path(output_folder_path) / 'patch_masks', Path(output_folder_path) / 'patch_masks',
fstem, fstem,
patches_channel, patches_channel,
...@@ -387,18 +387,17 @@ def infer_object_map_from_zstack( ...@@ -387,18 +387,17 @@ def infer_object_map_from_zstack(
) )
if exports.annotated_z_stack: if exports.annotated_z_stack:
obj_table.export_annotated_zstack( rois.export_annotated_zstack(
Path(output_folder_path) / 'patch_masks', Path(output_folder_path) / 'patch_masks',
fstem, fstem,
patches_channel, patches_channel,
exports.annotated_z_stack exports.annotated_z_stack
) )
ti.click('export_annotated_zstack') ti.click('export_annotated_zstack')
return { return {
'timer_results': ti.events, 'timer_results': ti.events,
'dataframe': pd.DataFrame(meta), 'dataframe': rois.df,
'interm': {}, 'interm': {},
'output_path': output_path.__str__(), 'output_path': output_path.__str__(),
} }
......
...@@ -8,10 +8,11 @@ from sklearn.preprocessing import PolynomialFeatures ...@@ -8,10 +8,11 @@ from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression from sklearn.linear_model import LinearRegression
from extensions.chaeo.annotators import draw_boxes_on_3d_image from extensions.chaeo.annotators import draw_boxes_on_3d_image
from extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack from extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack, get_patches_from_zmask_meta, get_patch_masks_from_zmask_meta
from extensions.chaeo.params import RoiFilter, RoiSetExportParams from extensions.chaeo.params import RoiFilter, RoiSetExportParams
from extensions.chaeo.process import mask_largest_object
from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file
from model_server.models import InstanceSegmentationModel
class RoiSet(object): class RoiSet(object):
...@@ -30,13 +31,11 @@ class RoiSet(object): ...@@ -30,13 +31,11 @@ class RoiSet(object):
filters=filters, filters=filters,
mask_type=mask_type, mask_type=mask_type,
expand_box_by=expand_box_by expand_box_by=expand_box_by
) # currently, some methods can add columns to self.df )
self.acc_raw = acc_raw self.acc_raw = acc_raw
self.count = len(self.zmask_meta) self.count = len(self.zmask_meta)
self.object_id_labels = self.interm['label_map']
def get_label_map(self):
return self.interm.lamap
def get_argmax(self): def get_argmax(self):
return self.interm.argmax return self.interm.argmax
...@@ -152,19 +151,48 @@ class RoiSet(object): ...@@ -152,19 +151,48 @@ class RoiSet(object):
projected = self.acc_raw.data.max(axis=-1) projected = self.acc_raw.data.max(axis=-1)
return projected return projected
def classify_by(self, column_name, object_classification_model, patch_basis=True): def get_raw_patches(self, channel):
# add one column to df where each value is an integer return get_patches_from_zmask_meta(self.acc_raw(channel), self.zmask_meta)
pass
def get_raw_patches(self, filters: RoiFilter): def get_patch_masks(self):
pass return get_patch_masks_from_zmask_meta(self.acc_raw, self.zmask_meta)
def classify_by(self, channel, object_classification_model: InstanceSegmentationModel):
# do this on a patch basis, i.e. only one object per frame
obmap_patches = object_classification_model.label_instance_class(
self.get_raw_patches(channel),
self.get_patch_masks()
)
def get_patch_masks(self, filters: RoiFilter): lamap = self.object_id_labels
output_map = np.zeros(lamap.shape, dtype=lamap.dtype)
self.df['instance_class'] = np.nan
# assign labels to object map:
for ii in range(0, self.count):
object_id = self.zmask_meta[ii]['info'].label
result_patch = mask_largest_object(obmap_patches.iat(ii))
object_class = np.unique(result_patch)[1]
output_map[self.object_id_labels == object_id] = object_class
self.df[object_id, 'instance_class'] = object_class
return InMemoryDataAccessor(output_map)
# TODO: test
def get_object_mask_by_id(self, obj_id):
return self.object_id_labels == obj_id
def get_object_mask_by_class(self, class_id):
return self.object_id_labels == class_id
# TODO: implement
def get_object_patch_by_id(self, obj_id):
pass pass
def get_object_map(self, filters: RoiFilter): def get_object_map(self, filters: RoiFilter):
pass pass
def build_zmask_from_object_mask( def build_zmask_from_object_mask(
obmask: GenericImageDataAccessor, obmask: GenericImageDataAccessor,
zstack: GenericImageDataAccessor, zstack: GenericImageDataAccessor,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment