diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index fbda1cf11d853803ae2eac06b2425cdfbebcc97c..24b9dd6c461c9d36d864017b29df9ad878c0e69b 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -341,15 +341,12 @@ class RoiSet(object): :param df: dataframe containing at minimum bounding box and segmentation mask information :param params: optional arguments that influence the definition and representation of ROIs """ - # assert acc_obj_ids.chroma == 1 - # self.acc_obj_ids = acc_obj_ids self.acc_raw = acc_raw self.accs_derived = [] self.params = params self._df = df self.count = len(self._df) - # self.object_class_maps = {} # classification results def __iter__(self): """Expose ROI meta information via the Pandas.DataFrame API""" @@ -555,7 +552,7 @@ class RoiSet(object): self.get_patch_masks_acc(expanded=False, pad_to=None) ) - self._df['classify_by_' + name] = pd.Series(dtype='Int64') + se = pd.Series(dtype='Int64', index=self._df.index) for i, roi in enumerate(self): oc = np.unique( @@ -563,7 +560,9 @@ class RoiSet(object): obmap_patches.iat(i).data ) )[-1] - self._df.loc[roi.Index, 'classify_by_' + name] = oc + se[roi.Index] = oc + self.set_classification(name, se) + def get_object_class_map(self, name: str) -> InMemoryDataAccessor: """ @@ -774,6 +773,23 @@ class RoiSet(object): dfe['patch'] = dfe.apply(lambda r: _make_patch(r), axis=1) return dfe + @property + def classification_columns(self): + """ + Return list of columns that describe instance classification results + """ + pr = 'classify_by_' + return [c.split(pr)[1] for c in self._df.columns if c.startswith(pr)] + + def set_classification(self, cname: str, se: pd.Series): + """ + Set instance classification result as a column addition on dataframe + :param cname: name of classification result + :param se: series containing class information + """ + col = f'classify_by_{cname}' + self._df[col] = se + def run_exports(self, where: Path, channel, prefix, params: RoiSetExportParams) -> dict: """ Export various representations of ROIs, e.g. patches, annotated stacks, and object maps. @@ -814,9 +830,7 @@ class RoiSet(object): if k == 'annotated_zstacks': record[k] = str(Path(k) / self.export_annotated_zstack(subdir, prefix=pr, **kp)) if k == 'object_classes': - pr = 'classify_by_' - cnames = [c.split(pr)[1] for c in self._df.columns if c.startswith(pr)] - for n in cnames: + for n in self.classification_columns: fp = subdir / n / (pr + '.tif') write_accessor_data_to_file(fp, self.get_object_class_map(n)) record[f'{k}_{n}'] = str(fp)