Commit 42daed9d authored by Constantin Pape's avatar Constantin Pape

Initial commit with stardist implementation

parents
__pycache__/
*.egg-info/
MIT License
Copyright (c) 2020 Constantin Pape
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# Deep Cell
Training and inference scripts for deep learning tools for cell segmentation in microscopy images.
Available tools:
- [stardist](https://github.com/mpicbg-csbd/stardist): stardist for convex object segmentation
- utils: functionality for visualisation and job submission on compute clusters
## Data Layout
The goal of this small package is to provide an easy way to train different tools via the command line from the cell data layout.
In order to use it, you will need training data images and labels in the following layout:
```
root-folder/
images/
labels/
```
The folder `images` contains the training image data and the labels the training label data.
The corresponding images and labels **must have exactly the same name**.
The data should be stored in tif format. For multi-channel images, we assume that they are stored channel-first, i.e. in cyx order.
## Setting up conda
The software dependencies of this repository are installed via conda.
If you don't have conda installed yet, check out the [installation instructions](https://docs.conda.io/projects/conda/en/latest/user-guide/install/).
Note that there are two different versions of conda:
Anaconda, which comes with the conda package manager and a complete environment with popular python packages as well as miniconda, which only contains the conda package manager.
For our purposes, it is sufficient to install miniconda (but using anaconda does not hurt if you need it for some other purpose).
### Setting up a multi-user conda installation
In order to set up a shared conda installation for multiple users on linux infrastrcture, follow these steps:
1. Download the latest version of miniconda: https://docs.conda.io/en/latest/miniconda.html.
2. Open a terminal and cd to the location where you have downloaded the installation script
3. Set executable permissions: `chmod +x Miniconda3-latest-Linux-x86_64.sh`
4. Exectute the installation script: `./Miniconda3-latest-Linux-x86_64.sh`
5. Agree to the license
6. Choose the installation directory and start the installation. Make sure that all users have read and execute permissions for the installation location.
7. After the installation is finished, choose `no` when asked `Do you wish to initialize Miniconda3 by running conda init?`.
8. This will end the installation and print out a few lines explaining how to activate the conda base environment.
9. Copy the line `eval "$(/path/to/miniconda3/bin/conda shell.YOUR_SHELL_NAME hook)"`and paste it into a file `init_conda.sh`.
10. Replace `YOUR_SHELL_NAME` with `bash` (assuming you and your users are using a bash shell; for zshell, replace with `zsh`, etc.).
Now, the conda base environment can be activated via running `source init_conda.sh`.
Use it to set up the environments to make the applications available to your users.
Important: in order to install version conflicts use separate environments for different application and don't install them to the base environment!
In order for users to activate one of these environments, they will need to first activate the base environment and then the desired env:
```shell
source init_conda.sh
conda activate <ENV_NAME>
```
### Advanced: setting up a conda environment for multiple maintainers and users
TODO
# Stardist
Training and prediction scripts for [stardist](https://github.com/mpicbg-csbd/stardist) models in 2d and 3d.
These scripts were adapted from the [stardist example notebooks](https://github.com/mpicbg-csbd/stardist/tree/master/examples).
## Installation
In order to install the software, you need miniconda. If you have not installed it yet, please download and install it [following the online instructions](https://docs.conda.io/en/latest/miniconda.html).
Once you have miniconda installed, make sure it is activated. Then you can install the requirements and activate it via:
```
conda env create -f environment_gpu.yaml
conda activate stardist-gpu
```
or, if you don't have a gpu available, via
```
conda env create -f environment_cpu.yaml
conda activate stardist-cpu
```
Finally, install the scripts to the environment via running
```
pip install -e .
```
in this folder.
Note that on the EMBL cluster, you need to make sure to use the correct OpenMPI version: run
```
module load OpenMPI/3.1.4-GCC-7.3.0-2.30
```
**BEFORE** the installation steps.
## Running the scripts
You can run the following scripts to train or predict a stardist model:
```
CUDA_VISIBLE_DEVICES=0 train_stardist_2d /path/to/data /path/to/model
```
```
CUDA_VISIBLE_DEVICES=0 predict_stardist_2d /path/to/data /path/to/model
```
The `CUDA_VISIBLE_DEVICES=0` part determines which gpu is used. If you have a machine with multiple GPUs and don't want to
use the first one, you need to change the `0` to the id of the GPU you want to use.
In order to run these scripts on the embl via slurm, you can use the `submit_slurm` script from `deep_cell.utils`, e.g.
```
submit_slurm train_stardist_2d /path/to/data /path/to/model
```
Scripts to train and predict with a 3d stardist model are also available: `train_stardist_3d`, `predict_stardist_3d`.
channels:
- conda-forge
- defaults
name:
stardist-cpu
dependencies:
- imageio
- jupyter
- h5py
- napari
- pip
- python 3.7
- scikit-image
- tensorflow 1.14
- tqdm
- pip:
- stardist
channels:
- conda-forge
- defaults
name:
stardist-gpu
dependencies:
- imageio
- jupyter
- h5py
- napari
- python 3.7
- pip
- scikit-image
- tqdm
- tensorflow-gpu 1.14
- pip:
- stardist
from setuptools import setup, find_packages
setup(
name="deep_cell.stardist",
packages=find_packages(),
version="0.0.1",
author="Constantin Pape",
url="https://github.com/constantinpape/deep-cell",
license='MIT',
entry_points={
"console_scripts": [
"train_stardist_2d = stardist_impl.train_stardist_2d:main",
"predict_stardist_2d = stardist_impl.predict_stardist_2d:main",
"train_stardist_3d = stardist_impl.train_stardist_3d:main",
"predict_stardist_3d = stardist_impl.predict_stardist_3d:main",
"stardist_model_to_fiji = stardist_impl.stardist_model_to_fiji:main"
]
},
)
import argparse
import os
from glob import glob
import imageio
from tqdm import tqdm
from csbdeep.utils import normalize
from stardist.models import StarDist2D
def get_image_files(root, image_folder, ext):
# get the image and label mask paths and validate them
image_pattern = os.path.join(root, image_folder, f'*{ext}')
print("Looking for images with the pattern", image_pattern)
images = glob(image_pattern)
assert len(images) > 0, "Did not find any images"
images.sort()
return images
# could be done more efficiently, see
# https://github.com/hci-unihd/batchlib/blob/master/batchlib/segmentation/stardist_prediction.py
def run_prediction(image_files, model_path, root, prediction_folder, multichannel):
# load the model
model_root, model_name = os.path.split(model_path.rstrip('/'))
model = StarDist2D(None, name=model_name, basedir=model_root)
res_folder = os.path.join(root, prediction_folder)
os.makedirs(res_folder, exist_ok=True)
# normalization parameters: lower and upper percentile used for image normalization
# maybe these should be exposed
lower_percentile = 1
upper_percentile = 99.8
ax_norm = (0, 1) # independent normalization for multichannel images
for im_file in tqdm(image_files, desc="run stardist prediction"):
if multichannel:
im = imageio.volread(im_file).transpose((1, 2, 0))
else:
im = imageio.imread(im_file)
im = normalize(im, lower_percentile, upper_percentile, axis=ax_norm)
pred, _ = model.predict_instances(im)
im_name = os.path.split(im_file)[1]
save_path = os.path.join(res_folder, im_name)
imageio.imsave(save_path, pred)
def predict_stardist(root, model_path, image_folder, prediction_folder, ext, multichannel):
print("Loading images")
image_files = get_image_files(root, image_folder, ext)
print("Found", len(image_files), "images for prediction")
print("Start prediction ...")
run_prediction(image_files, model_path, root, prediction_folder, multichannel)
print("Finished prediction")
def main():
parser = argparse.ArgumentParser(description="Predict new images with a stardist model")
parser.add_argument('root', type=str, help="Root folder with image data.")
parser.add_argument('model_path', type=str, help="Where the model is saved.")
parser.add_argument('--image_folder', type=str, default='images',
help="Name of the folder with the training images, default: images.")
parser.add_argument('--prediction_folder', type=str, default='predictions',
help="Name of the folder where the predictions should be stored, default: predictions.")
parser.add_argument('--ext', type=str, default='.tif', help="Image file extension, default: .tif")
parser.add_argument('--multichannel', type=int, default=0, help="Do we have multichannel images? Default: 0")
args = parser.parse_args()
predict_stardist(args.root, args.model_path, args.image_folder, args.prediction_folder,
args.ext, args.multichannel)
if __name__ == '__main__':
main()
import argparse
import os
from glob import glob
import imageio
from tqdm import tqdm
from csbdeep.utils import normalize
from stardist.models import StarDist3D
def get_image_files(root, image_folder, ext):
# get the image and label mask paths and validate them
image_pattern = os.path.join(root, image_folder, f'*{ext}')
print("Looking for images with the pattern", image_pattern)
images = glob(image_pattern)
assert len(images) > 0, "Did not find any images"
images.sort()
return images
# could be done more efficiently, see
# https://github.com/hci-unihd/batchlib/blob/master/batchlib/segmentation/stardist_prediction.py
def run_prediction(image_files, model_path, root, prediction_folder):
# load the model
model_root, model_name = os.path.split(model_path.rstrip('/'))
model = StarDist3D(None, name=model_name, basedir=model_root)
res_folder = os.path.join(root, prediction_folder)
os.makedirs(res_folder, exist_ok=True)
# normalization parameters: lower and upper percentile used for image normalization
# maybe these should be exposed
lower_percentile = 1
upper_percentile = 99.8
ax_norm = (0, 1, 2)
for im_file in tqdm(image_files, desc="run stardist prediction"):
im = imageio.volread(im_file)
im = normalize(im, lower_percentile, upper_percentile, axis=ax_norm)
pred, _ = model.predict_instances(im)
im_name = os.path.split(im_file)[1]
save_path = os.path.join(res_folder, im_name)
imageio.imsave(save_path, pred)
def predict_stardist(root, model_path, image_folder, prediction_folder, ext):
print("Loading images")
image_files = get_image_files(root, image_folder, ext)
print("Found", len(image_files), "images for prediction")
print("Start prediction ...")
run_prediction(image_files, model_path, root, prediction_folder)
print("Finished prediction")
def main():
parser = argparse.ArgumentParser(description="Predict new images with a stardist model")
parser.add_argument('root', type=str, help="Root folder with image data.")
parser.add_argument('model_path', type=str, help="Where the model is saved.")
parser.add_argument('--image_folder', type=str, default='images',
help="Name of the folder with the training images, default: images.")
parser.add_argument('--prediction_folder', type=str, default='predictions',
help="Name of the folder where the predictions should be stored, default: predictions.")
parser.add_argument('--ext', type=str, default='.tif', help="Image file extension, default: .tif")
args = parser.parse_args()
predict_stardist(args.root, args.model_path, args.image_folder,
args.prediction_folder, args.ext)
if __name__ == '__main__':
main()
import argparse
import os
from stardist.models import StarDist2D
def stardist_model_to_fiji(model_path, model=None):
if model is None:
save_root, save_name = os.path.split(model_path)
model = StarDist2D(None, name=save_name, basedir=save_root)
fiji_save_path = os.path.join(model_path, 'TF_SavedModel.zip')
print("Saving model for fiji", fiji_save_path)
model.export_TF()
def main():
parser = argparse.ArgumentParser(description="Save a stardist model for fiji")
parser.add_argument('model_path', type=str, help="Where the model is saved.")
args = parser.parse_args()
stardist_model_to_fiji(args.model_path)
if __name__ == '__main__':
main()
import argparse
import os
from glob import glob
import imageio
import numpy as np
from csbdeep.utils import normalize
from stardist import fill_label_holes, gputools_available
from stardist.models import Config2D, StarDist2D
from stardist_model_to_fiji import stardist_model_to_fiji
def check_training_data(train_images, train_labels):
train_names = [os.path.split(train_im)[1] for train_im in train_images]
label_names = [os.path.split(label_im)[1] for label_im in train_labels]
assert len(train_names) == len(label_names), "Number of training images and label masks does not match"
assert len(set(train_names) - set(label_names)) == 0, "Image names and label mask names do not match"
def check_training_images(train_images, train_labels):
ndim = train_images[0].ndim
assert all(im.ndim == ndim for im in train_images), "Inconsistent image dimensions"
assert all(im.ndim == 2 for im in train_labels), "Inconsistent label dimensions"
def get_n_channels(im):
return 1 if im.ndim == 2 else im.shape[-1]
def get_im_shape(im):
return im.shape if im.ndim == 2 else im.shape[:-1]
n_channels = get_n_channels(train_images[0])
assert all(get_n_channels(im) == n_channels for im in train_images), "Inconsistent number of image channels"
assert all(label.shape == get_im_shape(im)
for label, im in zip(train_labels, train_images)), "Incosistent shapes of images and labels"
return n_channels
def load_training_data(root, image_folder, labels_folder, ext, multichannel):
# get the image and label mask paths and validate them
image_pattern = os.path.join(root, image_folder, f'*{ext}')
print("Looking for images with the pattern", image_pattern)
train_images = glob(image_pattern)
assert len(train_images) > 0, "Did not find any images"
train_images.sort()
label_pattern = os.path.join(root, labels_folder, f'*{ext}')
print("Looking for labels with the pattern", image_pattern)
train_labels = glob(label_pattern)
assert len(train_labels) > 0, "Did not find any labels"
train_labels.sort()
check_training_data(train_images, train_labels)
# normalization parameters: lower and upper percentile used for image normalization
# maybe these should be exposed
lower_percentile = 1
upper_percentile = 99.8
ax_norm = (0, 1) # independent normalization for multichannel images
# load the images, check tham and preprocess the data
if multichannel:
# NOTE, we assume that images are stored as channel first, but stardist expects channel last
train_images = [imageio.volread(im).transpose((1, 2, 0)) for im in train_images]
else:
train_images = [imageio.imread(im) for im in train_images]
train_labels = [imageio.imread(im) for im in train_labels]
n_channels = check_training_images(train_images, train_labels)
train_images = [normalize(im, lower_percentile, upper_percentile, axis=ax_norm) for im in train_images]
train_labels = [fill_label_holes(im) for im in train_labels]
return train_images, train_labels, n_channels
def make_train_val_split(train_images, train_labels, validation_fraction):
n_samples = len(train_images)
# we do train/val split with a fixed seed in order to be reproducible
rng = np.random.RandomState(42)
indices = rng.permutation(n_samples)
n_val = max(1, int(validation_fraction * n_samples))
train_indices, val_indices = indices[:-n_val], indices[-n_val:]
x_train, y_train = [train_images[i] for i in train_indices], [train_labels[i] for i in train_indices]
x_val, y_val = [train_images[i] for i in val_indices], [train_labels[i] for i in val_indices]
return x_train, y_train, x_val, y_val
# TODO add more augmentations and refactor this so it can be used elsewhere
def random_flips_and_rotations(x, y):
# first, rotate randomly
axes = tuple(range(x.ndim))
permute = np.random.permutation(axes)
x, y = x.transpose(permute), y.transpose(permute)
# second, flip randomly
for ax in axes:
if np.random.rand() > .5:
x, y = np.flip(x, axis=ax), np.flip(y, axis=ax)
return x, y
# multiplicative and additive random noise
def random_uniform_noise(x):
return x * np.random.uniform(0.6, 2) + np.random.uniform(-0.2, 0.2)
def augmenter(x, y):
x, y = random_flips_and_rotations(x, y)
x = random_uniform_noise(x)
return x, y
# we leave n_rays at the default of 32, but may want to expose this as well
def train_model(x_train, y_train, x_val, y_val, save_path,
n_channels, patch_size, n_rays=32):
# make the model config
# copied from the stardist training notebook, this is a very weird line ...
use_gpu = False and gputools_available()
# predict on subsampled image for increased efficiency
grid = (2, 2)
config = Config2D(
n_rays=n_rays,
grid=grid,
use_gpu=use_gpu,
n_channel_in=n_channels,
train_patch_size=patch_size
)
if use_gpu:
print("Using a GPU for training")
# limit gpu memory
from csbdeep.utils.tf import limit_gpu_memory
limit_gpu_memory(0.8)
else:
print("GPU not found, using the CPU for training")
save_root, save_name = os.path.split(save_path)
os.makedirs(save_root, exist_ok=True)
model = StarDist2D(config, name=save_name, basedir=save_root)
model.train(x_train, y_train, validation_data=(x_val, y_val), augmenter=augmenter)
optimal_parameters = model.optimize_thresholds(x_val, y_val)
return model, optimal_parameters
def train_stardist_model(root, model_save_path, image_folder, labels_folder, ext,
validation_fraction, patch_size, multichannel,
save_for_fiji):
print("Loading training data")
train_images, train_labels, n_channels = load_training_data(root,
image_folder, labels_folder,
ext, multichannel)
print("Found", len(train_images), "images and label masks for training")
x_train, y_train, x_val, y_val = make_train_val_split(train_images, train_labels,
validation_fraction)
print("Made train validation split with validation fraction",
validation_fraction, "resulting in")
print(len(x_train), "training images")
print(len(y_train), "validation images")
print("Start model training ...")
print("You can connect to the tensorboard by typing 'tensorboaed --logdir=.' in the folder where the training runs")
model, opt_params = train_model(x_train, y_train, x_val, y_val, model_save_path,
n_channels, patch_size)
print("The model has been trained and was saved to", model_save_path)
print("The following optimal parameters were found:", opt_params)
if save_for_fiji:
stardist_model_to_fiji(model_save_path, model)
# use configarparse?
# TODO set batch size
# TODO enable fine-tuning on pre-trained
# TODO enable excluding images by name
def main():
parser = argparse.ArgumentParser(description="Train a 2d stardist model")
parser.add_argument('root', type=str,
help="Root folder with folders for the training images and labels.")
parser.add_argument('model_save_path', type=str, help="Where to save the model.")
parser.add_argument('--image_folder', type=str, default='images',
help="Name of the folder with the training images, default: images.")
parser.add_argument('--labels_folder', type=str, default='labels',
help="Name of the folder with the training labels, default: labels.")
parser.add_argument('--ext', type=str, default='.tif',
help="Image file extension, default: .tif")
parser.add_argument('--validation_fraction', type=float, default=.1,