-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict_nuclei_multiple.OLD
77 lines (56 loc) · 2.47 KB
/
predict_nuclei_multiple.OLD
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# StarDist and HistomicsTK installed
import sys
import time
import numpy as np
import glob
import skimage as ski
from skimage import io
from skimage.color import rgb2hed, hed2rgb
from csbdeep.utils import Path, normalize
from csbdeep.io import save_tiff_imagej_compatible
from stardist import random_label_cmap, _draw_polygons, export_imagej_rois
from stardist.models import StarDist2D
np.random.seed(6)
lbl_cmap = random_label_cmap()
from multiprocessing import Pool
import os
import logging
timestr = time.strftime("%Y%m%d-%H%M%S")
log_file = '/users/ad394h/Documents/nuclei_segment/logs/predict_nuclei_{}.txt'.format(timestr)
logging.basicConfig(filename = log_file, level=logging.INFO, format='%(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
IN = "/users/ad394h/Documents/nuclei_segment/data/karin_he_images_40X/"
DECONV_OUT = "/users/ad394h/Documents/nuclei_segment/data/karin_he_image_labels_40X/"
# def extract_hematoxylin(IMG):
# img = ski.io.imread(os.path.join(IN,IMG))
# img = rgb2hed(img)
# null = np.zeros_like(img[:,:,0])
# deconv_img = hed2rgb(np.stack((img[:,:,0],null,null),axis=-1))
# deconv_img = deconv_img*255
# deconv_img = deconv_img.astype(np.uint8)
# # ski.io.imsave(DECONV_OUT+f"{IMG[:-4]}_deconv_image.jpg",deconv_img)
# return deconv_img
def segment_nuclei(inp_image):
model_dict ={} # this is an expensive way to create multiple images of the model. the error messages
# in the slurm cluster may be due to multiple processes trying to access the same model
model_id = inp_image[:-4]
model_dict[model_id] = StarDist2D.from_pretrained('2D_versatile_he')
if not model_dict[model_id]:
logger.info("model has not been loaded")
else:
logger.info("model exists")
# read the image
image = io.imread(os.path.join(IN,inp_image))
# mormalize the image
image = normalize(image, 1,99.8)
# call the model
img_label, _ = model_dict[model_id].predict_instances(image) # this should call unique instances of the model
num_nuclei = np.unique(img_label).shape[0]
out_image = f"{inp_image[:-4]}_predicted_labels.png"
io.imsave(os.path.join(DECONV_OUT,out_image),img_label)
logger.info(f"image {inp_image[:-4]} has {num_nuclei} nuclei")
return num_nuclei,img_label
if __name__ == '__main__':
pool = Pool(10)
# Create a multiprocessing Pool
pool.map(segment_nuclei, os.listdir(IN))