Commit 2f0a0fce authored by Constantin Pape's avatar Constantin Pape
Browse files

Refactor stardist and util; add stub for cellpose and plantseg

parent 11776c40
......@@ -4,7 +4,7 @@ Training and inference scripts for deep learning tools for cell segmentation in
Available tools:
- [stardist](https://github.com/mpicbg-csbd/stardist): stardist for convex object segmentation
- embl-tools: general purpose tools for visualising training data and predictions as well as running gpu jobs on the embl cluster
- utils: functionality for visualisation and job submission on compute clusters
## Data Layout
......@@ -18,4 +18,4 @@ root-folder/
```
The folder `images` contains the training image data and the labels the training label data.
The corresponding images and labels **must have exactly the same name**.
The data should be stored in tif format.
The data should be stored in tif format. For multi-channel images, we assume that they are stored channel-first, i.e. in cyx order.
......@@ -9,8 +9,8 @@ setup(
license='MIT',
entry_points={
"console_scripts": [
"train_stardist_2d = training.train_stardist_2d:main",
"predict_stardist_2d = prediction.predict_stardist_2d:main"
"train_stardist_2d = stardist_impl.train_stardist_2d:main",
"predict_stardist_2d = stardist_impl.predict_stardist_2d:main"
]
},
)
......@@ -20,9 +20,9 @@ def get_image_files(root, image_folder, ext):
return images
# TODO could be done more efficiently, see
# could be done more efficiently, see
# https://github.com/hci-unihd/batchlib/blob/master/batchlib/segmentation/stardist_prediction.py
def run_prediction(image_files, model_path, root, prediction_folder):
def run_prediction(image_files, model_path, root, prediction_folder, multichannel):
# load the model
model_name, model_root = os.path.split(model_path)
......@@ -38,7 +38,10 @@ def run_prediction(image_files, model_path, root, prediction_folder):
ax_norm = (0, 1) # independent normalization for multichannel images
for im_file in tqdm(image_files, desc="run stardist prediction"):
im = imageio.imread(im_file)
if multichannel:
im = imageio.imread(im_file).transpose((1, 2, 0))
else:
im = imageio.imread(im_file)
im = normalize(im, lower_percentile, upper_percentile, axis=ax_norm)
pred, _ = model.predict_instances(im)
......@@ -47,13 +50,13 @@ def run_prediction(image_files, model_path, root, prediction_folder):
imageio.imsave(save_path, pred)
def predict_stardist(root, model_path, image_folder, prediction_folder, ext):
def predict_stardist(root, model_path, image_folder, prediction_folder, ext, multichannel):
print("Loading images")
image_files = get_image_files(root, image_folder, ext)
print("Found", len(image_files), "images for prediction")
print("Start prediction ...")
run_prediction(image_files, model_path, root, prediction_folder)
run_prediction(image_files, model_path, root, prediction_folder, multichannel)
print("Finished prediction")
......@@ -66,9 +69,11 @@ def main():
parser.add_argument('--prediction_folder', type=str, default='predictions',
help="Name of the folder where the predictions should be stored, default: predictions.")
parser.add_argument('--ext', type=str, default='.tif', help="Image file extension, default: .tif")
parser.add_argument('--multichannel', type=int, default=0, help="Do we have multichannel images? Default: 0")
args = parser.parse_args()
predict_stardist(args.root, args.model_path, args.image_folder, args.prediction_folder, args.ext)
predict_stardist(args.root, args.model_path, args.image_folder, args.prediction_folder,
args.ext, args.multichannel)
if __name__ == '__main__':
......
......@@ -37,7 +37,7 @@ def check_training_images(train_images, train_labels):
return n_channels
def load_training_data(root, image_folder, labels_folder, ext):
def load_training_data(root, image_folder, labels_folder, ext, multichannel):
# get the image and label mask paths and validate them
image_pattern = os.path.join(root, image_folder, f'*{ext}')
......@@ -61,7 +61,11 @@ def load_training_data(root, image_folder, labels_folder, ext):
ax_norm = (0, 1) # independent normalization for multichannel images
# load the images, check tham and preprocess the data
train_images = [imageio.imread(im) for im in train_images]
if multichannel:
# NOTE, we assume that images are stored as channel first, but stardist expects channel last
train_images = [imageio.volread(im).transpose((1, 2, 0)) for im in train_images]
else:
train_images = [imageio.imread(im) for im in train_images]
train_labels = [imageio.imread(im) for im in train_labels]
n_channels = check_training_images(train_images, train_labels)
train_images = [normalize(im, lower_percentile, upper_percentile, axis=ax_norm) for im in train_images]
......@@ -146,9 +150,10 @@ def train_model(x_train, y_train, x_val, y_val, save_path,
def train_stardist_model(root, model_save_path, image_folder, labels_folder, ext,
validation_fraction, patch_size):
validation_fraction, patch_size, multichannel):
print("Loading training data")
train_images, train_labels, n_channels = load_training_data(root, image_folder, labels_folder, ext)
train_images, train_labels, n_channels = load_training_data(root, image_folder, labels_folder,
ext, multichannel)
print("Found", len(train_images), "images and label masks for training")
x_train, y_train, x_val, y_val = make_train_val_split(train_images, train_labels,
......@@ -181,12 +186,14 @@ def main():
help="The fraction of available data that is used for validation, default: .1")
parser.add_argument('--patch_size', type=int, nargs=2, default=[256, 256],
help="Size of the image patches used to train the network, default: 256, 256")
parser.add_argument('--multichannel', type=int, default=0,
help="Do we have multichannel images? Default: 0")
args = parser.parse_args()
train_stardist_model(args.root, args.model_save_path,
args.image_folder, args.labels_folder,
args.ext, args.validation_fraction,
tuple(args.patch_size))
tuple(args.patch_size), bool(args.multichannel))
if __name__ == '__main__':
......
......@@ -9,8 +9,8 @@ setup(
license='MIT',
entry_points={
"console_scripts": [
"view_data = visualisation.view_data:main",
"submit_to_slurm = cluster.submit_to_slurm:main"
"view_data = utils_impl.view_data:main",
"submit_slurm = utils_impl.submit_to_slurm:main"
]
},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment