diff --git a/docs/docs.md b/docs/docs.md new file mode 100644 index 0000000000000000000000000000000000000000..3b45340c13a7672cea3a87c98e8a56a4c99c13cb --- /dev/null +++ b/docs/docs.md @@ -0,0 +1,200 @@ +# SVLT: <u>S</u>erving <u>V</u>ision to <u>L</u>iving <u>T</u>hings + +## Overview +SVLT is a service for on-demand computer vision, adapted specifically to life sciences applications that require +rapid feedback as images are acquired. It exposes APIs that make it easy for instruments to request +sophisticated image analyses while experiments are running. This is intended to accelerate advanced imaging, +experimental manipulation of biological entities, and image feedback for the preparation and manufacturing of +bioproducts. + +## Key design goals + +**Interoperability:** Image feedback is desired across diverse instrumentation and computing infrastructures. SVLTs +service architecture makes it broadly usable in instruments, labs, clusters, and clouds. + +**Low latency:** SVLT is the opposite of a workflow manager, in that it prioritizes a single rapid answer over the +throughput of batches of answer. But its API is async, so clients can readily parallelize computation as needed. +Pipeline performance is automatically recorded to help accurately profile system performance. + +**Extensibility:** There will always be new tasks, formats, and models. Core classes and modular pipelines make it +as easy as possible to incorporate these into the project. + +## Solutions to specific problems + +**Problem:** Instrument control and data analysis thrive in very different environments.\ +**Solution:** A FastAPI supports rapid communication between environments. It also eases distributed deployment to +dedicated on-prem hardware, HPC clusters, the cloud, etc. MQTT is also implemented on an experimental basis for +embedded and robotics applications. + +**Problem:** Microscopes see heterogeneous objects in complex scenes.\ +**Solution:** SVLT leverages cutting-edge machine learning frameworks. It serves ML models for segmentation, the +results of which are tabulated as RoiSet objects. Queries of this set are then passed to patch- or scene-, +context-aware object classification models e.g. YOLO. And these results accumulate during sessions, paving the way +for spatial statistics and advanced learning workflows. + +**Problem:** Machine learning uses large in-memory models. Loading them is a bottleneck.\ +**Solution:** The service holds models in state, so they are only loaded once per session. Hence, serial inference +calls are usually very fast. + +**Problem:** ML models are very particular about their input specification.\ +**Solution:** SVLT wraps image data in Accessors with intuitive dimensions (height, width, depth, colorspace). +Where possible, these are derived automatically from microscope file metadata. Models then introspect their +compatibility when attempting to run predictions. + +**Problem:** It is hard to keep track of parameters in image analysis pipelines.\ +**Solution:** Pipeline parameters are specified as JSON-serializable types with helpful docstrings. These schemas +are inheritable and discoverable by clients. Intermediate data representations are persisted and can be exported during +parameter-tuning campaigns. + +**Problem:** ML frameworks are picky about their platform and dependencies.\ +**Solution:** The project comprises core classes (data and model abstraction, base API) and framework-specific +extensions on a monorepo. Framework-platform pairs can then be built to spec with conda build. + +## Components + +### svlt.core.accessors +Abstractions of single- and multi-position image data, generally multichannel Z-stacks. Data can either exist purely +in-memory, such as the results of a pipeline operation, or be backed by image files. For example:<br> + +``` +>>> import numpy as np +>>> from svlt.core.accessors import InMemoryDataAccessor + +>>> arr = np.array([[4, 4], [2, 3]], dtype='uint8') +>>> acc = InMemoryDataAccessor(arr) + +>>> acc.shape_dict +{'Y': 2, 'X': 2, 'C': 1, 'Z': 1} + +>>> acc.is_mask() +False + +>>> acc.chroma == 1 +True + +>>> new_acc = acc.apply(lambda x: 2 * x) +>>> type(new_acc) +<class 'svlt.base.accessors.InMemoryDataAccessor'> + +>>> new_acc.unique() +(array([4, 6, 8], dtype=uint8), array([1, 1, 2], dtype=int64)) +``` +Data accessors can be imported directly from microscope image file formats: +``` +>>> from svlt.base.accessors import GenericImageFileAccessor +>>> facc = GenericImageFileAccessor.read('~/svlt/test_data/D3-selection-01.czi') +>>> facc.shape_dict +{'Y': 1274, 'X': 1274, 'C': 5, 'Z': 1} +>>> facc.get_channels([2, 3]).shape_dict +{'Y': 1274, 'X': 1274, 'C': 2, 'Z': 1} +>>> facc.pixel_scale_in_micrometers +{'X': 0.25074568288853993, 'Y': 0.25074568288853993} +``` + +### svlt.core.api +The core FastAPI methods, e.g. for configuring and restarting sessions, importing and exporting data accessors, +and interacting with generic models. + +### svlt.core.models +Abstractions of models that generally are (1) loaded once then (2) called for inference many times during an experiment. + +### svlt.core.pipelines.* +Implements classical, learned, and hybrid pipelines for image segmentation and analysis. Each module is a single +pipeline, comprising a main function that acts of accessors, its API endpoints, and a parameter model. + +### svlt.core.roiset +The RoiSet data structure catalogs regions of interest (ROI) inside of images. ROI metadata are queryable with the +Pandas API. Patches can be exported for visualization and annotation. + +### svlt.core.session +This module represents the state of a "session" and is mainly called via the core FastAPI + +## Examples + +### Running unittests +1. Build and activate the conda environment as described in README.md +2. Copy test data from https://s3.embl.de/model-server-fixtures (inquire for access) to a local directory <data_root> +3. In the development shell or IDE, set an environmental variable UNITTEST_DATA_ROOT=<data_root> +4. From the project root, run `python -m unittest discover` + +### Interactive server example +This involves manual clicking but covers the main steps that a client should call programmatically. +The same general sequence is covered in: +`tests.ilastik.test_ilastik.TestRoiSetWorkflowOverApi.test_ilastik_infer_pixel_probability` +1. Procure the test data as described previously +2. Start the FastAPI server: `python scripts.run_server.py` +3. Wait for the browser to open with `http://127.0.0.1:8000/status` +4. Downlaod and open the configuration JSON if it does not appear automatically in the browser. +5. Copy the test image `D3-selection-01.czi` to the session input directory listed under `paths.inbound_images` +6. Note the location of `paths.outbound_images` +7. In the browser, navigate to `http://127.0.0.1:8000/docs` +8. Load an ilastik model: + - under `/ilastik/seg/load/`, press "Try it out" + - under "Request Body", copy the full local path to `ilastik/demo_px.ilp` in the "project_file" field + - set "model_id" to "test_model" + - set "duplicate" to false + - click "Execute" + - confirm that you receive a server response code of 200 +9. Load the image as an in-memory accessor under `/accessors/read_from_file/` + - set "filename" to `D3-selection-01.czi` + - set "accessor_id" to 'test_acc' + - again, "Execute" and confirm a status code of 200 +10. Run inference on the image under `pipelines/segment` + - set "accessor_id" to 'test_acc' + - set "model_id" to "test_model" + - "Execute" and confirm a status code of 200 + - The output mask should soon appear in the directory described in `paths.outbound_images` + +## Extending SVLT + +### Where to write extensions +1. If using the same dependencies as `svlt-core`, create a subpackage of `base.extensions` +2. If using different or new dependencies, create a new package that imports `svlt-core` +3. Write unittests in a subdirectory of `tests` with the same (sub)package name +4. For the time being, push to a branch starting with extension_ followed by the (sub)package name +5. Write command-line scripts in `scripts`, i.e. outside of the package's directory tree + +### Extending models +1. First see if any existing model superclass matches the model's task. E.g. if a model processes an image into a +segmentation mask, inherit from `models.SemanticSegmentationModel` +2. Otherwise, inherit from `core.models.Model` +3. In the `.load()` method, implement any resource-intensive code that loads a model once for subsequent use. +4. In the `.infer()` method, implement any code that runs inference on the model. +5. If extensive setup parameters are used, specify these as a Pydantic model e.g. in +`extensions.ilastiik.models.IlastikParams` + +### Extending pipelines + +Using `pipelines.segment` as a template, generally work from bottom to top: + +1. Implement the main pipeline function ending in `*_pipeline()`. + - This function interacts with accessors and models as dictionaries of their python objects. + - E.g. parameter `example_accessor_id` becomes `accessors['example_accessor']` in the calling function, + - and `example_model_id` populates `models['example_model']`, and so on. + - Instantiate, populate, and return a single `PipelineTrace` object with any meaningful intermediate accessors. + - The API will automatically assign an ID to the final accessor in the trace, and all intermediate accessors if +keep_interm parameter is True. +2. Implement the API function and its API path. + - Other than renaming the function, its path, and the `Params` and `Record` subclasses, nothing here needs to change. + - This receives a single `PipelineParams` object as input, automatically maps accessors and models from their ID +strings to objects, passes them to `*_pipeline()`, and returns a single `PipelineRecord` as output. + - Fast API handles the parsing and validation of fields in the request and response. +3. Optionally, sublass `PipelineRecord`, if additional parameters are expected from the pipeline. +4. Subclass `PipelineParams` with additional parameters that the pipeline needs. + - By default, single `accessor_id` and `model_id` fields are already defined. + - These and anything ending in `accessor_id` or `model_id` are automatically validated to check that their IDs match +loaded models. + - Use a descriptive class name. FastAPI uses this in a global schema, so it helps users understand the expected +parameters. + - Use default values where possible. Field validation can optionally use `pydantic.Field` classes for this. + +### Testing + +1. Write unittests that cover new models, pipeline functions, and API calls. +2. Where possible, use the existing test data described "Examples" > "Running unittests" +3. If adding test data or models, inquire about adding these to the distributed test data. +4. For new pipeline modules, it is advisable to both: + - test `*_pipeline()` directly in a subclass of `unittest.TestCase` + - and test FastAPI endpoints in a subclass of `conf.TestServerBaseClass`, by way of its `._get()` and `._put()` methods +5. Where possible, assertions should assure not just error-free functionality, but also test that the content of outputs +is correct. diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py index 187b093038a76c87f10cd1e1b60e90b8b17de095..2129820310ef7f19c02326ec45b593f00ff8a863 100644 --- a/model_server/base/accessors.py +++ b/model_server/base/accessors.py @@ -99,6 +99,18 @@ class GenericImageDataAccessor(ABC): else: return self.data[:, :, 0, :] + @property + def data_yxcz(self) -> np.ndarray: + return self.data + + @property + def data_yxzc(self) -> np.ndarray: + return np.moveaxis( + self.data, + [0, 1, 3, 2], + [0, 1, 2, 3] + ) + @property def data_mono(self) -> np.ndarray: if self.nz > 1: @@ -186,17 +198,27 @@ class GenericImageDataAccessor(ABC): """ return InMemoryDataAccessor(data) - def apply(self, func, preserve_dtype=True): + def apply(self, func, params: dict = {}, preserve_dtype=True, mono=False): """ Apply func to data and return as a new in-memory accessor :param func: function that receives and returns the same size np.ndarray + :param params: (optional) dictionary of parameters to pass to function :param preserve_dtype: returned accessor gets the same dtype as self if True + :param mono: check that accessor is mono and pass only YXZ data to func :return: InMemoryDataAccessor """ - nda = func(self.data) + if mono: + if self.chroma != 1: + raise DataShapeError(f'Expecting monochrome accessor when calling with mono=True') + nda = func(self.data_mono, **params) + else: + nda = func(self.data, **params) if preserve_dtype: nda = nda.astype(self.dtype) - return self._derived_accessor(nda) + if mono: + return self._derived_accessor(np.expand_dims(nda, 2)) + else: + return self._derived_accessor(nda) @property def info(self): diff --git a/model_server/base/annotators.py b/model_server/base/annotators.py index 7626d76e72fdfe9e8f2c74b072d5764c9a568d1a..5b79599d8ab4812f6c686b08108adbae16a19a0b 100644 --- a/model_server/base/annotators.py +++ b/model_server/base/annotators.py @@ -31,20 +31,21 @@ def draw_boxes_on_3d_image(roiset, draw_full_depth=False, **kwargs): for zi in range(0, nz): if draw_full_depth: - subset = roiset.get_df() + subset = roiset.df().bounding_box else: - subset = roiset.get_df().query(f'zi == {zi}') + subset = roiset.df().bounding_box.query(f'zi == {zi}') for ci in range(0, len(channels)): pilimg = Image.fromarray(roiset.acc_raw.data[:, :, channels[ci], zi]) draw = ImageDraw.Draw(pilimg) draw.font = _get_font() - for roi in subset.itertuples('Roi'): - xm = round((roi.x0 + roi.x1) / 2) - draw.rectangle([(roi.x0, roi.y0), (roi.x1, roi.y1)], outline='white', width=linewidth) + def _draw_boxes(bb): + xm = round((bb.x0 + bb.x1) / 2) + draw.rectangle([(bb.x0, bb.y0), (bb.x1, bb.y1)], outline='white', width=linewidth) if kwargs.get('draw_label') is True: - draw.text((xm, roi.y0), f'{roi.label:04d}', fill='white', anchor='mb') + draw.text((xm, bb.y0), f'{bb.name:04d}', fill='white', anchor='mb') + subset.apply(_draw_boxes, axis=1) annotated[:, :, ci, zi] = pilimg diff --git a/model_server/base/api.py b/model_server/base/api.py index b1d92d5d88437ec444df3301952fec4eff0ed945..574f1ebcbb67b57fbe6f38a0c32115589f89830c 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -7,8 +7,8 @@ from pydantic import BaseModel, Field from .accessors import generate_file_accessor, generate_multiposition_file_accessors from .models import BinaryThresholdSegmentationModel +from model_server.rois.models import IntensityThresholdInstanceMaskSegmentationModel from .pipelines.shared import PipelineRecord -from .roiset import IntensityThresholdInstanceMaskSegmentationModel from .session import session, AccessorIdError, InvalidPathError, WriteAccessorError app = FastAPI(debug=True) @@ -59,7 +59,6 @@ def show_session_status(): 'memory': session.get_mem(), 'models': session.describe_loaded_models(), 'paths': session.get_paths(), - 'paths': session.get_paths(), 'accessors': session.list_accessors(), 'tasks': session.tasks.list_tasks(), } @@ -258,4 +257,11 @@ def get_output_accessor_id_for_task(task_id: str) -> str: @app.get('/tasks') def list_tasks() -> Dict[str, TaskInfo]: - return session.tasks.list_tasks() \ No newline at end of file + return session.tasks.list_tasks() + + +@app.get('/phenobase/bounding_box') +def get_phenobase_bounding_boxes() -> list: + if session.phenobase is None: + return [] + return session.phenobase.list_bounding_boxes() \ No newline at end of file diff --git a/model_server/base/models.py b/model_server/base/models.py index f1fccd37d99e237c10747a0a83c30db745c1fbf3..f4064feb451804bee297839fa02c8213ee637243 100644 --- a/model_server/base/models.py +++ b/model_server/base/models.py @@ -4,7 +4,6 @@ import numpy as np from .accessors import GenericImageDataAccessor, PatchStack - class Model(ABC): def __init__(self, autoload: bool = True, info: dict = None): diff --git a/model_server/base/pipelines/router.py b/model_server/base/pipelines/router.py index 5ce7e96fdf384aeebe42eaa7e9289ebe7981db0c..bf430f8c9d14864230043c5d14656c6ee42c036d 100644 --- a/model_server/base/pipelines/router.py +++ b/model_server/base/pipelines/router.py @@ -7,7 +7,7 @@ router = APIRouter( tags=['pipelines'], ) -for m in ['segment', 'roiset_obmap', 'segment_zproj']: +for m in ['segment', 'segment_zproj']: router.include_router( importlib.import_module( f'{__package__}.{m}' diff --git a/model_server/base/pipelines/segment.py b/model_server/base/pipelines/segment.py index 96cb5f19a1e2b2b338203692f3ecfccaf1c06b7c..8604ddd276f9e277c0a654b5bead491ca15eb801 100644 --- a/model_server/base/pipelines/segment.py +++ b/model_server/base/pipelines/segment.py @@ -41,9 +41,9 @@ def segment_pipeline( if not isinstance(model, SemanticSegmentationModel): raise IncompatibleModelsError('Expecting a semantic segmentation model') - if (ch := k.get('channel')) is not None: + if ch := k.get('channel') is not None: d['mono'] = d['input'].get_mono(ch) d['inference'] = model.label_pixel_class(d.last) - if (sm := k.get('smooth')) is not None: + if sm := k.get('smooth') is not None: d['smooth'] = d.last.apply(lambda x: smooth(x, sm)) return d \ No newline at end of file diff --git a/model_server/base/pipelines/shared.py b/model_server/base/pipelines/shared.py index 9518a08a35452edb7d9842242e60f5dd4d01d184..fca3f620f68cf137bf6dd33f5b0aeb4baafe6d24 100644 --- a/model_server/base/pipelines/shared.py +++ b/model_server/base/pipelines/shared.py @@ -7,6 +7,7 @@ from fastapi import HTTPException from pydantic import BaseModel, Field, root_validator from ..accessors import GenericImageDataAccessor, InMemoryDataAccessor +from model_server.rois.roiset import PatchParams, RoiSetExportParams from ..session import session, AccessorIdError @@ -53,6 +54,16 @@ class PipelineQueueRecord(BaseModel): task_id: str def call_pipeline(func, p: PipelineParams) -> Union[PipelineRecord, PipelineQueueRecord]: + """ + Resolve accessor and model objects from their IDs in session, pass them to the given pipeline function, and + register output and (if specified) intermediate accessors in the session. If p.schedule is True, enqueue + this as a task in the session queue; otherwise call it immediately. + :param func: pipeline function with the signature: + func(accessors: Dict[str, GenericImageDataAccessor], models: Dict[str, Model], **k) -> PipelineTrace + :param p: pipeline parameters object + :return: + record objects describing either the results of pipeline execution or the queued task + """ # instead of running right away, schedule pipeline as a task if p.schedule: p.schedule = False @@ -137,7 +148,7 @@ class PipelineTrace(OrderedDict): self.enforce_accessors = enforce_accessors self.allow_overwrite = allow_overwrite self.last_time = self.tfunc() - self.misc = {} + self.markers = None self.timer = OrderedDict() super().__init__() if start_acc is not None: @@ -163,6 +174,7 @@ class PipelineTrace(OrderedDict): def append(self, tr, skip_first=True): new_tr = self.copy() + new_tr.timer = self.timer for k, v in tr.items(): if skip_first and v == tr.first: continue @@ -246,6 +258,34 @@ class PipelineTrace(OrderedDict): return paths +class RoiSetPipelineParams(PipelineParams): + exports: RoiSetExportParams = RoiSetExportParams() + roiset_index: Dict[str, int] + patches: Dict[str, PatchParams] = {} + + +def call_roiset_pipeline(func, p: RoiSetPipelineParams) -> Union[PipelineRecord, PipelineQueueRecord]: + """ + Wraps pipelines.shared.call_pipeline for pipeline functions that return an RoiSet in addition to PipelineTrace. + Upon execution, add the returned RoiSet to session PhenoBase and otherwise delegate to call_pipeline + :param func: pipeline function with the signature: + func(accessors: Dict[str, GenericImageDataAccessor], models: Dict[str, Model], **k) -> (PipelineTrace, RoiSet) + :param p: pipeline parameters object, which must contain a unique index for the RoiSet + :return: + record objects describing either the results of pipeline execution or the queued task + """ + def outer(*args, **kwargs): + trace, rois = func(*args, **kwargs) + session.add_roiset( + roiset=rois, + index_dict=p.roiset_index, + export_params=p.exports, + ) + session.write_phenobase_table() + return trace + return call_pipeline(outer, p) + + class Error(Exception): pass @@ -259,4 +299,4 @@ class NoAccessorsFoundError(Error): pass class UnexpectedPipelineReturnError(Error): - pass + pass \ No newline at end of file diff --git a/model_server/base/process.py b/model_server/base/process.py index d2ba6abf743b01f7df2917699186b8e75b1e2ddb..0be8c8894ed254792ead1ce649e9c98440be53d3 100644 --- a/model_server/base/process.py +++ b/model_server/base/process.py @@ -156,3 +156,13 @@ class TooManyObjectError(Exception): pass +def safe_add(a, g, b): + assert a.dtype == b.dtype + assert a.shape == b.shape + assert g >= 0.0 + + return np.clip( + a.astype('uint32') + g * b.astype('uint32'), + 0, + np.iinfo(a.dtype).max + ).astype(a.dtype) diff --git a/model_server/base/session.py b/model_server/base/session.py index 112c021fa7a8f4b7480fb425a58642146a599181..032270193f1179b9f7f469d3dcdcc11306d441ae 100644 --- a/model_server/base/session.py +++ b/model_server/base/session.py @@ -2,7 +2,7 @@ from collections import OrderedDict import itertools import logging import os -from multiprocessing.managers import Namespace +from typing import Dict import psutil import uuid @@ -17,6 +17,8 @@ import pandas as pd from ..conf import defaults from .accessors import GenericImageDataAccessor from .models import Model +from model_server.rois.phenobase import PhenoBase, RoiSetIndex +from model_server.rois.roiset import RoiSetExportParams, RoiSet logger = logging.getLogger(__name__) @@ -195,35 +197,12 @@ class _Session(object): self.accessor_info = OrderedDict() self.accessor_objects = {} self.tasks = TaskCollection() + self.phenobase = None self.logfile = self.paths['logs'] / f'session.log' logging.basicConfig(filename=self.logfile, level=logging.INFO, force=True, format=self.log_format) self.log_info('Initialized session') - self.tables = {} - - def write_to_table(self, name: str, coords: dict, data: pd.DataFrame): - """ - Write data to a named data table, initializing if it does not yet exist. - :param name: name of the table to persist through session - :param coords: dictionary of coordinates to associate with all rows in this method call - :param data: DataFrame containing data - :return: True if successful - """ - try: - if name in self.tables.keys(): - table = self.tables.get(name) - else: - table = CsvTable(self.paths['tables'] / (name + '.csv')) - self.tables[name] = table - except Exception: - raise CouldNotCreateTable(f'Unable to create table named {name}') - - try: - table.append(coords, data) - return True - except Exception: - raise CouldNotAppendToTable(f'Unable to append data to table named {name}') def get_paths(self): return self.paths @@ -495,6 +474,38 @@ class _Session(object): return mid return None + def add_roiset( + self, + roiset: RoiSet, + index_dict: RoiSetIndex, + export_params: RoiSetExportParams = None, + ): + """ + Add an RoiSet into PhenoBase; initialize PhenoBase if empty + :param roiset: RoiSet object to be added to PhenoBase + :param index_dict: dict that uniquely identifies the RoiSet; keys must match existing ones + :param export_params: (optional) parameters for exporting each RoiSet patch series + return: dict of added RoiSet in PhenoBase if successful + """ + if self.phenobase is None: + self.phenobase = PhenoBase.from_roiset( + root=self.paths['outbound_images'], + roiset=roiset, + index_dict=index_dict, + export_params=export_params, + ) + else: + self.phenobase.push( + roiset, + index_dict=index_dict, + export_params=export_params, + ) + return index_dict + + def write_phenobase_table(self): + if self.phenobase is not None: + return self.phenobase.write_df() + def restart(self, **kwargs): self.__init__() diff --git a/model_server/clients/batch_runner.py b/model_server/clients/batch_runner.py index 6fd3dbd4c3e44de06b8315ba380958e24554e711..439cbeb65f888ce7b1cce7c33e286f4d610f3bef 100644 --- a/model_server/clients/batch_runner.py +++ b/model_server/clients/batch_runner.py @@ -53,6 +53,11 @@ class FileBatchRunnerClient(HttpClient): self.stacks.to_csv(self.conf_root / 'filelist.csv') self.stacks.to_csv(self.local_paths['output'] / 'filelist_copy.csv') + def write_tasks_json(self): + tasks_dict = self.hit('get', 'tasks') + with open(self.local_paths['output'] / 'tasks.json', 'w') as fh: + json.dump(tasks_dict, fh) + def hit(self, method, endpoint, params=None, body=None, catch=True, **kwargs): resp = super(FileBatchRunnerClient, self).hit(method, endpoint, params=params, body=body) if resp.status_code != 200: @@ -144,7 +149,6 @@ class FileBatchRunnerClient(HttpClient): matching_files.append(f.name) return matching_files - files += _append_files_by_pattern(where_local, inp.get('pattern')) is_multiposition = inp.get('multiposition', False) where_remote = Path(self.remote_paths['input']) / inp['directory'] @@ -153,20 +157,27 @@ class FileBatchRunnerClient(HttpClient): for subdir in where_local.iterdir(): if not subdir.is_dir(): continue + if (sdp := inp.get('subdirectory_pattern')) is not None: + if sdp.upper() not in subdir.name.upper(): + continue matches = _append_files_by_pattern(subdir, inp.get('pattern')) files += [f'{subdir}/{f}' for f in matches] + else: + files += _append_files_by_pattern(where_local, inp.get('pattern')) - def _get_file_info(filename): + def _get_file_info(fpath_str): info = { - 'remote_path': (where_remote / filename).as_posix(), - 'local_path': where_local / filename, + 'remote_path': (where_remote / fpath_str).as_posix(), + 'local_path': where_local / fpath_str, 'is_multiposition': is_multiposition, } + filename = Path(fpath_str).name if (coord_regex := inp.get('coord_regex')) is not None: for coord_k, coord_v in re.search(coord_regex, filename).groupdict().items(): - if coord_k.lower() not in ['well', 'position', 'time']: - raise InvalidStackCoordinateKeyError(f'Cannot interpret coordinate {coord_k}') - info[f'coord_{coord_k.lower()}'] = int(coord_v) + if coord_k.lower() in ['well', 'position', 'time', 'date']: + info[f'coord_{coord_k.lower()}'] = int(coord_v) + else: + info[f'coord_{coord_k.lower()}'] = coord_v return info paths = paths + [_get_file_info(f) for f in files] if max_count is not None: @@ -225,6 +236,8 @@ class FileBatchRunnerClient(HttpClient): v['body'][pki] = output_acc_ids[tki] else: # input is the output of the last task v['body']['accessor_id'] = list(output_acc_ids.values())[-1] + if cpi := v.get('stack_info_to_param'): + v['body'][cpi] = {'stack': stack.name, 'position': stack.position} task_id = self.hit(**v, catch=not self.debug)['task_id'] v['task_id'] = task_id @@ -258,6 +271,7 @@ class FileBatchRunnerClient(HttpClient): if all(df_tasks[df_tasks.stack_index == stack_index].complete): self.stacks.loc[stack_index, 'all_tasks_complete'] = True self.write_df() + self.write_tasks_json() self.hit('put', f'tasks/delete_accessors') diff --git a/model_server/clients/phenobase.py b/model_server/clients/phenobase.py new file mode 100644 index 0000000000000000000000000000000000000000..fea82ea5c0fb3d2f4a0762de1853cb1b48f4b108 --- /dev/null +++ b/model_server/clients/phenobase.py @@ -0,0 +1,32 @@ +from pathlib import Path + +import pandas as pd + +from model_server.rois.phenobase import PhenoBase + +from model_server.clients.batch_runner import FileBatchRunnerClient + +class MakePhenoBaseClient(FileBatchRunnerClient): + + def __init__(self, *args, labels_csv: Path = None, **kwargs): + """ + Create a batch runner client that outputs a PhenoBase in its output directory + :param args: + :param labels_csv: + optional path to a CSV file that relates object labels to substrings matched on each input stack + :param kwargs: + """ + self.labels_csv_path = labels_csv + return super().__init__(*args, **kwargs) + + def get_stacks(self, *args, **kwargs): + df_stacks = super().get_stacks(*args, **kwargs) + if self.labels_csv_path is not None: + labels = pd.read_csv(self.labels_csv_path) + def get_label(pa): + for row in labels.itertuples(): + if row.pattern in str(pa): + return row.label + return None + df_stacks['category_label'] = df_stacks.local_path.apply(get_label) + return df_stacks \ No newline at end of file diff --git a/model_server/conf/fastapi.py b/model_server/conf/fastapi.py index 53edadd80cda52c565a4d97ca99923bac92996e1..76eed2cc13935cab242269cb71ff7e6a27f6bf63 100644 --- a/model_server/conf/fastapi.py +++ b/model_server/conf/fastapi.py @@ -2,6 +2,6 @@ import importlib from ..base.api import app -for ex in ['ilastik']: - m = importlib.import_module(f'..extensions.{ex}.router', package=__package__) +for ex in ['ilastik', 'rois']: + m = importlib.import_module(f'..{ex}.router', package=__package__) app.include_router(m.router) diff --git a/model_server/conf/testing.py b/model_server/conf/testing.py index 6b05dbe54bb5a0aa0fe92405844120d68e313e27..b4535a5967877d715d173caee742489dc0d1162a 100644 --- a/model_server/conf/testing.py +++ b/model_server/conf/testing.py @@ -4,9 +4,7 @@ import unittest from math import floor from multiprocessing import Process from pathlib import Path -import requests from shutil import copyfile -from urllib3 import Retry from fastapi import APIRouter import numpy as np diff --git a/model_server/extensions/__init__.py b/model_server/ilastik/__init__.py similarity index 100% rename from model_server/extensions/__init__.py rename to model_server/ilastik/__init__.py diff --git a/model_server/extensions/ilastik/__init__.py b/model_server/ilastik/examples/__init__.py similarity index 100% rename from model_server/extensions/ilastik/__init__.py rename to model_server/ilastik/examples/__init__.py diff --git a/model_server/extensions/ilastik/examples/ilastik3d.py b/model_server/ilastik/examples/ilastik3d.py similarity index 100% rename from model_server/extensions/ilastik/examples/ilastik3d.py rename to model_server/ilastik/examples/ilastik3d.py diff --git a/model_server/extensions/ilastik/models.py b/model_server/ilastik/models.py similarity index 97% rename from model_server/extensions/ilastik/models.py rename to model_server/ilastik/models.py index e57f9726f4b43cf7c3c1a17f62c8cc2fc0c89ce1..2167fbbcd923d6552a60ad992dc41aec59d365b9 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/ilastik/models.py @@ -7,9 +7,9 @@ import warnings import numpy as np import vigra -from ...base.accessors import PatchStack -from ...base.accessors import GenericImageDataAccessor, InMemoryDataAccessor -from ...base.models import Model, ImageToImageModel, InstanceMaskSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel +from model_server.base.accessors import PatchStack +from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor +from model_server.base.models import Model, ImageToImageModel, InstanceMaskSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel class IlastikModel(Model): diff --git a/model_server/extensions/ilastik/examples/__init__.py b/model_server/ilastik/pipelines/__init__.py similarity index 100% rename from model_server/extensions/ilastik/examples/__init__.py rename to model_server/ilastik/pipelines/__init__.py diff --git a/model_server/extensions/ilastik/pipelines/px_then_ob.py b/model_server/ilastik/pipelines/px_then_ob.py similarity index 91% rename from model_server/extensions/ilastik/pipelines/px_then_ob.py rename to model_server/ilastik/pipelines/px_then_ob.py index 335e4115d0b9d0cdb2f9e405bfa402e88efc5a36..b9d210b6c16555d24c07ec8c55fbe337ddb99259 100644 --- a/model_server/extensions/ilastik/pipelines/px_then_ob.py +++ b/model_server/ilastik/pipelines/px_then_ob.py @@ -3,9 +3,9 @@ from typing import Dict from fastapi import APIRouter, HTTPException from pydantic import Field -from ....base.accessors import GenericImageDataAccessor -from ....base.models import Model -from ....base.pipelines.shared import call_pipeline, PipelineTrace, PipelineParams, PipelineRecord +from model_server.base.accessors import GenericImageDataAccessor +from model_server.base.models import Model +from model_server.base.pipelines.shared import call_pipeline, PipelineTrace, PipelineParams, PipelineRecord from ..models import IlastikPixelClassifierModel, IlastikObjectClassifierFromPixelPredictionsModel diff --git a/model_server/extensions/ilastik/router.py b/model_server/ilastik/router.py similarity index 92% rename from model_server/extensions/ilastik/router.py rename to model_server/ilastik/router.py index c21db3fe871759a01847ab0acef7edd659de3426..22cd7500151fec1180aeeb2ef7abbb7ce70663da 100644 --- a/model_server/extensions/ilastik/router.py +++ b/model_server/ilastik/router.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field from model_server.base.session import session -from model_server.extensions.ilastik import models as ilm +from model_server.ilastik import models as ilm router = APIRouter( prefix='/ilastik', @@ -13,8 +13,8 @@ router = APIRouter( ) -import model_server.extensions.ilastik.pipelines.px_then_ob -router.include_router(model_server.extensions.ilastik.pipelines.px_then_ob.router) +import model_server.ilastik.pipelines.px_then_ob +router.include_router(model_server.ilastik.pipelines.px_then_ob.router) class IlastikParams(BaseModel): diff --git a/model_server/rois/__init__.py b/model_server/rois/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model_server/rois/derived.py b/model_server/rois/derived.py new file mode 100644 index 0000000000000000000000000000000000000000..ae821c62fe3708a54ce8747bf4ba16f1e9b64c45 --- /dev/null +++ b/model_server/rois/derived.py @@ -0,0 +1,90 @@ +from pathlib import Path + +import numpy as np +import pandas as pd + +from model_server.base.models import InstanceMaskSegmentationModel +from model_server.base.process import mask_largest_object +from model_server.rois.roiset import RoiSetExportParams, RoiSet + + +class RoiSetWithDerivedChannelsExportParams(RoiSetExportParams): + derived_channels: bool = False + + +class RoiSetWithDerivedChannels(RoiSet): + + def __init__(self, *a, **k): + self.accs_derived = [] + super().__init__(*a, **k) + + def classify_by( + self, name: str, channels: list[int], + object_classification_model: InstanceMaskSegmentationModel, + derived_channel_functions: list[callable] = None + ): + """ + Insert a column in RoiSet data table that associates each ROI with an integer class, determined by passing + specified inputs through an instance segmentation classifier. Derive additional inputs for object + classification by passing a raw input channel through one or more functions. + + :param name: name of column to insert + :param channels: list of nc raw input channels to send to classifier + :param object_classification_model: InstanceSegmentation model object + :param derived_channel_functions: list of functions that each receive a PatchStack accessor with nc channels and + that return a single-channel PatchStack accessor of the same shape + :return: None + """ + + acc_in = self.get_patches_acc(channels=channels, expanded=False, pad_to=None) + if derived_channel_functions is not None: + for fcn in derived_channel_functions: + der = fcn(acc_in) # returns patch stack + self.accs_derived.append(der) + + # combine channels + acc_app = acc_in + for acc_der in self.accs_derived: + acc_app = acc_app.append_channels(acc_der) + + else: + acc_app = acc_in + + # do this on a patch basis, i.e. only one object per frame + obmap_patches = object_classification_model.label_patch_stack( + acc_app, + self.get_patch_masks_acc(expanded=False, pad_to=None) + ) + + self._df['classify_by_' + name] = pd.Series(dtype='Int16') + + se = pd.Series(dtype='Int64', index=self._df.index) + + for i, la in enumerate(self._df.index): + oc = np.unique( + mask_largest_object( + obmap_patches.iat(i).data_yxz + ) + )[-1] + se[la] = oc + self.set_classification(name, se) + + def run_exports(self, where: Path, prefix, params: RoiSetWithDerivedChannelsExportParams) -> dict: + """ + Export various representations of ROIs, e.g. patches, annotated stacks, and object maps. + :param where: path of directory in which to write all export products + :param prefix: prefix of the name of each product's file or subfolder + :param params: RoiSetExportParams object describing which products to export and with which parameters + :return: nested dict of Path objects describing the location of export products + """ + record = super().run_exports(where, prefix, params) + + k = 'derived_channels' + if k in params.dict().keys(): + record[k] = [] + for di, dacc in enumerate(self.accs_derived): + fp = where / k / f'dc{di:01d}.tif' + fp.parent.mkdir(exist_ok=True, parents=True) + dacc.export_pyxcz(fp) + record[k].append(str(fp)) + return record diff --git a/model_server/rois/df.py b/model_server/rois/df.py new file mode 100644 index 0000000000000000000000000000000000000000..5f953a5714403f4bebb890232422d9d60ca2c2b4 --- /dev/null +++ b/model_server/rois/df.py @@ -0,0 +1,210 @@ +import itertools +from math import sqrt +from pathlib import Path + +import numpy as np +import pandas as pd + +def filter_df(df: pd.DataFrame, filters: dict = {}) -> pd.DataFrame: + query_str = 'label > 0' # always true + if filters is not None: # parse filters + for k, val in filters.items(): + assert k in ('area', 'diag', 'min_hw') + if val is None: + continue + vmin = val['min'] + vmax = val['max'] + assert vmin >= 0 + query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}' + return df.loc[df.bounding_box.query(query_str).index, :] + + +def filter_df_overlap_bbox(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.DataFrame: + """ + If passed a single DataFrame, return the subset whose bounding boxes overlap in 3D space. If passed two DataFrames, + return the subset where a ROI in the first overlaps a ROI in the second. May return duplicates entries where a ROI + overlaps with multiple neighbors. + :param df1: DataFrame with potentially overlapping bounding boxes + :param df2: (optional) second DataFrame + :return DataFrame describing subset of overlapping ROIs + bbox_overlaps_with: index of ROI that overlaps + bbox_intersec: pixel area of intersecting region + """ + + def _compare(r0, r1): + olx = (r0.x0 < r1.x1) and (r0.x1 > r1.x0) + oly = (r0.y0 < r1.y1) and (r0.y1 > r1.y0) + olz = (r0.zi == r1.zi) + return olx and oly and olz + + def _intersec(r0, r1): + return (r0.x1 - r1.x0) * (r0.y1 - r1.y0) + + first = [] + second = [] + intersec = [] + + if df2 is not None: + for pair in itertools.product(df1.index, df2.index): + if _compare( + df1.bounding_box.loc[pair[0]], + df2.bounding_box.loc[pair[1]] + ): + first.append(pair[0]) + second.append(pair[1]) + intersec.append( + _intersec( + df1.bounding_box.loc[pair[0]], + df2.bounding_box.loc[pair[1]] + ) + ) + else: + for pair in itertools.combinations(df1.index, 2): + if _compare( + df1.bounding_box.loc[pair[0]], + df1.bounding_box.loc[pair[1]] + ): + first.append(pair[0]) + second.append(pair[1]) + first.append(pair[1]) + second.append(pair[0]) + isc = _intersec( + df1.bounding_box.loc[pair[0]], + df1.bounding_box.loc[pair[1]] + ) + intersec.append(isc) + intersec.append(isc) + + sdf = df1.bounding_box.loc[first] + sdf.loc[:, 'overlaps_with'] = second + sdf.loc[:, 'bbox_intersec'] = intersec + return sdf + + +def filter_df_overlap_seg(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.DataFrame: + """ + If passed a single DataFrame, return the subset whose segmentations overlap in 3D space. If passed two DataFrames, + return the subset where a ROI in the first overlaps a ROI in the second. May return duplicates entries where a ROI + overlaps with multiple neighbors. + :param df1: DataFrame with potentially overlapping bounding boxes + :param df2: (optional) second DataFrame + :return DataFrame describing subset of overlapping ROIs + seg_overlaps_with: index of ROI that overlaps + seg_intersec: pixel area of intersecting region + seg_iou: intersection over union + """ + + dfbb = filter_df_overlap_bbox(df1, df2) + + def _overlap_seg(r): + roi1 = df1.loc[r.name] + if df2 is not None: + roi2 = df2.loc[r.overlaps_with] + else: + roi2 = df1.loc[r.overlaps_with] + bb1 = roi1.bounding_box + bb2 = roi2.bounding_box + ex0 = min(bb1.x0, bb2.x0, bb1.x1, bb2.x1) + ew = max(bb1.x0, bb2.x0, bb1.x1, bb2.x1) - ex0 + ey0 = min(bb1.y0, bb2.y0, bb1.y1, bb2.y1) + eh = max(bb1.y0, bb2.y0, bb1.y1, bb2.y1) - ey0 + emask = np.zeros((eh, ew), dtype='uint8') + sl1 = np.s_[(bb1.y0 - ey0): (bb1.y1 - ey0), (bb1.x0 - ex0): (bb1.x1 - ex0)] + sl2 = np.s_[(bb2.y0 - ey0): (bb2.y1 - ey0), (bb2.x0 - ex0): (bb2.x1 - ex0)] + emask[sl1] = roi1.masks.binary_mask + emask[sl2] = emask[sl2] + roi2.masks.binary_mask + return emask + + emasks = dfbb.apply(_overlap_seg, axis=1) + dfbb['seg_overlaps'] = emasks.apply(lambda x: np.any(x > 1)) + dfbb['seg_intersec'] = emasks.apply(lambda x: (x == 2).sum()) + dfbb['seg_iou'] = emasks.apply(lambda x: (x == 2).sum() / (x > 0).sum()) + return dfbb + + +def is_df_3d(df: pd.DataFrame) -> bool: + return 'z0' in df.bounding_box.columns and 'z1' in df.bounding_box.columns + + +def insert_level(df: pd.DataFrame, name: str): + df.columns = pd.MultiIndex.from_product( + [ + [name], + list(df.columns.values), + ], + ) + + +def read_roiset_df(csv_path: Path) -> pd.DataFrame: + return pd.read_csv(csv_path, header=[0, 1], index_col=0) + + +def df_insert_slices(df: pd.DataFrame, sd: dict, expand_box_by) -> pd.DataFrame: + h = sd['Y'] + w = sd['X'] + nz = sd['Z'] + + bb = 'bounding_box' + df[bb, 'h'] = df[bb, 'y1'] - df[bb, 'y0'] + df[bb, 'w'] = df[bb, 'x1'] - df[bb, 'x0'] + df[bb, 'diag'] = (df[bb, 'w'] ** 2 + df[bb, 'h'] ** 2).apply(sqrt) + df[bb, 'min_hw'] = df[[[bb, 'w'], [bb, 'h']]].min(axis=1) + + ebxy, ebz = expand_box_by + ebb = 'expanded_bounding_box' + df[ebb, 'ebb_y0'] = (df[bb, 'y0'] - ebxy).apply(lambda x: max(x, 0)) + df[ebb, 'ebb_y1'] = (df[bb, 'y1'] + ebxy).apply(lambda x: min(x, h)) + df[ebb, 'ebb_x0'] = (df[bb, 'x0'] - ebxy).apply(lambda x: max(x, 0)) + df[ebb, 'ebb_x1'] = (df[bb, 'x1'] + ebxy).apply(lambda x: min(x, w)) + + # handle based on whether bounding box coordinates are 2d or 3d + if is_df_3d(df): + df[ebb, 'ebb_z0'] = (df[bb, 'z0'] - ebz).apply(lambda x: max(x, 0)) + df[ebb, 'ebb_z1'] = (df[bb, 'z1'] + ebz).apply(lambda x: max(x, nz)) + else: + if 'zi' not in df.bounding_box.columns: + df[bb, 'zi'] = 0 + df[ebb, 'ebb_z0'] = (df[bb, 'zi'] - ebz).apply(lambda x: max(x, 0)) + df[ebb, 'ebb_z1'] = (df[bb, 'zi'] + ebz).apply(lambda x: min(x, nz)) + + df[ebb, 'ebb_h'] = df[ebb, 'ebb_y1'] - df[ebb, 'ebb_y0'] + df[ebb, 'ebb_w'] = df[ebb, 'ebb_x1'] - df[ebb, 'ebb_x0'] + df[ebb, 'ebb_nz'] = df[ebb, 'ebb_z1'] - df[ebb, 'ebb_z0'] + 1 + + # compute relative bounding boxes + rbb = 'relative_bounding_box' + df[rbb, 'rel_y0'] = df[bb, 'y0'] - df[bb, 'y0'] + df[rbb, 'rel_y1'] = df[bb, 'y1'] - df[bb, 'y0'] + df[rbb, 'rel_x0'] = df[bb, 'x0'] - df[bb, 'x0'] + df[rbb, 'rel_x1'] = df[bb, 'x1'] - df[bb, 'x0'] + + assert np.all(df[rbb, 'rel_x1'] <= (df[ebb, 'ebb_x1'] - df[ebb, 'ebb_x0'])) + assert np.all(df[rbb, 'rel_y1'] <= (df[ebb, 'ebb_y1'] - df[ebb, 'ebb_y0'])) + + if is_df_3d(df): + df['slices', 'slice'] = df['bounding_box'].apply( + lambda r: + np.s_[int(r.y0): int(r.y1), int(r.x0): int(r.x1), :, int(r.z0): int(r.z1)], + axis=1, + result_type='reduce', + ) + else: + df['slices', 'slice'] = df['bounding_box'].apply( + lambda r: + np.s_[int(r.y0): int(r.y1), int(r.x0): int(r.x1), :, int(r.zi): int(r.zi + 1)], + axis=1, + result_type='reduce', + ) + df['slices', 'expanded_slice'] = df['expanded_bounding_box'].apply( + lambda r: + np.s_[int(r.ebb_y0): int(r.ebb_y1), int(r.ebb_x0): int(r.ebb_x1), :, int(r.ebb_z0): int(r.ebb_z1) + 1], + axis=1, + result_type='reduce', + ) + df['slices', 'relative_slice'] = df['relative_bounding_box'].apply( + lambda r: + np.s_[int(r.rel_y0): int(r.rel_y1), int(r.rel_x0): int(r.rel_x1), :, :], + axis=1, + result_type='reduce', + ) + return df diff --git a/model_server/rois/features.py b/model_server/rois/features.py new file mode 100644 index 0000000000000000000000000000000000000000..491764dd3060fb639ca17031505e98d4044040a5 --- /dev/null +++ b/model_server/rois/features.py @@ -0,0 +1,57 @@ +import pandas as pd +from skimage.measure import regionprops_table + +from .roiset import RoiSet + +def regionprops(rois: RoiSet, make_3d: bool = False, channel: int = None) -> pd.DataFrame: + props = [ + 'area', + 'area_bbox', + 'area_convex', + 'area_filled', + 'axis_major_length', + 'axis_minor_length', + 'equivalent_diameter_area', + 'euler_number', + 'extent', + 'feret_diameter_max', + 'intensity_max', + 'intensity_mean', + 'intensity_min', + 'intensity_std', + 'num_pixels', + 'solidity', + ] + if not make_3d: + props = props + [ + 'eccentricity', + 'orientation', + 'perimeter', + 'perimeter_crofton', + ] + + acc_la = rois.get_patch_obmap_acc(make_3d=make_3d) + acc_in = rois.get_patches_acc(make_3d=make_3d, channels=[channel]) + + for i in range(0, acc_la.count): + def _extract_features(roi): + i = roi['patches', 'index'] + if make_3d: + la_i = acc_la.iat(i).data_yxz + im_i = acc_in.iat(i).data_yxzc + else: + la_i = acc_la.iat(i).data_yx + im_i = acc_in.iat(i).data_yxzc[:, :, 0, :] + return pd.Series( + {k: v[0] for k, v in regionprops_table(la_i, im_i, properties=props).items()} + ) + + dff = rois.df().apply(_extract_features, axis=1) + if channel is not None: + dff.rename( + columns={ + f'intensity_{k}-0': f'intensity_{k}-{channel}' for k in ['max', 'min', 'mean', 'std'] + }, + inplace=True + ) + return dff \ No newline at end of file diff --git a/model_server/rois/labels.py b/model_server/rois/labels.py new file mode 100644 index 0000000000000000000000000000000000000000..2444231c8fca99d06e0bbf83df034a5e255769d8 --- /dev/null +++ b/model_server/rois/labels.py @@ -0,0 +1,206 @@ +from math import sqrt + +import numpy as np +import pandas as pd +from scipy.stats import moment +from skimage.filters import sobel +from skimage.measure import label, shannon_entropy, regionprops_table + +from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor +from model_server.rois.df import filter_df, insert_level, is_df_3d, df_insert_slices + + +def get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, connect_3d=True) -> InMemoryDataAccessor: + """ + Convert binary segmentation mask into either a 2D or 3D object identities map + :param acc_seg_mask: binary segmentation mask (mono) of either two or three dimensions + :param allow_3d: return a 3D map if True; return a 2D map of the mask's maximum intensity project if False + :param connect_3d: objects can span multiple z-positions if True; objects are unique to a single z if False + :return: object identities map + """ + if allow_3d and connect_3d: + nda_la = label( + acc_seg_mask.data_yxz, + connectivity=3, + ).astype('uint16') + return InMemoryDataAccessor(np.expand_dims(nda_la, 2)) + elif allow_3d and not connect_3d: + nla = 0 + la_3d = np.zeros((*acc_seg_mask.hw, 1, acc_seg_mask.nz), dtype='uint16') + for zi in range(0, acc_seg_mask.nz): + la_2d = label( + acc_seg_mask.data_yxz[:, :, zi], + connectivity=2, + ).astype('uint16') + la_2d[la_2d > 0] = la_2d[la_2d > 0] + nla + nla = la_2d.max() + la_3d[:, :, 0, zi] = la_2d + return InMemoryDataAccessor(la_3d) + else: + return InMemoryDataAccessor( + label( + acc_seg_mask.get_mip().data_yx, + connectivity=1, + ).astype('uint16') + ) + + +def focus_metrics(): + return { + 'max_intensity': lambda x: np.max(x), + 'stdev': lambda x: np.std(x), + 'max_sobel': lambda x: np.max(sobel(x)), + 'rms_sobel': lambda x: sqrt(np.mean(sobel(x) ** 2)), + 'entropy': lambda x: shannon_entropy(x), + 'moment': lambda x: moment(x.flatten(), moment=2), + } + + +def make_df_from_object_ids( + acc_raw, + acc_obj_ids, + expand_box_by, + deproject_channel=None, + filters=None, + deproject_intensity_threshold=0.0 +) -> pd.DataFrame: + """ + Build dataframe that associate object IDs with summary stats; + :param acc_raw: accessor to raw image data + :param acc_obj_ids: accessor to map of object IDs + :param expand_box_by: number of pixels to expand bounding box in all directions (without exceeding image boundary) + :param deproject_channel: if objects' z-coordinates are not specified, compute them based on argmax of this channel + :param deproject_intensity_threshold: when deprojecting, round MIP deprojection_channel to zero if below this + threshold (as fraction of full range, 0.0 to 1.0) + :return: pd.DataFrame + """ + # build dataframe of objects, assign z index to each object + + if acc_obj_ids.nz == 1 and acc_raw.nz > 1: # apply deprojection + + if deproject_channel is None or deproject_channel >= acc_raw.chroma or deproject_channel < 0: + if acc_raw.chroma == 1: + deproject_channel = 0 + else: + raise NoDeprojectChannelSpecifiedError( + f'When labeling objects, either their z-coordinates or a valid deprojection channel are required.' + ) + + mono = acc_raw.get_mono(deproject_channel) + intensity_weight = mono.get_mip().data_yx.astype('uint16') + intensity_weight[intensity_weight < (deproject_intensity_threshold * mono.dtype_max)] = 0 + argmax = mono.get_z_argmax().data_yx.astype('uint16') + zi_map = np.stack([ + intensity_weight, + argmax * intensity_weight, + ], axis=-1) + + assert len(zi_map.shape) == 3 + df = pd.DataFrame(regionprops_table( + acc_obj_ids.data_yx, + intensity_image=zi_map, + properties=('label', 'area', 'intensity_mean', 'bbox') + )).rename(columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1'}) + + df['zi'] = (df['intensity_mean-1'] / df['intensity_mean-0']).fillna(0).round().astype('int16') + df = df.drop(['intensity_mean-0', 'intensity_mean-1'], axis=1) + + def _make_binary_mask(r): + acc = InMemoryDataAccessor(acc_obj_ids.data == r.name) + cropped = acc.get_mono(0, mip=True).crop_hw( + (int(r.y0), int(r.x0), int(r.y1 - r.y0), int(r.x1 - r.x0)) + ).data_yx + return cropped + + elif acc_obj_ids.nz == 1 and acc_raw.nz == 1: # purely 2d, no z information in dataframe + df = pd.DataFrame(regionprops_table( + acc_obj_ids.data_yx, + properties=('label', 'area', 'bbox') + )).rename(columns={ + 'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1' + }) + + def _make_binary_mask(r): + acc = InMemoryDataAccessor(acc_obj_ids.data == r.name) + cropped = acc.get_mono(0).crop_hw( + (int(r.y0), int(r.x0), int(r.y1 - r.y0), int(r.x1 - r.x0)) + ).data_yx + return cropped + + else: # purely 3d: objects' median z-coordinates come from arg of max count in object identities map + df = pd.DataFrame(regionprops_table( + acc_obj_ids.data_yxz, + properties=('label', 'area', 'bbox') + )).rename(columns={ + 'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'z0', 'bbox-3': 'y1', 'bbox-4': 'x1', 'bbox-5': 'z1' + }) + + def _get_zi_from_label(r): + r = r.convert_dtypes() + la = r.name + crop = acc_obj_ids.crop_hwd((r.y0, r.x0, r.z0, r.y1 - r.y0, r.x1 - r.x0, r.z1 - r.z0)) + rel_argzmax = crop.apply(lambda x: x == la).get_focus_vector().argmax() + return rel_argzmax + r.z0 + + df['zi'] = df.apply(_get_zi_from_label, axis=1, result_type='reduce') + df['nz'] = df['z1'] - df['z0'] + + def _make_binary_mask(r): + r = r.convert_dtypes() + la = r.name + crop = acc_obj_ids.crop_hwd( + (int(r.y0), int(r.x0), int(r.z0), int(r.y1 - r.y0), int(r.x1 - r.x0), int(r.z1 - r.z0)) + ) + return crop.apply(lambda x: x == la).data_yxz + df = df.set_index('label') + insert_level(df, 'bounding_box') + df = df_insert_slices(df, acc_raw.shape_dict, expand_box_by) + filters_dict = {} if filters is None else filters.dict(exclude_unset=True) + df_fil = filter_df(df, filters_dict) + df_fil['masks', 'binary_mask'] = df_fil.bounding_box.apply( + _make_binary_mask, + axis=1, + result_type='reduce', + ) + return df_fil + + +def make_object_ids_from_df(df: pd.DataFrame, sd: dict) -> InMemoryDataAccessor: + id_mask = np.zeros((sd['Y'], sd['X'], 1, sd['Z']), dtype='uint16') + + if 'binary_mask' not in df.masks.columns: + raise MissingSegmentationError('RoiSet dataframe does not contain segmentation') + + if is_df_3d(df): # use 3d coordinates + def _label_obj(r): + bb = r.bounding_box + sl = np.s_[bb.y0:bb.y1, bb.x0:bb.x1, :, bb.z0:bb.z1] + mask = np.expand_dims(r.masks.binary_mask, 2) + id_mask[sl] = id_mask[sl] + r.name * mask + elif 'zi' in df.bounding_box.columns: + def _label_obj(r): + bb = r.bounding_box + sl = np.s_[bb.y0:bb.y1, bb.x0:bb.x1, :, bb.zi: (bb.zi + 1)] + mask = np.expand_dims(r.masks.binary_mask, (2, 3)) + id_mask[sl] = id_mask[sl] + r.name * mask + else: + def _label_obj(r): + bb = r.bounding_box + sl = np.s_[bb.y0:bb.y1, bb.x0:bb.x1, :] + mask = np.expand_dims(r.masks.binary_mask, (2, 3)) + id_mask[sl] = id_mask[sl] + r.name * mask + + df.apply(_label_obj, axis=1) + return InMemoryDataAccessor(id_mask) + + +class Error(Exception): + pass + + +class NoDeprojectChannelSpecifiedError(Error): + pass + + +class MissingSegmentationError(Error): + pass diff --git a/model_server/rois/models.py b/model_server/rois/models.py new file mode 100644 index 0000000000000000000000000000000000000000..51cfe598978066dcf28fdff27993366e0180c5d9 --- /dev/null +++ b/model_server/rois/models.py @@ -0,0 +1,71 @@ +import numpy as np +import pandas as pd +from skimage.measure import regionprops_table + +from model_server.base.accessors import GenericImageDataAccessor, PatchStack, InMemoryDataAccessor +from model_server.base.models import InstanceMaskSegmentationModel +from model_server.rois.labels import get_label_ids + + +class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationModel): + def __init__(self, tr: float = 0.5): + """ + Model that labels all objects as class 1 if the intensity in specified channel exceeds a threshold; labels all + objects as class 1 if threshold = 0.0 + :param tr: threshold in range of 0.0 to 1.0; model handles normalization to full pixel intensity range + :param channel: channel to use for thresholding + """ + self.tr = tr + self.loaded = self.load() + super().__init__(info={'tr': tr}) + + def load(self): + return True + + def infer( + self, + img: GenericImageDataAccessor, + mask: GenericImageDataAccessor, + allow_3d: bool = False, + connect_3d: bool = True, + ) -> GenericImageDataAccessor: + if img.chroma != 1: + raise ShapeMismatchError( + f'IntensityThresholdInstanceMaskSegmentationModel expects 1 channel but received {img.chroma}' + ) + if isinstance(img, PatchStack): # assume one object per patch + df = img.get_object_df(mask) + om = np.zeros(mask.shape, 'uint16') + def _label_patch_class(la): + om[la] = (mask.iat(la).data > 0) * 1 + df.loc[df['intensity_mean'] > (self.tr * img.dtype_max), 'label'].apply(_label_patch_class) + return PatchStack(om) + else: + labels = get_label_ids(mask) + df = pd.DataFrame(regionprops_table( + labels.data_yxz, + intensity_image=img.data_yxz, + properties=('label', 'area', 'intensity_mean') + )) + + om = np.zeros(labels.shape, labels.dtype) + + def _label_object_class(la): + om[labels.data == la] = 1 + df.loc[df['intensity_mean'] > (self.tr * img.dtype_max), 'label'].apply(_label_object_class) + + return InMemoryDataAccessor(om) + + def label_instance_class( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs + ) -> GenericImageDataAccessor: + super().label_instance_class(img, mask, **kwargs) + return self.infer(img, mask) + + +class Error(Exception): + pass + + +class ShapeMismatchError(Error): + pass \ No newline at end of file diff --git a/model_server/rois/phenobase.py b/model_server/rois/phenobase.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c2f2e91ce24febbe1bca773409f0ecad909e86 --- /dev/null +++ b/model_server/rois/phenobase.py @@ -0,0 +1,342 @@ +import re +from pathlib import Path +from typing import Dict, Union + +import pandas as pd +from sklearn.model_selection import train_test_split + +from model_server.base.accessors import PatchStack, GenericImageFileAccessor +from model_server.rois.roiset import PatchParams, RoiSet, RoiSetExportParams + +class RoiSetIndex(dict): + """ + Dictionary that identifies each unique RoiSet in the PhenoBase. + """ + pass + +class PhenoBase(object): + def __init__( + self, + dataframe: pd.DataFrame, + root: Path, + info: Dict[str, str], + roiset_subdir='phenobase', + roiset_df_paths: list[Path] = None, + ): + """ + Create a database containing one or more RoiSets and their associated data + :param dataframe: DataFrame object containing RoiSet information indexed by a unique RoiSet identifier + :param root: top-level directory containing a directory with RoiSets + :param info: about patch subdirectories, filters, etc.; passed to methods that create derivative PhenoBases + :param roiset_subdir: name of subdirectory that contains serialized RoiSets + :param roiset_df_paths: list of CSV files that individually define an Roiset + """ + self.df = dataframe + self.root = root + self.info = info + self.roiset_root_path = root / roiset_subdir + self.roiset_df_paths = roiset_df_paths + + assert len(self.list_patch_series()) > 0, f'No patch collections found in {self.roiset_root_path}' + for patch_series in self.list_patch_series(): + patches_path = self.roiset_root_path / patch_series + self.verify_patches( + self.df, + patches_path.name, + self.info['mask_patch_series'], + self.roiset_root_path + ) + + def update(self): + """ + Scan the target directory specified in .read() for additional entries and update accordingly. + """ + list_df_paths = [] + for pa in (self.roiset_root_path / 'dataframe').iterdir(): + if pa.suffix.upper() == '.CSV' and pa not in self.roiset_df_paths: + list_df_paths.append(pa) + self.roiset_df_paths.append(pa) + new_roisets = self._read_roiset_dfs(list_df_paths, max_stack_count=None, df_stacks=None) + + if len(new_roisets) == 0: + return + + for patch_series in self.list_patch_series(): + patches_path = self.roiset_root_path / patch_series + self.verify_patches( + pd.concat(new_roisets), + patches_path.name, + self.info['mask_patch_series'], + self.roiset_root_path + ) + + self.df = pd.concat([ + self.df, + *new_roisets, + ]) + + def push( + self, + roiset: RoiSet, + index_dict: RoiSetIndex, + csv_prefix: str = None, + export_params: RoiSetExportParams = None, + ): + """ + Push an RoiSet onto a PhenoBase based on a unique key + :param roiset: RoiSet object to be added to PhenoBase + :param index_dict: dict that uniquely identifies the RoiSet; keys must match existing ones + :param csv_prefix: (optional) prefix of newly created CSV file + :param export_params: (optional) parameters for exporting RoiSet patch series and other products + """ + # validate that index of RoiSet matches PhenoBase's schema and is unique + index_str_parts = [f'{k}{v:04d}' for k, v in index_dict.items()] + if csv_prefix is not None: + index_str_parts.insert(0, csv_prefix) + index_str = '-'.join(index_str_parts) + for k in self.roiset_index: + if k not in self.roiset_index: + raise RoiSetIndexError(f'RoiSet index {k} not recognized in existing PhenoBase') + if all([v in self.df.index.to_frame()[f'coord_{k}'] for k, v in index_dict.items()]): + raise RoiSetIndexError(f'RoiSet index {index_dict} already exists in PhenoBase') + + # validate that patch sets are specified in RoiSet export parameters before attempting to export them + if (export_params is None or len(export_params.patches) == 0) and self.list_patch_series() is not None: + raise MissingPatchSeriesError(f'RoiSet needs to specify patch export parameters: {self.list_patch_series()}') + missing_patch_series = [ + psk for psk in self.list_patch_series() if psk.split('patches_')[-1] not in export_params.patches.keys() + ] + if any(missing_patch_series): + raise MissingPatchSeriesError(f'Could not find patch series: {missing_patch_series}') + + # export RoiSet products and then re-scan target directory to update PhenoBase + roiset.run_exports( + self.roiset_root_path, + prefix=index_str, + params=export_params, + ) + self.update() + + @staticmethod + def _read_roiset_dfs(paths, max_stack_count, df_stacks=None): + list_df = [] + for pa in paths[0: max_stack_count]: + df_i = pd.read_csv(pa, header=[0, 1], index_col=0) + df_i.index.name = 'label' + roiset_index = {m[0]: int(m[1]) for m in re.findall(r'-*([a-zA-Z]+)(\d+)', pa.stem)} + if len(roiset_index) == 0: + raise RoiSetIndexError(f'No unique RoiSet identifier found in {pa.name}') + + for k, v in roiset_index.items(): + df_i.insert(0, f'coord_{k}', v) + + # transfer metadata from table of stack, if this is specified + if df_stacks is not None: + df_i['paths', 'input_file'] = df_stacks['remote_path'].loc[roiset_index] + + # if ROIs are labeled the same in their input stacks, transfer this to ROI table + if 'category_label' in df_stacks.columns: + df_i['annotations', 'category_label'] = df_stacks['category_label'].loc[roiset_index] + df_i['annotations', 'category_id'] = df_stacks['category_id'].loc[roiset_index] + + # transfer in object classification results, if relevant + if 'classifications' in df_i.columns: + classifications = df_i.classifications.columns + if len(classifications) > 1 or ('annotations', 'category_label') in df_i.columns: + raise AnnotationColumnsError('PhenoBase can only contain one annotation or classification column') + if len(classifications) == 1: + df_i.rename(columns={classifications[0]: 'category_id'}, inplace=True) + df_i['classifications', 'category_label'] = df_i['classifications', 'category_id'] + df_i.set_index([f'coord_{k}' for k in roiset_index.keys()] + [df_i.index], inplace=True) + list_df.append(df_i) + return list_df + + @classmethod + def from_roiset( + cls, + root: Path, + roiset: RoiSet, + index_dict: RoiSetIndex, + roiset_subdir='phenobase', + export_params: RoiSetExportParams = None, + ): + # serialize RoiSet + index_str = '-'.join([f'{k}{v:04d}' for k, v in index_dict.items()]) + roiset.run_exports( + root / roiset_subdir, + prefix=index_str, + params=export_params + ) + return cls.read(root, roiset_subdir) + + @classmethod + def read( + cls, + root: Path, + roiset_subdir='phenobase', + mask_patch_series: str = 'tight_patch_masks', + max_stack_count: Union[int, None] = None, + input_files_csv: str = None, + ): + """ + Automatically read serialized RoiSet from the specified directory and create a PhenoBase. The resulting table + parses its (generally multicolumn) index from RoiSet CSV filenames in the /dataframe subdirectory: + roiset-x000-y001-t002 results in the index {x=0, y=1, t=2} + and so on. + :param root: top-level directory containing a directory with RoiSets + :param roiset_subdir: name of subdirectory that contains serialized RoiSets + :param mask_patch_series: name of subdirectory containing patch masks needed to deserialize each RoiSet + :param max_stack_count: read only the specified number of RoiSets, or read all if None + :param input_files_csv: (optional) the name of a file with supplementary information about RoiSets' input files + :return: PhenoBase object + """ + # concat single-stack dfs + where_dfs = root / roiset_subdir / 'dataframe' + roiset_df_paths = [pa for pa in where_dfs.iterdir() if pa.suffix.upper() == '.CSV'] + n_df = len(roiset_df_paths) + assert n_df > 0, f'No patch dataframes found in {where_dfs}' + if max_stack_count is None: + max_stack_count = n_df + + if input_files_csv is not None: + df_stacks = pd.read_csv(root / input_files_csv) + else: + df_stacks = None + + df_concat = pd.concat( + cls._read_roiset_dfs( + roiset_df_paths, + max_stack_count=max_stack_count, + df_stacks=df_stacks, + ) + ) + + return cls( + df_concat, + root, + {'mask_patch_series': mask_patch_series}, + roiset_df_paths=roiset_df_paths, + ) + + + + def list_patch_series(self): + return [x.name for x in self.roiset_root_path.iterdir() if x.is_dir() and x.name.startswith('patches_')] + + @property + def count(self): + return len(self.df) + + def chroma(self, patch_series: str): + return self.sample(1).get_raw_patchstack(patch_series).chroma + + def filter(self, filters) -> pd.DataFrame: + if filters is None: + self.info['filters'] = None + return self + else: # parse filters + query_str = 'label > 0' + for k, val in filters.items(): + assert k in ('area', 'diag', 'min_hw', 'diag') + if val is None: + continue + vmin = val['min'] + vmax = val['max'] + assert vmin >= 0 + query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}' + info = {'filters': query_str, **self.info} + return self.__class__(self.df.query(query_str), self.root, info) + + def sample(self, n_max: int): + if n_max is None or n_max >= self.count: + return self + else: + return self.__class__(self.df.sample(int(n_max)), self.root, self.info) + + def get_rows(self, rows: pd.Index) -> pd.DataFrame: + return self.__class__(self.df.loc[rows], self.root, self.info) + + def write_df(self): + return self.df.reset_index( + col_level=1 + ).to_csv( + self.root / 'phenobase.csv', + index=False + ) + + def get_raw_patchstack(self, patch_series: str) -> PatchStack: + def _get_patch(roi): + pa = getattr(roi.patches, f'{patch_series}_path') + fp = self.roiset_root_path / pa + acc = GenericImageFileAccessor.read(fp) + return acc.data + return PatchStack(self.df.apply(_get_patch, axis=1).to_list(), force_ydim_longest=True) + + def get_patch_masks(self) -> PatchStack: + patch_mask_column = self.info['mask_patch_series'] + '_path' + def _get_mask(roi): + fp = self.roiset_root_path / getattr(roi.patches, patch_mask_column) + return GenericImageFileAccessor.read(fp).data + return PatchStack(self.df.apply(_get_mask, axis=1).to_list(), force_ydim_longest=True) + + def get_patch_obmaps(self) -> PatchStack: + if ('classifications', 'category_id') not in self.df.columns: + raise AnnotationColumnsError( + 'PhenoBase does not have an annotation column with which to generate object maps' + ) + patch_mask_column = self.info['mask_patch_series'] + '_path' + def _get_obmap(roi): + fp = self.roiset_root_path / getattr(roi.patches, patch_mask_column) + nda = GenericImageFileAccessor.read(fp).data + return ((nda > 0) * roi.classifications.category_id).astype('uint8') + return PatchStack(self.df.apply(_get_obmap, axis=1).to_list(), force_ydim_longest=True) + + @staticmethod + def verify_patches(df, patch_series: str, mask_patch_series: str, root: PatchStack): + # verify patch files and put paths in column + all_exist = {} + for k in [f'{patch_series}_path', f'{mask_patch_series}_path']: + df['patches', f'{k}_exists'] = df.patches[k].apply(lambda x: (root / x).exists()) + all_exist[k] = df['patches', f'{k}_exists'].all() + assert all(all_exist), f'Could not verify that all patches in {patch_series} exist' + + @property + def labels(self): + if 'classifications' not in self.df.columns: + return None + dfcl = self.df.classifications + if 'category_id' not in dfcl.columns and 'category_label' not in dfcl.columns: + return None + labels_dict = {g[0][0]: g[0][1] for g in dfcl.groupby(['category_id', 'category_label'])} + return [labels_dict.get(i, 'undefined') for i in range(1, 1 + max(labels_dict.keys()))] + + def split(self, test_size=0.3) -> Dict[str, pd.Series]: + tr_rows, te_rows = train_test_split(self.df.index, test_size=test_size) + + return { + 'train': self.__class__(self.df.loc[tr_rows], self.root, {'split': 'train', **self.info}), + 'test': self.__class__(self.df.loc[te_rows], self.root, {'split': 'test', **self.info}) + } + + @property + def roiset_index(self) -> list: + return [n for n in self.df.index.names if n != 'label'] + + def list_bounding_boxes(self): + return self.df.bounding_box.reset_index().to_dict(orient='records') + + +class Error(Exception): + pass + +class InvalidPatchStackError(Error): + pass + +class RoiSetIndexError(Error): + pass + +class AnnotationColumnsError(Error): + pass + +class MissingPatchSeriesError(Error): + pass diff --git a/model_server/rois/pipelines/__init__.py b/model_server/rois/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model_server/rois/pipelines/add_roiset.py b/model_server/rois/pipelines/add_roiset.py new file mode 100644 index 0000000000000000000000000000000000000000..1b0ce415adc974d1d4fa2a5f0fb62aa3d650aa2d --- /dev/null +++ b/model_server/rois/pipelines/add_roiset.py @@ -0,0 +1,62 @@ +from typing import Dict, Union + +from fastapi import APIRouter +from pydantic import Field + +from model_server.base.accessors import GenericImageDataAccessor +from model_server.base.models import Model +from model_server.base.pipelines.shared import call_roiset_pipeline, RoiSetPipelineParams, PipelineQueueRecord, PipelineRecord, PipelineTrace +from model_server.rois.roiset import RoiSet, RoiSetMetaParams + + +class AddRoiSetParams(RoiSetPipelineParams): + + accessor_id: str = Field( + description='ID of raw data to use in RoiSet' + ) + labels_accessor_id: str = Field( + description='ID of label mask to use in RoiSet' + ) + roi_params: RoiSetMetaParams = RoiSetMetaParams(**{ + 'mask_type': 'boxes', + 'filters': { + 'area': {'min': 1e3, 'max': 1e8} + }, + 'expand_box_by': [0, 0], + 'deproject_channel': None, + }) + +class SegToRoiSetRecord(PipelineRecord): + pass + + +router = APIRouter() + + +@router.put('/add_roiset') +def seg_to_roiset(p: AddRoiSetParams) -> Union[SegToRoiSetRecord, PipelineQueueRecord]: + """ + Compute a RoiSet from 2d segmentation, apply to z-stack, and optionally apply object classification. + """ + return call_roiset_pipeline(add_roiset_pipeline, p) + + +def add_roiset_pipeline( + accessors: Dict[str, GenericImageDataAccessor], + models: Dict[str, Model], + **k +) -> PipelineTrace: + d = PipelineTrace(accessors['']) + d['labels'] = accessors['labels_'] + rois = RoiSet.from_object_ids( + d['input'], + d['labels'], + RoiSetMetaParams(**k['roi_params']) + ) + + for patch_series, patch_params in k['patches'].items(): + d[patch_series] = rois.get_patches_acc(**patch_params) + d['patch_masks'] = rois.get_patch_masks_acc() + + + return d, rois \ No newline at end of file diff --git a/model_server/base/pipelines/roiset_obmap.py b/model_server/rois/pipelines/roiset_obmap.py similarity index 84% rename from model_server/base/pipelines/roiset_obmap.py rename to model_server/rois/pipelines/roiset_obmap.py index 57e7f6d8679acef37756d1b81029af42f5993d43..2f69cf3434db7d2cf1a0d1f97740be07952e190c 100644 --- a/model_server/base/pipelines/roiset_obmap.py +++ b/model_server/rois/pipelines/roiset_obmap.py @@ -3,14 +3,15 @@ from typing import Dict, Union from fastapi import APIRouter from pydantic import BaseModel, Field -from ..accessors import GenericImageDataAccessor -from .segment_zproj import segment_zproj_pipeline -from .shared import call_pipeline -from ..roiset import get_label_ids, RoiSet, RoiSetMetaParams, RoiSetExportParams +from model_server.base.accessors import GenericImageDataAccessor +from model_server.base.pipelines.segment_zproj import segment_zproj_pipeline +from model_server.base.pipelines.shared import call_pipeline +from model_server.rois.roiset import RoiSet, RoiSetMetaParams, RoiSetExportParams +from model_server.rois.labels import get_label_ids -from ..pipelines.shared import PipelineQueueRecord, PipelineTrace, PipelineParams, PipelineRecord +from model_server.base.pipelines.shared import PipelineQueueRecord, PipelineTrace, PipelineParams, PipelineRecord -from ..models import Model, InstanceMaskSegmentationModel +from model_server.base.models import Model, InstanceMaskSegmentationModel class RoiSetObjectMapParams(PipelineParams): class _SegmentationParams(BaseModel): @@ -60,7 +61,7 @@ class RoiSetToObjectMapRecord(PipelineRecord): router = APIRouter() -@router.put('/roiset_to_obmap/infer') +@router.put('/roiset_to_obmap') def roiset_object_map(p: RoiSetObjectMapParams) -> Union[RoiSetToObjectMapRecord, PipelineQueueRecord]: """ Compute a RoiSet from 2d segmentation, apply to z-stack, and optionally apply object classification. diff --git a/model_server/base/roiset.py b/model_server/rois/roiset.py similarity index 57% rename from model_server/base/roiset.py rename to model_server/rois/roiset.py index 9961de249cf5ecf990325fec366c9c1944072cb7..1d833dfa8ccb63b8a876867dd6d4a3d5112dfe43 100644 --- a/model_server/base/roiset.py +++ b/model_server/rois/roiset.py @@ -1,6 +1,5 @@ from collections import OrderedDict -import itertools -from math import sqrt, floor +from math import floor from pathlib import Path from typing import Dict, List, Union from typing_extensions import Self @@ -10,19 +9,20 @@ import glasbey import numpy as np import pandas as pd from pydantic import BaseModel, Field -from scipy.stats import moment -from skimage.filters import sobel from skimage import draw -from skimage.measure import approximate_polygon, find_contours, label, points_in_poly, regionprops, regionprops_table, shannon_entropy +from skimage.measure import approximate_polygon, find_contours, label, points_in_poly, regionprops from skimage.morphology import binary_dilation, disk -from .accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file -from .models import InstanceMaskSegmentationModel -from .process import get_safe_contours, pad, rescale, make_rgb -from .annotators import draw_box_on_patch, draw_contours_on_patch, draw_boxes_on_3d_image -from .accessors import generate_file_accessor, PatchStack -from .process import mask_largest_object +from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file +from model_server.base.models import InstanceMaskSegmentationModel +from model_server.base.process import get_safe_contours, pad, rescale, make_rgb, safe_add +from model_server.base.annotators import draw_box_on_patch, draw_contours_on_patch, draw_boxes_on_3d_image +from model_server.base.accessors import generate_file_accessor, PatchStack +from model_server.base.process import mask_largest_object +from model_server.rois.df import filter_df_overlap_seg, is_df_3d, insert_level, read_roiset_df, df_insert_slices +from model_server.rois.labels import get_label_ids, focus_metrics, make_df_from_object_ids, make_object_ids_from_df, \ + NoDeprojectChannelSpecifiedError class PatchParams(BaseModel): @@ -77,7 +77,7 @@ class RoiFilter(BaseModel): class RoiSetMetaParams(BaseModel): filters: Union[RoiFilter, None] = None - expand_box_by: List[int] = [128, 0] + expand_box_by: List[int] = [0, 0] deproject_channel: Union[int, None] = None deproject_intensity_threshold: float = 0.0 @@ -89,371 +89,10 @@ class RoiSetExportParams(BaseModel): labels_overlay: Union[RoiSetLabelsOverlayParams, None] = None derived_channels: bool = False write_patches_to_subdirectory: bool = Field( - False, + True, description='Write all patches to a subdirectory with prefix as name' ) - -def get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, connect_3d=True) -> InMemoryDataAccessor: - """ - Convert binary segmentation mask into either a 2D or 3D object identities map - :param acc_seg_mask: binary segmentation mask (mono) of either two or three dimensions - :param allow_3d: return a 3D map if True; return a 2D map of the mask's maximum intensity project if False - :param connect_3d: objects can span multiple z-positions if True; objects are unique to a single z if False - :return: object identities map - """ - if allow_3d and connect_3d: - nda_la = label( - acc_seg_mask.data_yxz, - connectivity=3, - ).astype('uint16') - return InMemoryDataAccessor(np.expand_dims(nda_la, 2)) - elif allow_3d and not connect_3d: - nla = 0 - la_3d = np.zeros((*acc_seg_mask.hw, 1, acc_seg_mask.nz), dtype='uint16') - for zi in range(0, acc_seg_mask.nz): - la_2d = label( - acc_seg_mask.data_yxz[:, :, zi], - connectivity=2, - ).astype('uint16') - la_2d[la_2d > 0] = la_2d[la_2d > 0] + nla - nla = la_2d.max() - la_3d[:, :, 0, zi] = la_2d - return InMemoryDataAccessor(la_3d) - else: - return InMemoryDataAccessor( - label( - acc_seg_mask.get_mip().data_yx, - connectivity=1, - ).astype('uint16') - ) - - -def focus_metrics(): - return { - 'max_intensity': lambda x: np.max(x), - 'stdev': lambda x: np.std(x), - 'max_sobel': lambda x: np.max(sobel(x)), - 'rms_sobel': lambda x: sqrt(np.mean(sobel(x) ** 2)), - 'entropy': lambda x: shannon_entropy(x), - 'moment': lambda x: moment(x.flatten(), moment=2), - } - - -def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame: - query_str = 'label > 0' # always true - if filters is not None: # parse filters - for k, val in filters.dict(exclude_unset=True).items(): - assert k in ('area', 'diag', 'min_hw') - if val is None: - continue - vmin = val['min'] - vmax = val['max'] - assert vmin >= 0 - query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}' - return df.loc[df.query(query_str).index, :] - - -def filter_df_overlap_bbox(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.DataFrame: - """ - If passed a single DataFrame, return the subset whose bounding boxes overlap in 3D space. If passed two DataFrames, - return the subset where a ROI in the first overlaps a ROI in the second. May return duplicates entries where a ROI - overlaps with multiple neighbors. - :param df1: DataFrame with potentially overlapping bounding boxes - :param df2: (optional) second DataFrame - :return DataFrame describing subset of overlapping ROIs - bbox_overlaps_with: index of ROI that overlaps - bbox_intersec: pixel area of intersecting region - """ - - def _compare(r0, r1): - olx = (r0.x0 < r1.x1) and (r0.x1 > r1.x0) - oly = (r0.y0 < r1.y1) and (r0.y1 > r1.y0) - olz = (r0.zi == r1.zi) - return olx and oly and olz - - def _intersec(r0, r1): - return (r0.x1 - r1.x0) * (r0.y1 - r1.y0) - - first = [] - second = [] - intersec = [] - - if df2 is not None: - for pair in itertools.product(df1.index, df2.index): - if _compare(df1.loc[pair[0]], df2.loc[pair[1]]): - first.append(pair[0]) - second.append(pair[1]) - intersec.append( - _intersec(df1.loc[pair[0]], df2.loc[pair[1]]) - ) - else: - for pair in itertools.combinations(df1.index, 2): - if _compare(df1.loc[pair[0]], df1.loc[pair[1]]): - first.append(pair[0]) - second.append(pair[1]) - first.append(pair[1]) - second.append(pair[0]) - isc = _intersec(df1.loc[pair[0]], df1.loc[pair[1]]) - intersec.append(isc) - intersec.append(isc) - - sdf = df1.loc[first] - sdf.loc[:, 'overlaps_with'] = second - sdf.loc[:, 'bbox_intersec'] = intersec - return sdf - - -def filter_df_overlap_seg(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.DataFrame: - """ - If passed a single DataFrame, return the subset whose segmentations overlap in 3D space. If passed two DataFrames, - return the subset where a ROI in the first overlaps a ROI in the second. May return duplicates entries where a ROI - overlaps with multiple neighbors. - :param df1: DataFrame with potentially overlapping bounding boxes - :param df2: (optional) second DataFrame - :return DataFrame describing subset of overlapping ROIs - seg_overlaps_with: index of ROI that overlaps - seg_intersec: pixel area of intersecting region - seg_iou: intersection over union - """ - - dfbb = filter_df_overlap_bbox(df1, df2) - - def _overlap_seg(r): - roi1 = df1.loc[r.name] - if df2 is not None: - roi2 = df2.loc[r.overlaps_with] - else: - roi2 = df1.loc[r.overlaps_with] - ex0 = min(roi1.x0, roi2.x0, roi1.x1, roi2.x1) - ew = max(roi1.x0, roi2.x0, roi1.x1, roi2.x1) - ex0 - ey0 = min(roi1.y0, roi2.y0, roi1.y1, roi2.y1) - eh = max(roi1.y0, roi2.y0, roi1.y1, roi2.y1) - ey0 - emask = np.zeros((eh, ew), dtype='uint8') - sl1 = np.s_[(roi1.y0 - ey0): (roi1.y1 - ey0), (roi1.x0 - ex0): (roi1.x1 - ex0)] - sl2 = np.s_[(roi2.y0 - ey0): (roi2.y1 - ey0), (roi2.x0 - ex0): (roi2.x1 - ex0)] - emask[sl1] = roi1.binary_mask - emask[sl2] = emask[sl2] + roi2.binary_mask - return emask - - emasks = dfbb.apply(_overlap_seg, axis=1) - dfbb['seg_overlaps'] = emasks.apply(lambda x: np.any(x > 1)) - dfbb['seg_intersec'] = emasks.apply(lambda x: (x == 2).sum()) - dfbb['seg_iou'] = emasks.apply(lambda x: (x == 2).sum() / (x > 0).sum()) - return dfbb - - -def is_df_3d(df: pd.DataFrame) -> bool: - return 'z0' in df.columns and 'z1' in df.columns - - -def make_df_from_object_ids( - acc_raw, - acc_obj_ids, - expand_box_by, - deproject_channel=None, - filters=None, - deproject_intensity_threshold=0.0 -) -> pd.DataFrame: - """ - Build dataframe that associate object IDs with summary stats; - :param acc_raw: accessor to raw image data - :param acc_obj_ids: accessor to map of object IDs - :param expand_box_by: number of pixels to expand bounding box in all directions (without exceeding image boundary) - :param deproject_channel: if objects' z-coordinates are not specified, compute them based on argmax of this channel - :param deproject_intensity_threshold: when deprojecting, round MIP deprojection_channel to zero if below this - threshold (as fraction of full range, 0.0 to 1.0) - :return: pd.DataFrame - """ - # build dataframe of objects, assign z index to each object - - if acc_obj_ids.nz == 1 and acc_raw.nz > 1: # apply deprojection - - if deproject_channel is None or deproject_channel >= acc_raw.chroma or deproject_channel < 0: - if acc_raw.chroma == 1: - deproject_channel = 0 - else: - raise NoDeprojectChannelSpecifiedError( - f'When labeling objects, either their z-coordinates or a valid deprojection channel are required.' - ) - - mono = acc_raw.get_mono(deproject_channel) - intensity_weight = mono.get_mip().data_yx.astype('uint16') - intensity_weight[intensity_weight < (deproject_intensity_threshold * mono.dtype_max)] = 0 - argmax = mono.get_z_argmax().data_yx.astype('uint16') - zi_map = np.stack([ - intensity_weight, - argmax * intensity_weight, - ], axis=-1) - - assert len(zi_map.shape) == 3 - df = pd.DataFrame(regionprops_table( - acc_obj_ids.data_yx, - intensity_image=zi_map, - properties=('label', 'area', 'intensity_mean', 'bbox') - )).rename(columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1'}) - - df['zi'] = (df['intensity_mean-1'] / df['intensity_mean-0']).fillna(0).round().astype('int16') - df = df.drop(['intensity_mean-0', 'intensity_mean-1'], axis=1) - - def _make_binary_mask(r): - acc = InMemoryDataAccessor(acc_obj_ids.data == r.label) - cropped = acc.get_mono(0, mip=True).crop_hw((r.y0, r.x0, (r.y1 - r.y0), (r.x1 - r.x0))).data_yx - return cropped - - elif acc_obj_ids.nz == 1 and acc_raw.nz == 1: # purely 2d, no z information in dataframe - df = pd.DataFrame(regionprops_table( - acc_obj_ids.data_yx, - properties=('label', 'area', 'bbox') - )).rename(columns={ - 'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1' - }) - - def _make_binary_mask(r): - acc = InMemoryDataAccessor(acc_obj_ids.data == r.label) - cropped = acc.get_mono(0).crop_hw((r.y0, r.x0, (r.y1 - r.y0), (r.x1 - r.x0))).data_yx - return cropped - - else: # purely 3d: objects' median z-coordinates come from arg of max count in object identities map - df = pd.DataFrame(regionprops_table( - acc_obj_ids.data_yxz, - properties=('label', 'area', 'bbox') - )).rename(columns={ - 'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'z0', 'bbox-3': 'y1', 'bbox-4': 'x1', 'bbox-5': 'z1' - }) - - def _get_zi_from_label(r): - r = r.convert_dtypes() - la = r.label - crop = acc_obj_ids.crop_hwd((r.y0, r.x0, r.z0, r.y1 - r.y0, r.x1 - r.x0, r.z1 - r.z0)) - rel_argzmax = crop.apply(lambda x: x == la).get_focus_vector().argmax() - return rel_argzmax + r.z0 - - df['zi'] = df.apply(_get_zi_from_label, axis=1, result_type='reduce') - df['nz'] = df['z1'] - df['z0'] - - def _make_binary_mask(r): - r = r.convert_dtypes() - la = r.label - crop = acc_obj_ids.crop_hwd((r.y0, r.x0, r.z0, r.y1 - r.y0, r.x1 - r.x0, r.z1 - r.z0)) - return crop.apply(lambda x: x == la).data_yxz - - df = df_insert_slices(df, acc_raw.shape_dict, expand_box_by) - df_fil = filter_df(df, filters) - df_fil['binary_mask'] = df_fil.apply( - _make_binary_mask, - axis=1, - result_type='reduce', - ) - return df_fil - - -def df_insert_slices(df: pd.DataFrame, sd: dict, expand_box_by) -> pd.DataFrame: - h = sd['Y'] - w = sd['X'] - nz = sd['Z'] - - df['h'] = df['y1'] - df['y0'] - df['w'] = df['x1'] - df['x0'] - df['diag'] = (df['w']**2 + df['h']**2).apply(sqrt) - df['min_hw'] = df[['w', 'h']].min(axis=1) - - ebxy, ebz = expand_box_by - df['ebb_y0'] = (df.y0 - ebxy).apply(lambda x: max(x, 0)) - df['ebb_y1'] = (df.y1 + ebxy).apply(lambda x: min(x, h)) - df['ebb_x0'] = (df.x0 - ebxy).apply(lambda x: max(x, 0)) - df['ebb_x1'] = (df.x1 + ebxy).apply(lambda x: min(x, w)) - - # handle based on whether bounding box coordinates are 2d or 3d - if is_df_3d(df): - df['ebb_z0'] = (df.z0 - ebz).apply(lambda x: max(x, 0)) - df['ebb_z1'] = (df.z1 + ebz).apply(lambda x: max(x, nz)) - else: - if 'zi' not in df.columns: - df['zi'] = 0 - df['ebb_z0'] = (df.zi - ebz).apply(lambda x: max(x, 0)) - df['ebb_z1'] = (df.zi + ebz).apply(lambda x: min(x, nz)) - - df['ebb_h'] = df['ebb_y1'] - df['ebb_y0'] - df['ebb_w'] = df['ebb_x1'] - df['ebb_x0'] - df['ebb_nz'] = df['ebb_z1'] - df['ebb_z0'] + 1 - - # compute relative bounding boxes - df['rel_y0'] = df.y0 - df.ebb_y0 - df['rel_y1'] = df.y1 - df.ebb_y0 - df['rel_x0'] = df.x0 - df.ebb_x0 - df['rel_x1'] = df.x1 - df.ebb_x0 - - assert np.all(df['rel_x1'] <= (df['ebb_x1'] - df['ebb_x0'])) - assert np.all(df['rel_y1'] <= (df['ebb_y1'] - df['ebb_y0'])) - - if is_df_3d(df): - df['slice'] = df.apply( - lambda r: - np.s_[int(r.y0): int(r.y1), int(r.x0): int(r.x1), :, int(r.z0): int(r.z1)], - axis=1, - result_type='reduce', - ) - else: - df['slice'] = df.apply( - lambda r: - np.s_[int(r.y0): int(r.y1), int(r.x0): int(r.x1), :, int(r.zi): int(r.zi + 1)], - axis=1, - result_type='reduce', - ) - df['expanded_slice'] = df.apply( - lambda r: - np.s_[int(r.ebb_y0): int(r.ebb_y1), int(r.ebb_x0): int(r.ebb_x1), :, int(r.ebb_z0): int(r.ebb_z1) + 1], - axis=1, - result_type='reduce', - ) - df['relative_slice'] = df.apply( - lambda r: - np.s_[int(r.rel_y0): int(r.rel_y1), int(r.rel_x0): int(r.rel_x1), :, :], - axis=1, - result_type='reduce', - ) - return df - - -def safe_add(a, g, b): - assert a.dtype == b.dtype - assert a.shape == b.shape - assert g >= 0.0 - - return np.clip( - a.astype('uint32') + g * b.astype('uint32'), - 0, - np.iinfo(a.dtype).max - ).astype(a.dtype) - - -def make_object_ids_from_df(df: pd.DataFrame, sd: dict) -> InMemoryDataAccessor: - id_mask = np.zeros((sd['Y'], sd['X'], 1, sd['Z']), dtype='uint16') - - if 'binary_mask' not in df.columns: - raise MissingSegmentationError('RoiSet dataframe does not contain segmentation') - - if is_df_3d(df): # use 3d coordinates - def _label_obj(r): - sl = np.s_[r.y0:r.y1, r.x0:r.x1, :, r.z0:r.z1] - mask = np.expand_dims(r.binary_mask, 2) - id_mask[sl] = id_mask[sl] + r.label * mask - elif 'zi' in df.columns: - def _label_obj(r): - sl = np.s_[r.y0:r.y1, r.x0:r.x1, :, r.zi: (r.zi + 1)] - mask = np.expand_dims(r.binary_mask, (2, 3)) - id_mask[sl] = id_mask[sl] + r.label * mask - else: - def _label_obj(r): - sl = np.s_[r.y0:r.y1, r.x0:r.x1, :] - mask = np.expand_dims(r.binary_mask, (2, 3)) - id_mask[sl] = id_mask[sl] + r.label * mask - - df.apply(_label_obj, axis=1) - return InMemoryDataAccessor(id_mask) - - class RoiSet(object): def __init__( @@ -473,12 +112,10 @@ class RoiSet(object): self.accs_derived = [] self.params = params + df['patches', 'index'] = df.reset_index().index self._df = df - self.count = len(self._df) - def __iter__(self): - """Expose ROI meta information via the Pandas.DataFrame API""" - return self._df.itertuples(name='Roi') + self.count = len(self._df) @classmethod def from_object_ids( @@ -547,17 +184,19 @@ class RoiSet(object): bbox_df['x1'] = bbox_df['x0'] + bbox_df['w'] bbox_df['label'] = bbox_df.index - + bbox_df.set_index('label', inplace=True) + bbox_df = bbox_df.drop(['x', 'y', 'w', 'h'], axis=1) + insert_level(bbox_df, 'bounding_box') df = df_insert_slices( - bbox_df[['y0', 'x0', 'y1', 'x1', 'zi', 'label']], + bbox_df, acc_raw.shape_dict, params.expand_box_by, ) def _make_binary_mask(r): - return np.ones((r.h, r.w), dtype=bool) + return np.ones((int(r.h), int(r.w)), dtype=bool) - df['binary_mask'] = df.apply( + df['masks', 'binary_mask'] = df.bounding_box.apply( _make_binary_mask, axis=1, result_type='reduce', @@ -623,14 +262,17 @@ class RoiSet(object): 'raw_shape_dict': self.acc_raw.shape_dict, 'count': self.count, 'classify_by': self.classification_columns, - 'df_memory_usage': int(self.get_df().memory_usage(deep=True).sum()) + 'df_memory_usage': int(self.df().memory_usage(deep=True).sum()) } - def get_df(self) -> pd.DataFrame: - return self._df + def df(self) -> pd.DataFrame: + return self._df.copy() + + def list_bounding_boxes(self): + return self.df.bounding_box.reset_index().to_dict(orient='records') def get_slices(self) -> pd.Series: - return self.get_df()['slice'] + return self.df()['slices', 'slice'] def add_df_col(self, name, se: pd.Series) -> None: self._df[name] = se @@ -663,8 +305,8 @@ class RoiSet(object): # make an object map where label is replaced by focus position in stack and background is -1 lut = np.zeros(lamap.max() + 1) - 1 - df = self.get_df() - lut[df.label] = df.zi + df = self.df() + lut[df.index] = df.bounding_box.zi if mask_type == 'contours': zi_map = (lut[lamap] + 1.0).astype('int') @@ -680,14 +322,14 @@ class RoiSet(object): zi_st[:, :, :, -1][lamap == 0] = 0 elif mask_type == 'boxes': - for roi in self: - zi_st[roi.slice] = True - + def _set_box(sl): + zi_st[sl] = True + self._df.slices.slice.apply(_set_box) return zi_st @property def is_3d(self) -> bool: - return is_df_3d(self.get_df()) + return is_df_3d(self.df()) def classify_by( self, name: str, channels: list[int], @@ -703,7 +345,7 @@ class RoiSet(object): :return: None """ if self.count == 0: - self._df['classify_by_' + name] = None + self._df['classifications', name] = None return True input_acc = self.get_patches_acc(channels=channels, expanded=False, pad_to=None) # all channels @@ -716,14 +358,14 @@ class RoiSet(object): se = pd.Series(dtype='Int64', index=self._df.index) - for i, roi in enumerate(self): + for i, la in enumerate(self._df.index): oc = np.unique( mask_largest_object( obmap_patches.iat(i).data_yxz ) )[-1] - se[roi.Index] = oc - self.set_classification(f'classify_by_{name}', se) + se[la] = oc + self.set_classification(name, se) def get_instance_classification(self, roiset_from: Self, iou_min: float = 0.5) -> pd.DataFrame: """ @@ -736,18 +378,18 @@ class RoiSet(object): raise ShapeMismatchError( f'Expecting two RoiSets of same shape: {self.acc_raw.shape} != {roiset_from.acc_raw.shape}') - columns = [f'classify_by_{c}' for c in roiset_from.classification_columns] + columns = roiset_from.classification_columns if len(columns) == 0: raise MissingInstanceLabelsError('Expecting at least on instance classification channel but none found') df_overlaps = filter_df_overlap_seg( - roiset_from.get_df(), - self.get_df() + roiset_from.df(), + self.df() ) df_overlaps['transfer'] = df_overlaps.seg_iou > iou_min df_merge = pd.merge( - roiset_from.get_df()[columns], + roiset_from.df().classifications[columns], df_overlaps.loc[df_overlaps.transfer, ['overlaps_with']], left_index=True, right_index=True, @@ -758,25 +400,24 @@ class RoiSet(object): return df_overlaps - def get_object_class_map(self, name: str, filter_by: Union[List, None] = None) -> InMemoryDataAccessor: + def get_object_class_map(self, class_name: str, filter_by: Union[List, None] = None) -> InMemoryDataAccessor: """ For a given classification result, return a map where object IDs are replaced by each object's class :param name: name of the classification result, same as passed to RoiSet.classify_by() :param filter_by: only include ROIs if the intersection of all specified classifications is True :return: accessor of object class map """ - colname = ('classify_by_' + name) - assert colname in self._df.columns + assert class_name in self._df.classifications.columns obj_ids = self.acc_obj_ids om = np.zeros(obj_ids.shape, obj_ids.dtype) def _label_object_class(roi): - om[self.acc_obj_ids.data == roi.label] = roi[colname] + om[self.acc_obj_ids.data == roi.name] = roi[class_name] if filter_by is None: - self._df.apply(_label_object_class, axis=1) + self._df.classifications.apply(_label_object_class, axis=1) else: - pd_fil = self._df[[f'classify_by_{fb}' for fb in filter_by]] - self._df.loc[pd_fil.all(axis=1), :].apply(_label_object_class, axis=1) + pd_fil = self._df.classifications[filter_by] + self._df.classifications.loc[pd_fil.all(axis=1), :].apply(_label_object_class, axis=1) return InMemoryDataAccessor(om) def get_object_identities_overlay_map( @@ -790,7 +431,7 @@ class RoiSet(object): if rescale_clip is not None: mono = mono.apply(lambda x: rescale(x, clip=rescale_clip)) mono = mono.to_8bit().data_yxz - max_label = self.get_df()['label'].max() + max_label = self.df().index.max() palette = np.array([[0, 0, 0]] + glasbey.create_palette(max_label, as_hex=False)) rgb_8bit_palette = (255 * palette).round().astype('uint8') id_map_yxzc = rgb_8bit_palette[self.acc_obj_ids.data_yxz] @@ -814,23 +455,35 @@ class RoiSet(object): return acc.write(fp, composite=True).name def get_serializable_dataframe(self) -> pd.DataFrame: - return self._df.drop(['expanded_slice', 'slice', 'relative_slice', 'binary_mask'], axis=1) + return self._df.drop([ + ('slices', 'expanded_slice'), + ('slices', 'slice'), + ('slices', 'relative_slice'), + ('masks', 'binary_mask') + ], + axis=1 + ) def export_dataframe(self, csv_path: Path) -> str: csv_path.parent.mkdir(parents=True, exist_ok=True) - self.get_serializable_dataframe().to_csv(csv_path, index=False) + self.get_serializable_dataframe().reset_index( + col_level=1, + ).to_csv( + csv_path, + index=False, + ) return csv_path.name def export_patch_masks(self, where: Path, prefix='mask', expanded=False, make_3d=True, mask_mip=False, **kwargs) -> pd.DataFrame: patches_df = self.get_patch_masks(pad_to=None, expanded=expanded, make_3d=make_3d, mask_mip=mask_mip).copy() - if 'nz' in patches_df.columns and any(patches_df['nz'] > 1): + if 'nz' in patches_df.bounding_box.columns and any(patches_df.bounding_box['nz'] > 1): ext = 'tif' else: ext = 'png' def _export_patch_mask(roi): - patch = InMemoryDataAccessor.from_mono(roi.patch_mask) - fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}' + patch = InMemoryDataAccessor.from_mono(roi.masks.patch_mask) + fname = f'{prefix}-la{roi.name:04d}-zi{roi.bounding_box.zi:04d}.{ext}' write_accessor_data_to_file(where / fname, patch) return fname @@ -845,14 +498,13 @@ class RoiSet(object): :return: pd.Series of patch paths """ make_3d = kwargs.get('make_3d', False) - patches_df = self._df.join( - self.get_patches(**kwargs).rename('patch') - ) + patches_df = self._df.copy() + patches_df['patches', 'patch'] = self.get_patches(**kwargs) def _export_patch(roi): - patch = InMemoryDataAccessor(roi.patch) + patch = InMemoryDataAccessor(roi.patches.patch) ext = 'tif' if make_3d or patch.chroma > 3 or kwargs.get('force_tif') else 'png' - fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}' + fname = f'{prefix}-la{roi.name:04d}-zi{roi.bounding_box.zi:04d}.{ext}' if patch.dtype == 'uint16': resampled = patch.to_8bit() @@ -870,10 +522,10 @@ class RoiSet(object): def _make_patch_mask(roi): if expanded: - patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8') - patch[roi.relative_slice][:, :, 0, 0] = roi.binary_mask * 255 + patch = np.zeros((roi.expanded_bounding_box.ebb_h, roi.expanded_bounding_box.ebb_w, 1, 1), dtype='uint8') + patch[roi.slices.relative_slice][:, :, 0, 0] = roi.masks.binary_mask * 255 else: - patch = (roi.binary_mask * 255).astype('uint8') + patch = (roi.masks.binary_mask * 255).astype('uint8') if pad_to: patch = pad(patch, pad_to) if self.is_3d and make_3d: @@ -881,22 +533,29 @@ class RoiSet(object): elif self.is_3d and mask_mip: return np.max(patch, axis=-1) elif self.is_3d: - rzi = roi.zi - roi.z0 + rzi = roi.bounding_box.zi - roi.bounding_box.z0 return patch[:, :, rzi: rzi + 1] else: return np.expand_dims(patch, 2) dfe = self._df.copy() - dfe['patch_mask'] = dfe.apply(_make_patch_mask, axis=1) + dfe['masks', 'patch_mask'] = dfe.apply(_make_patch_mask, axis=1) return dfe def get_patch_masks_acc(self, **kwargs) -> Union[PatchStack, None]: if self.count == 0: return None - se_pm = self.get_patch_masks(**kwargs).patch_mask + se_pm = self.get_patch_masks(**kwargs).masks.patch_mask se_ext = se_pm.apply(lambda x: np.expand_dims(x, 2)) return PatchStack(list(se_ext)) + def get_patch_obmap_acc(self, **kwargs) -> Union[PatchStack, None]: + if self.count == 0: + return None + labels = self.df().index.sort_values().to_list() + acc_masks = self.get_patch_masks_acc(**kwargs) + return PatchStack([(acc_masks.iat(i).data > 0) * labels[i] for i in range(0, len(labels))]) + def get_patches( self, rescale_clip: float = 0.0, @@ -962,10 +621,10 @@ class RoiSet(object): def _make_patch(roi): # extract, focus, and annotate a patch if expanded: - patch3d = stack[roi.expanded_slice] - subpatch = patch3d[roi.relative_slice] + patch3d = stack[roi.slices.expanded_slice] + subpatch = patch3d[roi.slices.relative_slice] else: - patch3d = stack[roi.slice] + patch3d = stack[roi.slices.slice] subpatch = patch3d ph, pw, pc, pz = patch3d.shape @@ -973,7 +632,7 @@ class RoiSet(object): # make a 3d patch, focus stays where it is if make_3d: patch = patch3d.copy() - zif = roi.zi + zif = roi.bounding_box.zi # make a 2d patch, find optimal z-position determined by focus_metric function on each channel separately elif focus_metric is not None: @@ -995,9 +654,9 @@ class RoiSet(object): mask = np.zeros(patch3d.shape[0:2], dtype=bool) if expanded: - mask[roi.relative_slice[0:2]] = roi.binary_mask + mask[roi.slices.relative_slice[0:2]] = roi.masks.binary_mask else: - mask = roi.binary_mask + mask = roi.masks.binary_mask if rescale_clip is not None: if rgb_overlay_channels: # rescale all equally to preserve white balance @@ -1014,7 +673,10 @@ class RoiSet(object): for zi in range(0, patch.shape[3]): patch[:, :, bci, zi] = draw_box_on_patch( patch[:, :, bci, zi], - ((roi.rel_x0, roi.rel_y0), (roi.rel_x1, roi.rel_y1)), + ( + (roi.relative_bounding_box.rel_x0, roi.relative_bounding_box.rel_y0), + (roi.relative_bounding_box.rel_x1, roi.relative_bounding_box.rel_y1) + ), linewidth=kwargs.get('bounding_box_linewidth', 1) ) @@ -1037,29 +699,30 @@ class RoiSet(object): patch = pad(patch, pad_to) return { 'patch': patch, - 'zif': (roi.z0 + zif) if hasattr(roi, 'z0') else roi.zi, + 'zif': (roi.bounding_box.z0 + zif) if hasattr(roi.bounding_box, 'z0') else roi.bounding_box.zi, } df_processed_patches = self._df.apply(lambda r: _make_patch(r), axis=1, result_type='expand') if update_focus_zi: - self._df['zi'] = df_processed_patches['zif'] + self._df.loc[:, ('bounding_box', 'zi')] = df_processed_patches['zif'] return df_processed_patches['patch'] @property - def classification_columns(self): + def classification_columns(self) -> List[str]: """ Return list of columns that describe instance classification results """ - pr = 'classify_by_' - return [c.split(pr)[1] for c in self._df.columns if c.startswith(pr)] + if (dfc := self._df.get('classifications')) is None: + return [] + return dfc.columns.to_list() - def set_classification(self, colname: str, se: pd.Series): + def set_classification(self, classification_class: str, se: pd.Series): """ Set instance classification result as a column addition on dataframe - :param colname: name of classification result + :param classification_class: name of classification result :param se: series containing class information """ - self._df[colname] = se + self._df['classifications', classification_class] = se def run_exports(self, where: Path, prefix, params: RoiSetExportParams) -> dict: """ @@ -1092,7 +755,9 @@ class RoiSet(object): f'{product_name}_path': se_paths, f'{product_name}_id': se_paths.apply(lambda _: uuid4()), }) + insert_level(df_patch_info, 'patches') self._df = self._df.join(df_patch_info) + assert isinstance(self._df.columns, pd.MultiIndex) record[product_name] = list(se_paths) if k == 'annotated_zstacks': record[k] = str(Path(k) / self.export_annotated_zstack(where / k, prefix=prefix, **kp)) @@ -1146,9 +811,9 @@ class RoiSet(object): if k == 'annotated_zstacks': interm[k] = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kp)) if k == 'object_classes': - pr = 'classify_by_' - cnames = [c.split(pr)[1] for c in self._df.columns if c.startswith(pr)] - for n in cnames: + # pr = 'classify_by_' + # cnames = [c.split(pr)[1] for c in self._df.columns if c.startswith(pr)] + for n in self.classification_columns: interm[f'{k}_{n}'] = self.get_object_class_map(n) if k == 'labels_overlay': interm[k] = self.get_object_identities_overlay_map(**kp) @@ -1165,7 +830,7 @@ class RoiSet(object): :return: nested dict of Path objects describing the locations of export products """ record = {} - if not self._df.binary_mask.apply(lambda x: np.all(x)).all(): # binary masks aren't just all True + if not self._df.masks.binary_mask.apply(lambda x: np.all(x)).all(): # binary masks aren't just all True subdir = Path('tight_patch_masks') if write_patches_to_subdirectory: subdir = subdir / prefix @@ -1179,7 +844,7 @@ class RoiSet(object): se_pa = se_exp.apply( lambda x: str(subdir / x) ).rename('tight_patch_masks_path') - self._df = self._df.join(se_pa) + self._df['patches', 'tight_patch_masks_path'] = se_exp.apply(lambda x: str(subdir / x)) record['tight_patch_masks'] = list(se_pa) csv_path = where / 'dataframe' / (prefix + '.csv') @@ -1205,7 +870,7 @@ class RoiSet(object): pad_to = 1 def _poly_from_mask(roi): - mask = roi.binary_mask + mask = roi.masks.binary_mask if len(mask.shape) != 2: raise PatchShapeError(f'Patch mask needs to be two dimensions to fit a polygon') @@ -1225,7 +890,7 @@ class RoiSet(object): break rel_polygon = approximate_polygon(contour[:, [1, 0]], poly_threshold) - [pad_to, pad_to] - return rel_polygon + [roi.x0, roi.y0] + return rel_polygon + [roi.bounding_box.x0, roi.bounding_box.y0] return self._df.apply(_poly_from_mask, axis=1) @@ -1234,6 +899,28 @@ class RoiSet(object): def acc_obj_ids(self): return make_object_ids_from_df(self._df, self.acc_raw.shape_dict) + + def extract_features( + self, + extractor: callable, + **kwargs, + ): + """ + Join a grouping of each Roi's features + :param extractor: function that takes an RoiSet object and returns a DataFrame of features + :param kwargs: variable-length keyword arguments that are passed to feature extractor + """ + if self.count == 0: + return + df_features = extractor(self, **kwargs) + insert_level(df_features, 'features') + self._df = self._df.join(df_features) + + + def get_features(self) -> pd.DataFrame: + return self.df().get('features') + + @classmethod def deserialize(cls, acc_raw: GenericImageDataAccessor, where: Path, prefix='roiset') -> Self: """ @@ -1244,14 +931,15 @@ class RoiSet(object): :param prefix: starting prefix of patch mask filenames :return: RoiSet object """ - df = pd.read_csv(where / 'dataframe' / (prefix + '.csv')) + df = read_roiset_df(where / 'dataframe' / (prefix + '.csv')) + df.index.name = 'label' pa_masks = where / 'tight_patch_masks' is_3d = is_df_3d(df) ext = 'tif' if is_3d else 'png' if pa_masks.exists(): # import segmentation masks def _read_binary_mask(r): - fname = f'{prefix}-la{r.label:04d}-zi{r.zi:04d}.{ext}' + fname = f'{prefix}-la{r.name:04d}-zi{r.bounding_box.zi:04d}.{ext}' try: ma_acc = generate_file_accessor(pa_masks / fname) if is_3d: @@ -1262,155 +950,22 @@ class RoiSet(object): except Exception as e: raise DeserializeRoiSetError(e) - df['binary_mask'] = df.apply(_read_binary_mask, axis=1) + df['masks', 'binary_mask'] = df.apply(_read_binary_mask, axis=1) id_mask = make_object_ids_from_df(df, acc_raw.shape_dict) return cls.from_object_ids(acc_raw, id_mask) else: # assume bounding boxes, exclusively 2d objects - df['y'] = df['y0'] - df['x'] = df['x0'] - df['h'] = df['y1'] - df['y0'] - df['w'] = df['x1'] - df['x0'] + df['bounding_box', 'y'] = df.bounding_box['y0'] + df['bounding_box', 'x'] = df.bounding_box['x0'] + df['bounding_box', 'h'] = df.bounding_box['y1'] - df.bounding_box['y0'] + df['bounding_box', 'w'] = df.bounding_box['x1'] - df.bounding_box['x0'] return cls.from_bounding_boxes( acc_raw, - df[['y', 'x', 'h', 'w']].to_dict(orient='records'), - list(df['zi']) + df.bounding_box[['y', 'x', 'h', 'w']].to_dict(orient='records'), + list(df.bounding_box['zi']) ) -class RoiSetWithDerivedChannelsExportParams(RoiSetExportParams): - derived_channels: bool = False - - -class RoiSetWithDerivedChannels(RoiSet): - - def __init__(self, *a, **k): - self.accs_derived = [] - super().__init__(*a, **k) - - def classify_by( - self, name: str, channels: list[int], - object_classification_model: InstanceMaskSegmentationModel, - derived_channel_functions: list[callable] = None - ): - """ - Insert a column in RoiSet data table that associates each ROI with an integer class, determined by passing - specified inputs through an instance segmentation classifier. Derive additional inputs for object - classification by passing a raw input channel through one or more functions. - - :param name: name of column to insert - :param channels: list of nc raw input channels to send to classifier - :param object_classification_model: InstanceSegmentation model object - :param derived_channel_functions: list of functions that each receive a PatchStack accessor with nc channels and - that return a single-channel PatchStack accessor of the same shape - :return: None - """ - - acc_in = self.get_patches_acc(channels=channels, expanded=False, pad_to=None) - if derived_channel_functions is not None: - for fcn in derived_channel_functions: - der = fcn(acc_in) # returns patch stack - self.accs_derived.append(der) - - # combine channels - acc_app = acc_in - for acc_der in self.accs_derived: - acc_app = acc_app.append_channels(acc_der) - - else: - acc_app = acc_in - - # do this on a patch basis, i.e. only one object per frame - obmap_patches = object_classification_model.label_patch_stack( - acc_app, - self.get_patch_masks_acc(expanded=False, pad_to=None) - ) - - self._df['classify_by_' + name] = pd.Series(dtype='Int16') - - for i, roi in enumerate(self): - oc = np.unique( - mask_largest_object( - obmap_patches.iat(i).data - ) - )[-1] - self._df.loc[roi.Index, 'classify_by_' + name] = oc - - def run_exports(self, where: Path, prefix, params: RoiSetWithDerivedChannelsExportParams) -> dict: - """ - Export various representations of ROIs, e.g. patches, annotated stacks, and object maps. - :param where: path of directory in which to write all export products - :param prefix: prefix of the name of each product's file or subfolder - :param params: RoiSetExportParams object describing which products to export and with which parameters - :return: nested dict of Path objects describing the location of export products - """ - record = super().run_exports(where, prefix, params) - - k = 'derived_channels' - if k in params.dict().keys(): - record[k] = [] - for di, dacc in enumerate(self.accs_derived): - fp = where / k / f'dc{di:01d}.tif' - fp.parent.mkdir(exist_ok=True, parents=True) - dacc.export_pyxcz(fp) - record[k].append(str(fp)) - return record - - -class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationModel): - def __init__(self, tr: float = 0.5): - """ - Model that labels all objects as class 1 if the intensity in specified channel exceeds a threshold; labels all - objects as class 1 if threshold = 0.0 - :param tr: threshold in range of 0.0 to 1.0; model handles normalization to full pixel intensity range - :param channel: channel to use for thresholding - """ - self.tr = tr - self.loaded = self.load() - super().__init__(info={'tr': tr}) - - def load(self): - return True - - def infer( - self, - img: GenericImageDataAccessor, - mask: GenericImageDataAccessor, - allow_3d: bool = False, - connect_3d: bool = True, - ) -> GenericImageDataAccessor: - if img.chroma != 1: - raise ShapeMismatchError( - f'IntensityThresholdInstanceMaskSegmentationModel expects 1 channel but received {img.chroma}' - ) - if isinstance(img, PatchStack): # assume one object per patch - df = img.get_object_df(mask) - om = np.zeros(mask.shape, 'uint16') - def _label_patch_class(la): - om[la] = (mask.iat(la).data > 0) * 1 - df.loc[df['intensity_mean'] > (self.tr * img.dtype_max), 'label'].apply(_label_patch_class) - return PatchStack(om) - else: - labels = get_label_ids(mask) - df = pd.DataFrame(regionprops_table( - labels.data_yxz, - intensity_image=img.data_yxz, - properties=('label', 'area', 'intensity_mean') - )) - - om = np.zeros(labels.shape, labels.dtype) - def _label_object_class(la): - om[labels.data == la] = 1 - df.loc[df['intensity_mean'] > (self.tr * img.dtype_max), 'label'].apply(_label_object_class) - return InMemoryDataAccessor(om) - - def label_instance_class( - self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs - ) -> GenericImageDataAccessor: - super().label_instance_class(img, mask, **kwargs) - return self.infer(img, mask) - - class Error(Exception): pass @@ -1431,18 +986,10 @@ class SerializeRoiSetError(Error): pass -class NoDeprojectChannelSpecifiedError(Error): - pass - - class DerivedChannelError(Error): pass -class MissingSegmentationError(Error): - pass - - class PatchShapeError(Error): pass diff --git a/model_server/rois/router.py b/model_server/rois/router.py new file mode 100644 index 0000000000000000000000000000000000000000..a28e2ffb11bd2724c2806d24f09eda358e7bf91c --- /dev/null +++ b/model_server/rois/router.py @@ -0,0 +1,15 @@ +import importlib + +from fastapi import APIRouter + +router = APIRouter( + prefix='/rois/pipelines', + tags=['pipelines'], +) + +for m in ['roiset_obmap', 'add_roiset']: + router.include_router( + importlib.import_module( + f'{__package__}.pipelines.{m}' + ).router + ) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index cc720ff8e064dc2ade4137f5b3c7aa554a686034..d2741541a66bcf061898d7163ef5f52cec311bf9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" [project] name = "model_server" license = {file = "LICENSE"} -version = "2024.11.01" +version = "2025.02.01" authors = [ { name="Christopher Rhodes", email="christopher.rhodes@embl.de" }, ] diff --git a/tests/base/test_accessors.py b/tests/base/test_accessors.py index 93c80046fe357bb9a55c561225f60e2f9fd17f07..2fe1eedf2d1e4e1a1a80a78ed21fa66045fd96d7 100644 --- a/tests/base/test_accessors.py +++ b/tests/base/test_accessors.py @@ -155,6 +155,15 @@ class TestCziImageFileAccess(unittest.TestCase): self.assertEqual(sc3d.shape_dict['C'], nc) self.assertEqual(sc3d.hw, yxhw[2:]) + def test_yxzc(self): + w = 256 + h = 512 + nc = 4 + nz = 11 + cf = InMemoryDataAccessor(_random_int(h, w, nc, nz)) + + self.assertEqual(cf.data_yxzc.shape, (h, w, nz, nc)) + def test_write_single_channel_tif(self): ch = 4 @@ -279,6 +288,21 @@ class TestCziImageFileAccess(unittest.TestCase): self.assertIsInstance(acc_cf._data, np.ndarray) self.assertIsInstance(acc_cf._metadata, dict) + def test_apply_function_with_params(self): + acc = InMemoryDataAccessor(np.ones((16, 16, 1, 3), dtype='uint8')) + def _multiply(nda, by): + return nda * by + res = acc.apply(_multiply, {'by': 2}) + self.assertTrue(np.all(res.data == 2)) + + def test_apply_mono_function(self): + acc = InMemoryDataAccessor(np.ones((16, 16, 2, 3), dtype='uint8')) + with self.assertRaises(DataShapeError): + res = acc.apply(lambda x: 2 * x, mono=True) + res = acc.get_mono(channel=0).apply(lambda x: 2 * x, mono=True) + self.assertTrue(np.all(res.data == 2)) + self.assertEqual(res.nz, acc.nz) + class TestPatchStackAccessor(unittest.TestCase): def setUp(self) -> None: diff --git a/tests/base/test_pipelines.py b/tests/base/test_pipelines.py index d2bc6bc87025d1ba48b6c14459794008b8988a52..df8bae6733173dc3f5d776d4e38e8037ac85901e 100644 --- a/tests/base/test_pipelines.py +++ b/tests/base/test_pipelines.py @@ -3,7 +3,7 @@ import unittest import numpy as np from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file -from model_server.base.pipelines import router, segment, segment_zproj +from model_server.base.pipelines import segment, segment_zproj from model_server.base.pipelines.shared import PipelineParams, PipelineRecord, PipelineTrace from model_server.base.session import RunTaskError, session @@ -73,15 +73,23 @@ class TestSegmentationPipelines(unittest.TestCase): def test_append_traces(self): acc = generate_file_accessor(zstack['path']).apply(lambda x: x * 0.5) trace1 = PipelineTrace(acc) - trace1['double'] = trace1.last.apply(lambda x: 2 * x) + + from time import sleep + def _slowly_multiply(arr, by): + sleep(0.1 * by) + return arr * by + + trace1['double'] = trace1.last.apply(_slowly_multiply, params={'by': 2}) trace2 = PipelineTrace(trace1.last) - trace2['halve'] = trace2.last.apply(lambda x: 0.5 * x) + trace2['halve'] = trace2.last.apply(_slowly_multiply, params={'by': 0.5}) trace3 = trace1.append(trace2, skip_first=False) self.assertEqual(len(trace3), len(trace1) + len(trace2)) self.assertEqual(trace3['double'], trace3['appended_input']) self.assertTrue(np.all(trace3['input'].data == trace3['halve'].data)) + self.assertGreater(trace3.timer['double'], 0.2) + self.assertGreater(trace3.timer['halve'], 0.05) trace4 = trace1.append(trace2, skip_first=True) self.assertEqual(len(trace4), len(trace1) + len(trace2) - 1) diff --git a/tests/base/test_session.py b/tests/base/test_session.py index a8cad804ea5c03559f7080772f0ae186e8844348..568fa7014f726e15164f90794b2a25f960de9648 100644 --- a/tests/base/test_session.py +++ b/tests/base/test_session.py @@ -4,7 +4,6 @@ import unittest import numpy as np from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor -from model_server.base.roiset import RoiSet, RoiSetMetaParams from model_server.base.session import session import model_server.conf.testing as conf @@ -76,40 +75,12 @@ class TestGetSessionObject(unittest.TestCase): self.assertEqual(self.sesh.paths['inbound_images'], self.sesh.paths['outbound_images']) self.assertIsInstance(self.sesh.paths['outbound_images'], pathlib.Path) - def test_make_table(self): - import pandas as pd - data = [{'modulo': i % 2, 'times one hundred': i * 100} for i in range(0, 8)] - self.sesh.write_to_table( - 'test_numbers', {'X': 0, 'Y': 0}, pd.DataFrame(data[0:4]) - ) - self.assertTrue(self.sesh.tables['test_numbers'].path.exists()) - self.sesh.write_to_table( - 'test_numbers', {'X': 1, 'Y': 1}, pd.DataFrame(data[4:8]) - ) - - dfv = pd.read_csv(self.sesh.tables['test_numbers'].path) - self.assertEqual(len(dfv), len(data)) - self.assertEqual(dfv.columns[0], 'X') - self.assertEqual(dfv.columns[1], 'Y') - class TestSessionPersistentData(unittest.TestCase): def test_add_and_remove_accessor(self): data = conf.meta['image_files'] acc_in = generate_file_accessor(data['multichannel_zstack_raw']['path']) - mask = generate_file_accessor(data['multichannel_zstack_mask2d']['path']) - - self.roiset = RoiSet.from_binary_mask( - acc_in, - mask, - params=RoiSetMetaParams( - filters={'area': {'min': 1e3, 'max': 1e4}}, - expand_box_by=(128, 2), - deproject_channel=0, - ) - ) - shd = acc_in.shape_dict # add accessor to session registry diff --git a/tests/rois/__init__.py b/tests/rois/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/rois/test_features.py b/tests/rois/test_features.py new file mode 100644 index 0000000000000000000000000000000000000000..dceb59f5533920d8e9e101aaaa961b1d8b483356 --- /dev/null +++ b/tests/rois/test_features.py @@ -0,0 +1,92 @@ +import unittest + +from model_server.base.accessors import generate_file_accessor +import model_server.conf.testing as conf +from model_server.rois.features import regionprops +from model_server.rois.roiset import RoiSet, RoiSetMetaParams + + +data = conf.meta['image_files'] +output_path = conf.meta['output_path'] +params = conf.meta['roiset'] +stack = generate_file_accessor(data['multichannel_zstack_raw']['path']) +seg_mask = generate_file_accessor(data['multichannel_zstack_mask3d']['path']) + + +class TestRoiSetMonoProducts(unittest.TestCase): + + def _make_roi_set(self, mask_type='boxes', **kwargs): + roiset = RoiSet.from_binary_mask( + stack, + seg_mask, + params=RoiSetMetaParams( + mask_type=mask_type, + filters=kwargs.get('filters', {'area': {'min': 1e3, 'max': 1e4}}), + expand_box_by=(128, 2) + ), + allow_3d=True, + ) + return roiset + + def test_multichannel_regionprops_features(self): + roiset = self._make_roi_set() + self.assertGreater(roiset.acc_raw.chroma, 1) + self.assertGreater(roiset.acc_obj_ids.nz, 1) + features_2d = regionprops(roiset, make_3d=False) + features_3d = regionprops(roiset, make_3d=True) + self.assertTrue(all(features_3d.area == roiset.df().bounding_box.area)) + + # test channel-specific + for s in ['max', 'min', 'mean', 'std']: + self.assertTrue( + all([f'intensity_{s}-{c}' in features_2d.columns for c in range(0, roiset.acc_raw.chroma)]) + ) + self.assertTrue( + all([f'intensity_{s}-{c}' in features_3d.columns for c in range(0, roiset.acc_raw.chroma)]) + ) + + def test_mono_regionprops_features(self): + ch = 2 + roiset = self._make_roi_set() + self.assertGreater(roiset.acc_raw.chroma, 1) + self.assertGreater(roiset.acc_obj_ids.nz, 1) + features_2d = regionprops(roiset, make_3d=False, channel=ch) + features_3d = regionprops(roiset, make_3d=True, channel=ch) + self.assertTrue(all(features_3d.area == roiset.df().bounding_box.area)) + + # test channel-specific + for s in ['max', 'min', 'mean', 'std']: + self.assertTrue(f'intensity_{s}-{ch}' in features_2d.columns) + self.assertTrue(all( + [f'intensity_{s}-{c}' not in features_2d.columns for c in range(0, roiset.acc_raw.chroma) if c != ch] + )) + self.assertTrue(f'intensity_{s}-{ch}' in features_3d.columns) + self.assertTrue(all( + [f'intensity_{s}-{c}' not in features_3d.columns for c in range(0, roiset.acc_raw.chroma) if c != ch] + )) + + def test_roiset_extract_features(self): + roiset = self._make_roi_set() + roiset.extract_features(regionprops, make_3d=True, channel=2) + self.assertTrue(('features', 'intensity_std-2') in roiset.df().columns) + self.assertTrue('intensity_std-2' in roiset.get_features()) + + def test_empty_roiset_extract_features(self): + import numpy as np + from model_server.base.accessors import InMemoryDataAccessor + + arr_mask = InMemoryDataAccessor(np.zeros([*stack.hw, 1, stack.nz], dtype='uint8')) + + roiset = RoiSet.from_binary_mask( + stack, + arr_mask, + params=RoiSetMetaParams( + filters={'area': {'min': 1e3, 'max': 1e4}}, + expand_box_by=(128, 2) + ), + allow_3d=True, + ) + self.assertEqual(roiset.count, 0) + roiset.extract_features(regionprops, make_3d=True, channel=2) + self.assertTrue(('features', 'intensity_std-2') not in roiset.df().columns) + self.assertIsNone(roiset.get_features()) \ No newline at end of file diff --git a/tests/rois/test_phenobase.py b/tests/rois/test_phenobase.py new file mode 100644 index 0000000000000000000000000000000000000000..77828146b5c535c5c96a44a820ff51283e1bf373 --- /dev/null +++ b/tests/rois/test_phenobase.py @@ -0,0 +1,282 @@ +import math +import shutil +import unittest + +import numpy as np + +from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor +from model_server.base.session import session +from model_server.conf import testing as conf + +from model_server.rois.models import IntensityThresholdInstanceMaskSegmentationModel +from model_server.rois.phenobase import MissingPatchSeriesError, PhenoBase, RoiSetIndexError +from model_server.rois.roiset import RoiSet, PatchParams, RoiSetExportParams, RoiSetMetaParams + +data = conf.meta['image_files'] +output_path = conf.meta['output_path'] +params = conf.meta['roiset'] + + +class TestPhenoBase(unittest.TestCase): + + def setUp(self, **kwargs): + def _split_in_quarters(acc): + sd = acc.shape_dict + self.assertEqual(list(sd.keys()), ['Y', 'X', 'C', 'Z']) + h = math.floor(0.5 * sd['Y']) + w = math.floor(0.5 * sd['X']) + nda = acc.data + return [ + InMemoryDataAccessor(a) for a in [ + nda[0: h, 0:w, :, :], + nda[0:h, w: 2 * w, :, :], + nda[h: 2*h, 0:w, :, :], + nda[h: 2 * h, w: 2 * w, :, :], + ] + ] + + data = conf.meta['image_files'] + zstacks = _split_in_quarters( + generate_file_accessor( + data['multichannel_zstack_raw']['path'] + ).get_mono( + conf.meta['roiset']['patches_channel'] + ) + ) + masks = _split_in_quarters( + generate_file_accessor( + data['multichannel_zstack_mask2d']['path'] + ) + ) + self.roisets = [ + RoiSet.from_binary_mask( + zstacks[i], + masks[i], + params=RoiSetMetaParams( + filters={'area': {'min': 1e2, 'max': 1e4}}, + expand_box_by=(128, 2) + ) + ) for i in range(0, 4) + ] + + def test_roisets_contain_mix_of_rois(self): + counts = [r.count for r in self.roisets] + self.assertTrue(0 in counts) + self.assertGreater(sum([c > 0 for c in counts]), 2) + + @staticmethod + def _serialize_roisets(roisets, where): + write_to = where / 'phenobase' + if write_to.exists(): + shutil.rmtree(write_to) + write_to.mkdir(parents=True, exist_ok=True) + for ri, roiset in enumerate(roisets): + roiset.run_exports( + write_to, + prefix=f'acq{ri:02d}', + params=RoiSetExportParams( + patches={'channel_zero': PatchParams(white_channel=0)}, + write_patches_to_subdirectory=True, + ) + ) + return where + + def test_from_serialized_roisets(self): + count = sum([r.count for r in self.roisets]) + series = 'patches_channel_zero' + where = self._serialize_roisets(self.roisets, output_path / 'phenobase_unlabeled') + phenobase = PhenoBase.read(where) + self.assertEqual(phenobase.count, count) + self.assertIsNone(phenobase.labels) + + # test size of patch stack products + self.assertEqual(phenobase.list_patch_series(), [series]) + self.assertEqual(phenobase.get_raw_patchstack(series).count, count) + self.assertEqual(phenobase.get_patch_masks().count, count) + + phenobase.write_df() + self.assertTrue( + (output_path / 'phenobase_unlabeled' / 'phenobase.csv').exists() + ) + return phenobase + + def test_classified_rois(self): + for roiset in self.roisets: + roiset.classify_by( + 'permissive_model', + [0], + IntensityThresholdInstanceMaskSegmentationModel(tr=0.0) + ) + phenobase = PhenoBase.read( + self._serialize_roisets(self.roisets, output_path / 'phenobase_labeled') + ) + self.assertEqual(phenobase.count, sum([r.count for r in self.roisets])) + self.assertEqual(phenobase.labels, [1]) + + obmaps = phenobase.get_patch_obmaps() + self.assertTrue(all(obmaps.unique()[0] == [0, 1])) + + phenobase.write_df() + self.assertTrue( + (output_path / 'phenobase_labeled' / 'phenobase.csv').exists() + ) + + def test_split_phenobase(self): + phenobase = self.test_from_serialized_roisets() + split = phenobase.split(0.5) + self.assertEqual(split['train'].count, split['test'].count) + + def test_sample_phenobase(self): + phenobase = self.test_from_serialized_roisets() + sample = phenobase.sample(10) + self.assertEqual(sample.count, 10) + + def test_update_with_new_roiset(self): + where = self._serialize_roisets(self.roisets, output_path / 'phenobase_update') + phenobase = PhenoBase.read(where) + count = sum([r.count for r in self.roisets]) + self.assertEqual(phenobase.count, count) + + # export new RoiSet + new_roiset = RoiSet.from_binary_mask( + self.roisets[0].acc_raw, + self.roisets[0].acc_obj_ids, + params=RoiSetMetaParams( + filters={'area': {'min': 1e2, 'max': 1e4}}, + expand_box_by=(128, 2) + ) + ) + ri = len(self.roisets) + new_roiset.run_exports( + output_path / 'phenobase_update' / 'phenobase', + prefix=f'phenobase_acq{ri:02d}', + params=RoiSetExportParams( + patches={'channel_zero': PatchParams(white_channel=0)}, + write_patches_to_subdirectory=True, + ) + ) + + phenobase.update() + self.assertEqual(phenobase.count, count + new_roiset.count) + + def test_init_from_roiset(self): + # export new RoiSet + roiset = RoiSet.from_binary_mask( + self.roisets[0].acc_raw, + self.roisets[0].acc_obj_ids, + params=RoiSetMetaParams( + filters={'area': {'min': 1e2, 'max': 1e4}}, + expand_box_by=(128, 2) + ) + ) + + phenobase = PhenoBase.from_roiset( + root=output_path / 'phenobase_init_from_roiset', + roiset=roiset, + index_dict={'acq': 1}, + export_params=RoiSetExportParams(patches={'channel_zero': PatchParams(white_channel=0)}), + ) + self.assertEqual(phenobase.count, roiset.count) + self.assertEqual(phenobase.roiset_index, ['coord_acq']) + + def test_push_roiset(self): + where = self._serialize_roisets(self.roisets, output_path / 'phenobase_push') + phenobase = PhenoBase.read(where) + count = sum([r.count for r in self.roisets]) + self.assertEqual(phenobase.count, count) + + # export new RoiSet + new_roiset = RoiSet.from_binary_mask( + self.roisets[0].acc_raw, + self.roisets[0].acc_obj_ids, + params=RoiSetMetaParams( + filters={'area': {'min': 1e2, 'max': 1e4}}, + expand_box_by=(128, 2) + ) + ) + self.assertEqual(phenobase.roiset_index, ['coord_acq']) + + # without specifying patch exports + with self.assertRaises(MissingPatchSeriesError): + phenobase.push(new_roiset, {'acq': 99}) + + # clashes with existing index + with self.assertRaises(RoiSetIndexError): + phenobase.push(new_roiset, {'acq': 3}) + + # delegate patch exports to phenobase + phenobase.push( + new_roiset, + {'acq': 99}, + export_params=RoiSetExportParams(patches={'channel_zero': PatchParams(white_channel=0)}), + ) + self.assertEqual(phenobase.count, count + new_roiset.count) + + def test_push_empty_roiset(self): + where = self._serialize_roisets(self.roisets, output_path / 'phenobase_push') + phenobase = PhenoBase.read(where) + count = phenobase.count + + # export new RoiSet + acc_in = self.roisets[0].acc_raw + mask = InMemoryDataAccessor(np.zeros(acc_in.shape, dtype=acc_in.dtype)) + empty_roiset = RoiSet.from_binary_mask( + acc_in, + mask, + ) + self.assertEqual(empty_roiset.count, 0) + self.assertEqual(phenobase.roiset_index, ['coord_acq']) + phenobase.push( + roiset=empty_roiset, + index_dict={'acq': 10}, + export_params=RoiSetExportParams(patches={'channel_zero': PatchParams(white_channel=0)}), + ) + self.assertEqual(phenobase.count - count, 0) + + +class TestPhenoBaseApi(unittest.TestCase): + + def test_initiate_phenobase_with_roiset(self): + data = conf.meta['image_files'] + acc_in = generate_file_accessor(data['multichannel_zstack_raw']['path']) + mask = generate_file_accessor(data['multichannel_zstack_mask2d']['path']) + + # make test roiset + roiset1 = RoiSet.from_binary_mask( + acc_in, + mask, + params=RoiSetMetaParams( + filters={'area': {'min': 1e3, 'max': 1e4}}, + expand_box_by=(128, 2), + deproject_channel=0, + ) + ) + self.assertGreater(roiset1.count, 2) + + # initiate phenobase in session scope + self.assertIsNone(session.phenobase) + session.add_roiset( + roiset1, + index_dict={'X': 0, 'T': 0}, + export_params=RoiSetExportParams(patches={'ch0': PatchParams(white_channel=0)}), + ) + self.assertEqual(session.phenobase.count, roiset1.count) + + # add to phenobase in session scope + roiset2 = RoiSet.from_binary_mask( + acc_in.apply(lambda x: x + 1), + mask, + params=RoiSetMetaParams( + filters={'area': {'min': 1e3, 'max': 1e4}}, + expand_box_by=(128, 2), + deproject_channel=0, + ) + ) + self.assertGreater(roiset2.count, 2) + + session.add_roiset( + roiset2, + index_dict={'X': 1, 'T': 1}, + export_params=RoiSetExportParams(patches={'ch0': PatchParams(white_channel=0)}), + ) + self.assertEqual(session.phenobase.count, roiset1.count + roiset2.count) \ No newline at end of file diff --git a/tests/base/test_roiset.py b/tests/rois/test_roiset.py similarity index 83% rename from tests/base/test_roiset.py rename to tests/rois/test_roiset.py index a95358817e3803593d6d8c0be0cae6b01a60dc25..fa6654e0e29e1a9e2b46e2799983fc8e2a4595f0 100644 --- a/tests/base/test_roiset.py +++ b/tests/rois/test_roiset.py @@ -6,12 +6,17 @@ from pathlib import Path import pandas as pd +from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.base.process import smooth -from model_server.base.roiset import filter_df_overlap_bbox, filter_df_overlap_seg, IntensityThresholdInstanceMaskSegmentationModel, RoiSet, RoiSetExportParams, RoiSetMetaParams -from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, PatchStack, write_accessor_data_to_file + import model_server.conf.testing as conf from model_server.conf.testing import DummyInstanceMaskSegmentationModel +from model_server.rois.labels import get_label_ids +from model_server.rois.roiset import RoiSet, RoiSetExportParams, RoiSetMetaParams +from model_server.rois.models import IntensityThresholdInstanceMaskSegmentationModel +from model_server.rois.df import filter_df_overlap_bbox, filter_df_overlap_seg, read_roiset_df + data = conf.meta['image_files'] output_path = conf.meta['output_path'] params = conf.meta['roiset'] @@ -45,11 +50,15 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): roiset = self._make_roi_set(**kwargs) # all masks' bounding boxes are at least as big as ROI area - for roi in roiset.get_df().itertuples(): - self.assertEqual(roi.binary_mask.dtype, 'bool') - sh = roi.binary_mask.shape - self.assertEqual(sh, (roi.h, roi.w)) - self.assertGreaterEqual(sh[0] * sh[1], roi.area) + def _validate_roi(roi): + self.assertEqual(roi.masks.binary_mask.dtype, 'bool') + sh = roi.masks.binary_mask.shape + h = roi.bounding_box.h + w = roi.bounding_box.w + area = roi.bounding_box.area + self.assertEqual(sh, (h, w)) + self.assertGreaterEqual(sh[0] * sh[1], area) + roiset.df().apply(_validate_roi, axis=1) def test_roi_zmask(self, **kwargs): roiset = self._make_roi_set(**kwargs) @@ -81,16 +90,18 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): zero_obmap = InMemoryDataAccessor(np.zeros(self.seg_mask.shape, self.seg_mask.dtype)) roiset = RoiSet.from_object_ids(self.stack_ch_pa, zero_obmap) self.assertEqual(roiset.count, 0) - roiset.classify_by('dummy_class', [0], DummyInstanceMaskSegmentationModel()) - self.assertTrue('classify_by_dummy_class' in roiset.get_df().columns) + cl = 'dummy_class' + roiset.classify_by(cl, [0], DummyInstanceMaskSegmentationModel()) + self.assertTrue(cl in roiset.df().classifications.columns) def test_create_roiset_with_no_3d_objects(self): seg_mask_3d = generate_file_accessor(data['multichannel_zstack_mask3d']['path']) zero_obmap = InMemoryDataAccessor(np.zeros(seg_mask_3d.shape, seg_mask_3d.dtype)) roiset = RoiSet.from_object_ids(self.stack_ch_pa, zero_obmap) self.assertEqual(roiset.count, 0) - roiset.classify_by('dummy_class', [0], DummyInstanceMaskSegmentationModel()) - self.assertTrue('classify_by_dummy_class' in roiset.get_df().columns) + cl = 'dummy_class' + roiset.classify_by(cl, [0], DummyInstanceMaskSegmentationModel()) + self.assertTrue(cl in roiset.df().classifications.columns) def test_slices_are_valid(self): roiset = self._make_roi_set() @@ -101,20 +112,23 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_dataframe_and_mask_array_in_iterator(self): roiset = self._make_roi_set() - for roi in roiset: - ma = roi.binary_mask + def _validate_mask_shape(roi): + ma = roi.masks.binary_mask + bb = roi.bounding_box self.assertEqual(ma.dtype, 'bool') - self.assertEqual(ma.shape, (roi.h, roi.w)) + self.assertEqual(ma.shape, (bb.h, bb.w)) + roiset.df().apply(_validate_mask_shape, axis=1) def test_rel_slices_are_valid(self): roiset = self._make_roi_set() - for roi in roiset: - ebb = roiset.acc_raw.data[roi.expanded_slice] + def _validate_rel_slices(roi): + ebb = roiset.acc_raw.data[roi.slices.expanded_slice] self.assertEqual(len(ebb.shape), 4) self.assertTrue(np.all([si >= 1 for si in ebb.shape])) - rbb = ebb[roi.relative_slice] + rbb = ebb[roi.slices.relative_slice] self.assertEqual(len(rbb.shape), 4) self.assertTrue(np.all([si >= 1 for si in rbb.shape])) + roiset.df().apply(_validate_rel_slices, axis=1) def test_make_expanded_2d_patches(self): @@ -127,11 +141,11 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): pad_to=256, make_3d=False, ) - df = roiset.get_df() + df = roiset.df() for f in se_res: acc = generate_file_accessor(where / f) la = int(re.search(r'la([\d]+)', str(f)).group(1)) - roi_q = df.loc[df.label == la, :] + roi_q = df.loc[df.index == la, :] self.assertEqual(len(roi_q), 1) self.assertEqual((256, 256), acc.hw) self.assertEqual(1, acc.nz) @@ -144,14 +158,14 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): draw_bounding_box=True, expanded=False, ) - df = roiset.get_df() + df = roiset.df() for f in se_res: # all exported files are same shape as bounding boxes in RoiSet's datatable acc = generate_file_accessor(where / f) la = int(re.search(r'la([\d]+)', str(f)).group(1)) - roi_q = df.loc[df.label == la, :] + roi_q = df.loc[df.index == la, :] self.assertEqual(len(roi_q), 1) - roi = roi_q.iloc[0] - self.assertEqual((roi.h, roi.w), acc.hw) + bbox = roi_q.iloc[0].bounding_box + self.assertEqual((bbox.h, bbox.w), acc.hw) self.assertEqual(1, acc.nz) def test_make_expanded_3d_patches(self): @@ -182,26 +196,28 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): roiset = self._make_roi_set() se_res = roiset.export_patch_masks(output_path / '2d_mask_patches', ) - df = roiset.get_df() + df = roiset.df() for f in se_res: # all exported files are same shape as bounding boxes in RoiSet's datatable acc = generate_file_accessor(output_path / '2d_mask_patches' / f) la = int(re.search(r'la([\d]+)', str(f)).group(1)) - roi_q = df.loc[df.label == la, :] + roi_q = df.loc[df.index == la, :] self.assertEqual(len(roi_q), 1) - roi = roi_q.iloc[0] - self.assertEqual((roi.h, roi.w), acc.hw) + bbox = roi_q.iloc[0].bounding_box + self.assertEqual((bbox.h, bbox.w), acc.hw) def test_classify_by(self): roiset = self._make_roi_set() - roiset.classify_by('dummy_class', [0], DummyInstanceMaskSegmentationModel()) - self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1])) - self.assertTrue(all(np.unique(roiset.get_object_class_map('dummy_class').data) == [0, 1])) + cl = 'dummy_class' + roiset.classify_by(cl, [0], DummyInstanceMaskSegmentationModel()) + self.assertTrue(all(roiset.df()['classifications', cl].unique() == [1])) + self.assertTrue(all(np.unique(roiset.get_object_class_map(cl).data) == [0, 1])) return roiset def test_classify_by_multiple_channels(self): roiset = RoiSet.from_binary_mask(self.stack, self.seg_mask, params=RoiSetMetaParams(deproject_channel=0)) - roiset.classify_by('dummy_class', [0, 1], DummyInstanceMaskSegmentationModel()) - self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1])) + cl = 'dummy_class' + roiset.classify_by(cl, [0, 1], DummyInstanceMaskSegmentationModel()) + self.assertTrue(all(roiset.df().classifications[cl].unique() == [1])) self.assertTrue(all(np.unique(roiset.get_object_class_map('dummy_class').data) == [0, 1])) return roiset @@ -217,15 +233,16 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): self.assertGreater(total_iou, 0.5) # classify first RoiSet - roiset1.classify_by('dummy_class', [0, 1], DummyInstanceMaskSegmentationModel()) + cl = 'dummy_class' + roiset1.classify_by(cl, [0, 1], DummyInstanceMaskSegmentationModel()) - self.assertTrue('dummy_class' in roiset1.classification_columns) - self.assertFalse('dummy_class' in roiset2.classification_columns) + self.assertTrue(cl in roiset1.classification_columns) + self.assertFalse(cl in roiset2.classification_columns) res = roiset2.get_instance_classification(roiset1) - self.assertTrue('dummy_class' in roiset2.classification_columns) + self.assertTrue(cl in roiset2.classification_columns) self.assertLess( - roiset2.get_df().classify_by_dummy_class.count(), - roiset1.get_df().classify_by_dummy_class.count(), + roiset2.df().classifications[cl].count(), + roiset1.df().classifications[cl].count(), ) @@ -252,11 +269,29 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_patch_masks_are_correct_shape(self): roiset = self._make_roi_set() df_patch_masks = roiset.get_patch_masks() - for roi in df_patch_masks.itertuples(): - h, w, nz = roi.patch_mask.shape + self.assertIsInstance(df_patch_masks.columns, pd.MultiIndex) + def _validate_roi(roi): + h, w, nz = roi.masks.patch_mask.shape self.assertEqual(nz, 1) - self.assertEqual(h, roi.h) - self.assertEqual(w, roi.w) + self.assertEqual(h, roi.bounding_box.h) + self.assertEqual(w, roi.bounding_box.w) + df_patch_masks.apply(_validate_roi, axis=1) + + def test_patch_obmaps_match_df_labels(self): + roiset = self._make_roi_set() + acc_patch_obmaps = roiset.get_patch_obmap_acc() + acc_la = [la for la in acc_patch_obmaps.unique()[0] if la > 0] + acc_la.sort() + df_la = roiset.df().index.values + df_la.sort() + self.assertEqual(len(acc_la), len(df_la)) + self.assertTrue(all([acc_la[i] == df_la[i] for i in range(0, len(acc_la))])) + + def _check_label(roi): + i = roi['patches', 'index'] + la = roi.name + return acc_patch_obmaps.iat(i).unique()[0][-1] == la + self.assertTrue(roiset.df()[[('patches', 'index')]].apply(_check_label, axis=1).all()) class TestRoiSet3dProducts(unittest.TestCase): @@ -277,11 +312,12 @@ class TestRoiSet3dProducts(unittest.TestCase): connect_3d=True, params=RoiSetMetaParams(filters={'area': {'min': 1e2, 'max': 1e8}}) ) - df = roiset.get_df() + df = roiset.df() df.to_csv(self.where / 'df.csv') - self.assertGreater(len(roiset.get_df()['zi'].unique()), 1) - self.assertTrue((df['z1'] - df['z0'] > 1).any()) + bbdf = roiset.df().bounding_box + self.assertGreater(len(bbdf['zi'].unique()), 1) + self.assertTrue((bbdf['z1'] - bbdf['z0'] > 1).any()) roiset.acc_obj_ids.write(self.where / 'labels.tif') return roiset @@ -341,14 +377,15 @@ class TestRoiSet3dProducts(unittest.TestCase): }, }) roiset = self.test_create_roiset_from_3d_obj_ids() - starting_zi = roiset.get_df()['zi'] + starting_zi = roiset.df().bounding_box['zi'] res = roiset.run_exports( self.where, prefix='test', params=p ) - updated_zi = roiset.get_df()['zi'] + self.assertIsInstance(roiset.df().columns, pd.MultiIndex) + updated_zi = roiset.df().bounding_box['zi'] self.assertTrue((starting_zi != updated_zi).any()) def test_mip_patch_masks(self): @@ -366,7 +403,9 @@ class TestRoiSet3dProducts(unittest.TestCase): }, }) - res = self.test_create_roiset_from_3d_obj_ids().run_exports( + rois = self.test_create_roiset_from_3d_obj_ids() + self.assertIsInstance(rois._df.columns, pd.MultiIndex) + res = rois.run_exports( self.where, prefix='test', params=p @@ -416,7 +455,7 @@ class TestRoiSet3dProducts(unittest.TestCase): self.assertTrue((ref_roiset.get_zmask() == test_roiset.get_zmask()).all()) self.assertTrue( np.all( - test_roiset.get_df().label.unique() == ref_roiset.get_df().label.unique() + test_roiset.df().index.unique() == ref_roiset.df().index.unique() ) ) @@ -643,10 +682,10 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa self.assertTrue((where / v).exists()) # test on paths in CSV - test_df = pd.read_csv(where / res['dataframe']) + test_df = read_roiset_df(where / res['dataframe']) for c in ['tight_patch_masks_path', 'patches_2d_path', 'patches_2d_annotated_path']: - self.assertTrue(c in test_df.columns) - for f in test_df[c]: + self.assertTrue(c in test_df.patches.columns) + for f in test_df.patches[c]: self.assertTrue((where / f).exists(), where / f) def test_get_interm_prods(self): @@ -673,13 +712,14 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa params=p ) self.assertNotIn('patches_3d', interm.keys()) + bbox = self.roiset.df().bounding_box self.assertEqual( interm['patches_2d_annotated'].hw, - (self.roiset.get_df().h.max(), self.roiset.get_df().w.max()) + (bbox.h.max(), bbox.w.max()) ) self.assertEqual( interm['patches_2d'].hw, - (self.roiset.get_df().h.max(), self.roiset.get_df().w.max()) + (bbox.h.max(), bbox.w.max()) ) self.assertEqual( interm['annotated_zstacks'].hw, @@ -751,7 +791,6 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa self.assertEqual(pacc.chroma, 1) -from model_server.base.roiset import get_label_ids class TestRoiSetSerialization(unittest.TestCase): def setUp(self) -> None: @@ -790,7 +829,7 @@ class TestRoiSetSerialization(unittest.TestCase): params=RoiSetMetaParams(mask_type='contours') ) self.assertEqual(roiset.count, id_map.data.max()) - self.assertGreater(len(roiset.get_df()['zi'].unique()), 1) + self.assertGreater(len(roiset.df()['bounding_box', 'zi'].unique()), 1) return roiset def test_create_roiset_from_df_and_patch_masks(self): @@ -799,7 +838,7 @@ class TestRoiSetSerialization(unittest.TestCase): ref_roiset.serialize(where_ser, prefix='ref') where_df = where_ser / 'dataframe' / 'ref.csv' self.assertTrue(where_df.exists()) - df_test = pd.read_csv(where_df) + df_test = read_roiset_df(where_df) # check that patches are correct size where_patch_masks = where_ser / 'tight_patch_masks' @@ -807,21 +846,30 @@ class TestRoiSetSerialization(unittest.TestCase): for pmf in where_patch_masks.iterdir(): self.assertTrue(pmf.suffix.upper() == '.PNG') la = int(re.search(r'la([\d]+)', str(pmf)).group(1)) - roi_q = df_test.loc[df_test.label == la, :] + roi_q = df_test.loc[df_test.index == la, :] self.assertEqual(len(roi_q), 1) - roi = roi_q.iloc[0] + bb = roi_q.iloc[0].bounding_box m_acc = generate_file_accessor(pmf) - self.assertEqual((roi.h, roi.w), m_acc.hw) + self.assertEqual((bb.h, bb.w), m_acc.hw) patch_filenames.append(pmf.name) self.assertEqual(m_acc.nz, 1) # make another RoiSet from just the data table, raw images, and (tight) patch masks test_roiset = RoiSet.deserialize(self.stack_ch_pa, where_ser, prefix='ref') - self.assertEqual(ref_roiset.get_zmask().shape, test_roiset.get_zmask().shape,) - self.assertTrue((ref_roiset.get_zmask() == test_roiset.get_zmask()).all()) - self.assertTrue(np.all(test_roiset.get_df().label == ref_roiset.get_df().label)) - cols = ['label', 'y1', 'y0', 'x1', 'x0', 'zi'] - self.assertTrue((test_roiset.get_df()[cols] == ref_roiset.get_df()[cols]).all().all()) + self.assertEqual( + ref_roiset.get_zmask().shape, + test_roiset.get_zmask().shape + ) + self.assertTrue( + (ref_roiset.get_zmask() == test_roiset.get_zmask()).all() + ) + self.assertTrue( + np.all(test_roiset.df().index == ref_roiset.df().index) + ) + cols = ['y1', 'y0', 'x1', 'x0', 'zi'] + self.assertTrue( + (test_roiset.df().bounding_box[cols] == ref_roiset.df().bounding_box[cols]).all().all() + ) # re-serialize and check that patch masks are the same where_dser = output_path / 'deserialize' @@ -862,11 +910,17 @@ class TestEmptyRoiSet(unittest.TestCase): def test_get_patch_masks(self): roiset = self.empty_roiset self.assertEqual(roiset.count, 0) - se_patches = roiset.get_patch_masks(make_3d=True) - self.assertIsInstance(se_patches, pd.DataFrame) - self.assertEqual(len(se_patches), 0) + df_patches = roiset.get_patch_masks(make_3d=True) + self.assertIsInstance(df_patches, pd.DataFrame) + self.assertEqual(len(df_patches), 0) self.assertIsNone(roiset.get_patches_acc(make_3d=True)) + def test_get_obmaps(self): + roiset = self.empty_roiset + self.assertEqual(roiset.count, 0) + acc_obmap = roiset.get_patch_obmap_acc(make_3d=True) + self.assertIsNone(acc_obmap) + def test_run_exports(self): roiset = self.empty_roiset export_params = RoiSetExportParams(**{ @@ -910,12 +964,12 @@ class TestEmptyRoiSet(unittest.TestCase): def test_classify_by(self): roiset = self.empty_roiset - self.assertFalse('classify_by_permissive_model' in roiset.get_df().columns) + self.assertFalse('permissive_model' in roiset.classification_columns) self.assertTrue( roiset.classify_by('permissive_model', [0], IntensityThresholdInstanceMaskSegmentationModel(tr=0.0)) ) self.assertEqual(roiset.count, 0) - self.assertTrue('classify_by_permissive_model' in roiset.get_df().columns) + self.assertTrue('permissive_model' in roiset.classification_columns) class TestRoiSetObjectDetection(unittest.TestCase): @@ -942,7 +996,6 @@ class TestRoiSetObjectDetection(unittest.TestCase): bboxes = table[['y', 'x', 'h', 'w']].to_dict(orient='records') roiset_bbox = RoiSet.from_bounding_boxes(self.stack_ch_pa, bboxes) - self.assertTrue('label' in roiset_bbox.get_df().columns) patches_bbox = roiset_bbox.get_patches_acc() self.assertEqual(len(table), patches_bbox.count) @@ -1010,11 +1063,11 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_overlap_bbox(self): df = pd.DataFrame({ - 'x0': [0, 1, 2, 1, 1], - 'x1': [2, 3, 4, 3, 3], - 'y0': [0, 0, 0, 2, 0], - 'y1': [2, 2, 2, 3, 2], - 'zi': [0, 0, 0, 0, 1], + ('bounding_box', 'x0'): [0, 1, 2, 1, 1], + ('bounding_box', 'x1'): [2, 3, 4, 3, 3], + ('bounding_box', 'y0'): [0, 0, 0, 2, 0], + ('bounding_box', 'y1'): [2, 2, 2, 3, 2], + ('bounding_box', 'zi'): [0, 0, 0, 0, 1], }) df.set_index(pd.Index(range(100, 105)), inplace=True) res = filter_df_overlap_bbox(df) @@ -1027,19 +1080,19 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_overlap_bbox_multiple(self): df1 = pd.DataFrame({ - 'x0': [0, 1], - 'x1': [2, 3], - 'y0': [0, 0], - 'y1': [2, 2], - 'zi': [0, 0], + ('bounding_box', 'x0'): [0, 1], + ('bounding_box', 'x1'): [2, 3], + ('bounding_box', 'y0'): [0, 0], + ('bounding_box', 'y1'): [2, 2], + ('bounding_box', 'zi'): [0, 0], }) df1.set_index(pd.Index(range(100, 102)), inplace=True) df2 = pd.DataFrame({ - 'x0': [2], - 'x1': [4], - 'y0': [0], - 'y1': [2], - 'zi': [0], + ('bounding_box', 'x0'): [2], + ('bounding_box', 'x1'): [4], + ('bounding_box', 'y0'): [0], + ('bounding_box', 'y1'): [2], + ('bounding_box', 'zi'): [0], }) df2.set_index(pd.Index(range(200, 201)), inplace=True) res = filter_df_overlap_bbox(df1, df2) @@ -1050,12 +1103,12 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_overlap_seg(self): df = pd.DataFrame({ - 'x0': [0, 1, 2], - 'x1': [2, 3, 4], - 'y0': [0, 0, 0], - 'y1': [2, 2, 2], - 'zi': [0, 0, 0], - 'binary_mask': [ + ('bounding_box', 'x0'): [0, 1, 2], + ('bounding_box', 'x1'): [2, 3, 4], + ('bounding_box', 'y0'): [0, 0, 0], + ('bounding_box', 'y1'): [2, 2, 2], + ('bounding_box', 'zi'): [0, 0, 0], + ('masks', 'binary_mask'): [ [ [1, 1], [1, 0] @@ -1077,12 +1130,12 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_overlap_seg_multiple(self): df1 = pd.DataFrame({ - 'x0': [0, 1], - 'x1': [2, 3], - 'y0': [0, 0], - 'y1': [2, 2], - 'zi': [0, 0], - 'binary_mask': [ + ('bounding_box', 'x0'): [0, 1], + ('bounding_box', 'x1'): [2, 3], + ('bounding_box', 'y0'): [0, 0], + ('bounding_box', 'y1'): [2, 2], + ('bounding_box', 'zi'): [0, 0], + ('masks', 'binary_mask'): [ [ [1, 1], [1, 0] @@ -1095,12 +1148,12 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase): }) df1.set_index(pd.Index(range(100, 102)), inplace=True) df2 = pd.DataFrame({ - 'x0': [2], - 'x1': [4], - 'y0': [0], - 'y1': [2], - 'zi': [0], - 'binary_mask': [ + ('bounding_box', 'x0'): [2], + ('bounding_box', 'x1'): [4], + ('bounding_box', 'y0'): [0], + ('bounding_box', 'y1'): [2], + ('bounding_box', 'zi'): [0], + ('masks', 'binary_mask'): [ [ [1, 1], [1, 1] @@ -1141,10 +1194,13 @@ class TestIntensityThresholdObjectModel(BaseTestRoiSetMonoProducts, unittest.Tes deproject_channel=0 ) ) - roiset.classify_by('permissive_model', [0], IntensityThresholdInstanceMaskSegmentationModel(tr=0.0)) - self.assertEqual(roiset.get_df()['classify_by_permissive_model'].sum(), roiset.count) - roiset.classify_by('avg_intensity', [0], IntensityThresholdInstanceMaskSegmentationModel(tr=0.2)) - self.assertLess(roiset.get_df()['classify_by_avg_intensity'].sum(), roiset.count) + cl1 = 'permissive_model' + roiset.classify_by(cl1, [0], IntensityThresholdInstanceMaskSegmentationModel(tr=0.0)) + self.assertEqual(roiset.df().classifications[cl1].sum(), roiset.count) + + cl2 = 'avg_intensity' + roiset.classify_by(cl2, [0], IntensityThresholdInstanceMaskSegmentationModel(tr=0.2)) + self.assertLess(roiset.df().classifications[cl2].sum(), roiset.count) return roiset def test_aggregate_classification_results(self): diff --git a/tests/base/test_roiset_derived.py b/tests/rois/test_roiset_derived.py similarity index 90% rename from tests/base/test_roiset_derived.py rename to tests/rois/test_roiset_derived.py index cea9e904362cf39c1709966d5509702737576fa0..025eda5760f6cd6feb00e485b8c1abbee7298128 100644 --- a/tests/base/test_roiset_derived.py +++ b/tests/rois/test_roiset_derived.py @@ -3,8 +3,8 @@ import unittest import numpy as np -from model_server.base.roiset import RoiSetWithDerivedChannelsExportParams, RoiSetMetaParams -from model_server.base.roiset import RoiSetWithDerivedChannels +from model_server.rois.roiset import RoiSetMetaParams +from model_server.rois.derived import RoiSetWithDerivedChannelsExportParams, RoiSetWithDerivedChannels from model_server.base.accessors import generate_file_accessor, PatchStack import model_server.conf.testing as conf from model_server.conf.testing import DummyInstanceMaskSegmentationModel @@ -43,7 +43,7 @@ class TestDerivedChannels(unittest.TestCase): ] ) self.assertGreater(roiset.accs_derived[0].chroma, 1) - self.assertTrue(all(roiset.get_df()['classify_by_multiple_input_model'].unique() == [5])) + self.assertTrue(all(roiset.df()['classifications', 'multiple_input_model'].unique() == [5])) self.assertTrue(all(np.unique(roiset.get_object_class_map('multiple_input_model').data) == [0, 5])) self.assertEqual(len(roiset.accs_derived), 2) diff --git a/tests/base/test_roiset_pipeline.py b/tests/rois/test_roiset_pipeline.py similarity index 68% rename from tests/base/test_roiset_pipeline.py rename to tests/rois/test_roiset_pipeline.py index a0a903e89200b7faaa54702b9c3661cb2ea4e854..1ce96223681f500a0d1cbf3787e6fb8887998469 100644 --- a/tests/base/test_roiset_pipeline.py +++ b/tests/rois/test_roiset_pipeline.py @@ -1,12 +1,12 @@ -import json from pathlib import Path +from shutil import copyfile import unittest import numpy as np from model_server.base.accessors import generate_file_accessor import model_server.conf.testing as conf -from model_server.base.pipelines.roiset_obmap import RoiSetObjectMapParams, roiset_object_map_pipeline +from model_server.rois.pipelines.roiset_obmap import RoiSetObjectMapParams, roiset_object_map_pipeline data = conf.meta['image_files'] output_path = conf.meta['output_path'] @@ -61,7 +61,7 @@ class BaseTestRoiSetMonoProducts(object): def _get_models(self): from model_server.base.models import BinaryThresholdSegmentationModel - from model_server.base.roiset import IntensityThresholdInstanceMaskSegmentationModel + from model_server.rois.models import IntensityThresholdInstanceMaskSegmentationModel return { 'pixel_classifier_segmentation': { 'name': 'min_px_mod', @@ -109,82 +109,6 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase): self.assertTrue('ob_id' in trace.keys()) self.assertEqual(len(trace['ob_id'].unique()[0]), 2) -class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts): - - input_data = data['multichannel_zstack_raw'] - - - def setUp(self) -> None: - self.where_out = output_path / 'roiset' - self.where_out.mkdir(parents=True, exist_ok=True) - return conf.TestServerBaseClass.setUp(self) - - def test_trivial_api_response(self): - self.assertGetSuccess('') - - def test_load_input_accessor(self): - fname = self.copy_input_file_to_server() - return self.assertPutSuccess(f'accessors/read_from_file/{fname}') - - def test_load_pixel_classifier(self): - mid = self.assertPutSuccess( - 'models/seg/threshold/load/', - query={'tr': 0.2}, - )['model_id'] - self.assertTrue(mid.startswith('BinaryThresholdSegmentationModel')) - return mid - - def test_load_object_classifier(self): - mid = self.assertPutSuccess( - 'models/classify/threshold/load/', - body={'tr': 0} - )['model_id'] - self.assertTrue(mid.startswith('IntensityThresholdInstanceMaskSegmentation')) - return mid - - def _object_map_workflow(self, ob_classifer_id): - res = self.assertPutSuccess( - 'pipelines/roiset_to_obmap/infer', - body={ - 'accessor_id': self.test_load_input_accessor(), - 'pixel_classifier_segmentation_model_id': self.test_load_pixel_classifier(), - 'object_classifier_model_id': ob_classifer_id, - 'segmentation': {'channel': 0}, - 'patches_channel': 1, - 'roi_params': self._get_roi_params(), - 'export_params': self._get_export_params(), - }, - ) - - # check on automatically written RoiSet - roiset_id = res['roiset_id'] - roiset_info = self.assertGetSuccess(f'rois/{roiset_id}') - self.assertGreater(roiset_info['count'], 0) - return res - - def test_workflow_with_object_classifier(self): - obmod_id = self.test_load_object_classifier() - res = self._object_map_workflow(obmod_id) - acc_obmap = self.get_accessor(res['output_accessor_id']) - self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1])) - - # get object map via RoiSet API - roiset_id = res['roiset_id'] - obmap_id = self.assertPutSuccess(f'rois/obmap/{roiset_id}/{obmod_id}', query={'object_classes': True}) - acc_obmap_roiset = self.get_accessor(obmap_id) - self.assertTrue(np.all(acc_obmap_roiset.data == acc_obmap.data)) - - # check serialize RoiSet - self.assertPutSuccess(f'rois/write/{roiset_id}') - self.assertFalse( - self.assertGetSuccess(f'rois/{roiset_id}')['loaded'] - ) - - - def test_workflow_without_object_classifier(self): - res = self._object_map_workflow(None) - acc_obmap = self.get_accessor(res['output_accessor_id']) - self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1])) class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts): @@ -218,7 +142,7 @@ class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProd def _object_map_workflow(self, ob_classifer_id): return self.assertPutSuccess( - 'pipelines/roiset_to_obmap/infer', + 'rois/pipelines/roiset_to_obmap', body={ 'accessor_id': self.test_load_input_accessor(), 'pixel_classifier_segmentation_model_id': self.test_load_pixel_classifier(), @@ -247,7 +171,7 @@ class TestTaskQueuedRoiSetWorkflowOverApi(TestRoiSetWorkflowOverApi): def _object_map_workflow(self, ob_classifer_id): res_queue = self.assertPutSuccess( - 'pipelines/roiset_to_obmap/infer', + 'rois/pipelines/roiset_to_obmap', body={ 'schedule': True, 'accessor_id': self.test_load_input_accessor(), @@ -272,4 +196,35 @@ class TestTaskQueuedRoiSetWorkflowOverApi(TestRoiSetWorkflowOverApi): self.assertTrue(res_run) self.assertEqual(self.assertGetSuccess(f'tasks/get/{task_id}')['status'], 'FINISHED') - return res_run \ No newline at end of file + return res_run + +class TestAddRoiSetOverApi(conf.TestServerBaseClass): + def test_add_roiset(self): + pa_stack = data['multichannel_zstack_raw']['path'] + pa_mask = data['multichannel_zstack_mask2d']['path'] + where_in = Path(self.assertGetSuccess('paths')['inbound_images']) + copyfile(pa_stack, where_in / pa_stack.name) + copyfile(pa_mask, where_in / pa_mask.name) + + acc_id_stack = self.assertPutSuccess(f'accessors/read_from_file/{pa_stack.name}') + acc_id_mask = self.assertPutSuccess(f'accessors/read_from_file/{pa_mask.name}') + + res = self.assertPutSuccess( + 'rois/pipelines/add_roiset', + body={ + 'accessor_id': acc_id_stack, + 'labels_accessor_id': acc_id_mask, # binary mask, not labels, so there's only on ROI + 'roiset_index': {'X': 0, 'T': 0}, + 'exports': {'patches': {'channel_zero': {'white_channel': 0}}}, + 'roi_params': {'deproject_channel': 0} + } + ) + + # mask patch stack has one position + acc_id_out = res['output_accessor_id'] + sd = self.assertGetSuccess(f'accessors/get/{acc_id_out}')['shape_dict'] + self.assertEqual(sd['P'], 1) + + # RoiSet has one entry + phenobase = self.assertGetSuccess('phenobase/bounding_box') + self.assertEqual(len(phenobase), 1) \ No newline at end of file diff --git a/tests/test_ilastik/test_ilastik.py b/tests/test_ilastik/test_ilastik.py index 81f9b0126837db7169a75f61d6142ea6bc5b7d92..9fe0356f36e8040b46ea7c0ba077ca482eccfeb0 100644 --- a/tests/test_ilastik/test_ilastik.py +++ b/tests/test_ilastik/test_ilastik.py @@ -6,10 +6,10 @@ import numpy as np from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, PatchStack, write_accessor_data_to_file from model_server.base.api import app -from model_server.extensions.ilastik import models as ilm -from model_server.extensions.ilastik.pipelines import px_then_ob -from model_server.extensions.ilastik.router import router -from model_server.base.roiset import RoiSet, RoiSetMetaParams +from model_server.ilastik import models as ilm +from model_server.ilastik.pipelines import px_then_ob +from model_server.ilastik.router import router +from model_server.rois.roiset import RoiSet, RoiSetMetaParams from model_server.base.pipelines import segment import model_server.conf.testing as conf