Commit e0d4ba90 authored by Constantin Pape's avatar Constantin Pape
Browse files

Update stardist scripts

parent 403d2c36
......@@ -6,7 +6,8 @@ import imageio
import napari
def check_training_data(root, image_folder, labels_folder, ext):
def check_training_data(root, image_folder, labels_folder, ext, prediction_folder,
prediction_is_labels):
image_folder = os.path.join(root, image_folder)
assert os.path.exists(image_folder), f"Could not find {image_folder}"
labels_folder = os.path.join(root, labels_folder)
......@@ -15,26 +16,38 @@ def check_training_data(root, image_folder, labels_folder, ext):
files = glob(os.path.join(image_folder, f"*{ext}"))
files.sort()
for ff in files:
try:
im = imageio.imread(ff)
name = os.path.split(ff)[1]
except Exception as e:
print(f"Could not open {ff}")
print(f"Failed with {e}")
continue
def _load(path):
try:
label_file = os.path.join(labels_folder, name)
labels = imageio.imread(label_file)
im = imageio.imread(path)
name = os.path.split(path)[1]
except Exception as e:
print(f"Could not open {label_file}")
print(f"Could not open {path}")
print(f"Failed with {e}")
im, name = None, None
return im, name
for ff in files:
im, name = _load(ff)
if im is None:
continue
label_file = os.path.join(labels_folder, name)
labels, _ = _load(label_file)
if prediction_folder is not None:
pred_file = os.path.join(prediction_folder, name)
pred, _ = _load(pred_file)
with napari.gui_qt():
viewer = napari.Viewer(title=name)
viewer.add_image(im)
viewer.add_labels(labels)
if prediction_folder is not None:
if prediction_is_labels:
viewer.add_labels(pred)
else:
viewer.add_image(pred)
if __name__ == '__main__':
......@@ -42,7 +55,10 @@ if __name__ == '__main__':
parser.add_argument('root')
parser.add_argument('--image_folder', type=str, default='images')
parser.add_argument('--labels_folder', type=str, default='labels')
parser.add_argument('--prediction_folder', type=str, default=None)
parser.add_argument('--prediction_is_labels', type=int, default=1)
parser.add_argument('--ext', type=str, default='.tif')
args = parser.parse_args()
check_training_data(args.root, args.image_folder, args.labels_folder, args.ext)
check_training_data(args.root, args.image_folder, args.labels_folder, args.ext,
args.prediction_folder, bool(args.prediction_is_labels))
......@@ -8,6 +8,7 @@ dependencies:
- h5py
- napari
- pip
- python 3.7
- scikit-image
- tensorflow 1.15
- tqdm
......
......@@ -7,9 +7,14 @@ dependencies:
- imageio
- h5py
- napari
- python 3.7
- pip
- scikit-image
- tqdm
- pip:
- stardist
# I think this is the gpu version, it appears that tensorflow changed it's naming convention at some point:
# tensorflow < 2 gpu version is called tensorflow
# tensorflow >= 2 gpu version is called tensorflow-gpu
# In summary, tensorflow is a hot mess...
- tensorflow==1.15
- stardist
import argparse
def main():
pass
if __name__ == '__main__':
main()
import argparse
import os
from glob import glob
import imageio
from tqdm import tqdm
from csbdeep.utils import normalize
from stardist.models import StarDist2D
def get_image_files(root, image_folder, ext):
# get the image and label mask paths and validate them
image_pattern = os.path.join(root, image_folder, f'*{ext}')
print("Looking for images with the pattern", image_pattern)
images = glob(image_pattern)
assert len(images) > 0, "Did not find any images"
images.sort()
return images
# TODO could be done more efficiently, see
# https://github.com/hci-unihd/batchlib/blob/master/batchlib/segmentation/stardist_prediction.py
def run_prediction(image_files, model_path, root, prediction_folder):
# load the model
model_name, model_root = os.path.split(model_path)
model = StarDist2D(None, name=model_name, basedir=model_root)
res_folder = os.path.join(root, prediction_folder)
os.makedirs(res_folder, exist_ok=True)
# normalization parameters: lower and upper percentile used for image normalization
# maybe these should be exposed
lower_percentile = 1
upper_percentile = 99.8
ax_norm = (0, 1) # independent normalization for multichannel images
for im_file in tqdm(image_files, desc="run stardist prediction"):
im = imageio.imread(im_file)
im = normalize(im, lower_percentile, upper_percentile, axis=ax_norm)
pred, _ = model.predict_instances(im)
im_name = os.path.split(im_file)[1]
save_path = os.path.join(res_folder, im_name)
imageio.imsave(save_path, pred)
def predict_stardist(root, model_path, image_folder, prediction_folder, ext):
print("Loading images")
image_files = get_image_files(root, image_folder, ext)
print("Found", len(image_files), "images for prediction")
print("Start prediction ...")
run_prediction(image_files, model_path, root, prediction_folder)
print("Finished prediction")
def main():
parser = argparse.ArgumentParser(description="Predict new images with a stardist model")
parser.add_argument('root', type=str, help="Root folder with image data.")
parser.add_argument('model_path', type=str, help="Where the model is saved.")
parser.add_argument('--image_folder', type=str, default='images',
help="Name of the folder with the training images, default: images.")
parser.add_argument('--prediction_folder', type=str, default='predictions',
help="Name of the folder where the predictions should be stored, default: predictions.")
parser.add_argument('--ext', type=str, default='.tif', help="Image file extension, default: .tif")
args = parser.parse_args()
predict_stardist(args.root, args.model_path, args.image_folder, args.prediction_folder, args.ext)
if __name__ == '__main__':
main()
......@@ -10,7 +10,7 @@ setup(
entry_points={
"console_scripts": [
"train_stardist_2d = training.train_stardist_2d:main",
"predict_stardist = prediction.predict_stardist:main"
"predict_stardist_2d = prediction.predict_stardist_2d:main"
]
},
)
......@@ -58,12 +58,13 @@ def load_training_data(root, image_folder, labels_folder, ext):
# maybe these should be exposed
lower_percentile = 1
upper_percentile = 99.8
ax_norm = (0, 1) # independent normalization for multichannel images
# load the images, check tham and preprocess the data
train_images = [imageio.imread(im) for im in train_images]
train_labels = [imageio.imread(im) for im in train_labels]
n_channels = check_training_images(train_images, train_labels)
train_images = [normalize(im, lower_percentile, upper_percentile) for im in train_images]
train_images = [normalize(im, lower_percentile, upper_percentile, axis=ax_norm) for im in train_images]
train_labels = [fill_label_holes(im) for im in train_labels]
return train_images, train_labels, n_channels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment