Attachment 'mri_easyatlas.py'
Download 1 import os
2 import argparse
3 import numpy as np
4 import voxelmorph as vxm
5 import torch
6 import surfa as sf
7 import nibabel as nib
8 import glob
9 from scipy.ndimage import gaussian_filter, binary_dilation, binary_erosion, distance_transform_edt, binary_fill_holes
10 from scipy.ndimage import label as scipy_label
11
12 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
13 os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
14 import tensorflow as tf
15 import keras
16 import keras.backend as K
17 import keras.layers as KL
18
19
20
21 # set tensorflow logging
22 tf.get_logger().setLevel('ERROR')
23 K.set_image_data_format('channels_last')
24
25
26 def main():
27
28 parser = argparse.ArgumentParser(description="EasyAtlas: fast atlas construction with EasyReg", epilog='\n')
29
30 # input/outputs
31 parser.add_argument("--i", help="Input directory with scans")
32 parser.add_argument("--o", help="Output directory where atlas and other files will be written")
33 parser.add_argument("--threads", type=int, default=-1, help="(optional) Number of cores to be used. You can use -1 to use all available cores. Default is -1.")
34 parser.add_argument('--use_reliability_maps', action='store_true', help='Use reliability maps when averaging into atlas (recommended if data are not 1mm isotropic!')
35
36 # parse commandline
37 args = parser.parse_args()
38
39 #############
40
41 # Very first thing: we require FreeSurfer
42 if not os.environ.get('FREESURFER_HOME'):
43 sf.system.fatal('FREESURFER_HOME is not set. Please source freesurfer.')
44 fs_home = os.environ.get('FREESURFER_HOME')
45
46 if args.i is None:
47 sf.system.fatal('Input directory must be provided')
48 if args.o is None:
49 sf.system.fatal('Output directory must be provided')
50
51 # limit the number of threads to be used if running on CPU
52 if args.threads<0:
53 args.threads = os.cpu_count()
54 print('using all available threads ( %s )' % args.threads)
55 else:
56 print('using %s threads' % args.threads)
57 tf.config.threading.set_inter_op_parallelism_threads(args.threads)
58 tf.config.threading.set_intra_op_parallelism_threads(args.threads)
59 torch.set_num_threads(args.threads)
60
61 # path models
62 path_model_segmentation = fs_home + '/models/synthseg_2.0.h5'
63 path_model_parcellation = fs_home + '/models/synthseg_parc_2.0.h5'
64 path_model_registration_trained = fs_home + '/models/easyreg_v10_230103.h5'
65
66 # path labels
67 labels_segmentation = fs_home + '/models/synthseg_segmentation_labels_2.0.npy'
68 labels_parcellation = fs_home + '/models/synthseg_parcellation_labels.npy'
69 atlas_volsize = [160, 160, 192]
70 atlas_aff = np.matrix([[-1, 0, 0, 79], [0, 0, 1, -104], [0, -1, 0, 79], [0, 0, 0, 1]])
71
72 # get label lists
73 labels_segmentation, _ = get_list_labels(label_list=labels_segmentation)
74 labels_segmentation, unique_idx = np.unique(labels_segmentation, return_index=True)
75 labels_parcellation, _ = np.unique(get_list_labels(labels_parcellation)[0], return_index=True)
76
77 # Create output (and SynthSeg) directory if needed
78 if os.path.exists(args.o) and os.path.isdir(args.o):
79 print('Output directory already exists; no need to create it')
80 else:
81 os.mkdir(args.o)
82 segdir = args.o + '/SynthSeg/'
83 if os.path.exists(segdir) and os.path.isdir(segdir):
84 print('SynthSeg directory already exists; no need to create it')
85 else:
86 os.mkdir(segdir)
87 regdir = args.o + '/Registrations/'
88 if os.path.exists(regdir) and os.path.isdir(regdir):
89 print('Registration directory already exists; no need to create it')
90 else:
91 os.mkdir(regdir)
92 tempdir = args.o + '/temp/'
93 if os.path.exists(tempdir) and os.path.isdir(tempdir):
94 print('Temporary directory already exists; no need to create it')
95 else:
96 os.mkdir(tempdir)
97
98 # Build list of input, affine, segmentation files (supports nii, mgz, nii.gz)
99 input_files = sorted(glob.glob(args.i + '/*.nii.gz') + glob.glob(args.i + '/*.nii') + glob.glob(args.i + '/*.mgz'))
100 seg_files = []
101 reg_files = []
102 linear_files = []
103 for file in input_files:
104 _, tail = os.path.split(file)
105 seg_files.append(segdir + '/' + tail)
106 reg_files.append(regdir + '/' + tail)
107 linear_files.append(tempdir + '/' + tail + '.npy')
108
109 # Decide if we need to segment anything
110 all_segs_ready = True
111 for file in seg_files:
112 if os.path.exists(file) is False:
113 all_segs_ready = False
114
115 # Run SynthSeg if needed
116 if all_segs_ready:
117 print('SynthSeg already there for all input files; no need to segment anything')
118 else:
119 print('Setting up segmentation net')
120 segmentation_net = build_seg_model(model_file_segmentation=path_model_segmentation,
121 model_file_parcellation=path_model_parcellation,
122 labels_segmentation=labels_segmentation,
123 labels_parcellation=labels_parcellation)
124 for i in range(len(input_files)):
125 if os.path.exists(seg_files[i]):
126 print('Image ' + str(i + 1) + ' of ' + str(len(input_files)) + ': segmentation already there')
127 else:
128 print('Image ' + str(i + 1) + ' of ' + str(len(input_files)) + ': segmenting')
129 image, aff, h, im_res, shape, pad_idx, crop_idx = preprocess(path_image=input_files[i], crop=None,
130 min_pad=128, path_resample=None)
131 post_patch_segmentation, post_patch_parcellation = segmentation_net.predict(image)
132 seg_buffer, _, _ = postprocess(post_patch_seg=post_patch_segmentation,
133 post_patch_parc=post_patch_parcellation,
134 shape=shape,
135 pad_idx=pad_idx,
136 crop_idx=crop_idx,
137 labels_segmentation=labels_segmentation,
138 labels_parcellation=labels_parcellation,
139 aff=aff,
140 im_res=im_res)
141 save_volume(seg_buffer, aff, h, seg_files[i], dtype='int32')
142
143 # Now the linear registration part
144 print('Linear registration with centroids of segmentations')
145
146 # First, prepare a bunch of common variables
147 labels = np.array([2,4,5,7,8,10,11,12,13,14,15,16,17,18,26,28,41,43,44,46,47,49,50,51,52,53,54,58,60,
148 1001,1002,1003,1005,1006,1007,1008,1009,1010,1011,1012,1013,1014,1015,1016,1017,1018,1019,1020,1021,1022,1023,1024,1025,1026,1027,1028,1029,1030,1031,1032,1033,1034,1035,
149 2001,2002,2003,2005,2006,2007,2008,2009,2010,2011,2012,2013,2014,2015,2016,2017,2018,2019,2020,2021,2022,2023,2024,2025,2026,2027,2028,2029,2030,2031,2032,2033,2034,2035])
150 nlab = len(labels)
151 atlasCOG = np.array([[-28.,-18.,-37.,-19.,-27.,-19.,-23.,-31.,-26.,-2.,-3.,-3.,-29.,-26.,-14.,-14.,24.,14.,31.,12.,18.,14.,19.,26.,21.,25.,22.,11.,8.,-52.,-6.,-36.,-7.,-24.,-37.,-39.,-52.,-9.,-27.,-26.,-14.,-8.,-59.,-28.,-7.,-49.,-43.,-47.,-12.,-46.,-6.,-43.,-10.,-7.,-33.,-11.,-23.,-55.,-50.,-10.,-29.,-46.,-38.,48.,4.,31.,3.,21.,33.,37.,47.,3.,24.,20.,8.,4.,54.,21.,5.,45.,38.,46.,8.,45.,3.,38.,6.,4.,29.,9.,19.,51.,49.,10.,24.,43.,33.],
152 [-30.,-17.,-13.,-36.,-40.,-22.,-3.,-5.,-9.,-14.,-31.,-21.,-15.,-1.,3.,-16.,-32.,-20.,-14.,-37.,-42.,-24.,-3.,-6.,-10.,-15.,-2.,3.,-17.,-44.,-5.,-15.,-71.,2.,-29.,-70.,-23.,-44.,-73.,22.,-57.,27.,-19.,-23.,-45.,4.,31.,20.,-68.,-38.,-33.,-26.,-60.,23.,22.,0.,-72.,-12.,-49.,49.,17.,-25.,-3.,-42.,-1.,-16.,-76.,0.,-34.,-69.,-16.,-44.,-73.,22.,-56.,28.,-18.,-25.,-45.,-3.,30.,14.,-69.,-37.,-32.,-30.,-60.,21.,21.,0.,-72.,-11.,-49.,48.,15.,-27.,-3.],
153 [12.,14.,-13.,-41.,-51.,1.,13.,3.,1.,0.,-40.,-28.,-15.,-10.,2.,-7.,11.,14.,-12.,-40.,-51.,2.,14.,4.,2.,-14.,-10.,4.,-7.,-8.,32.,40.,-14.,-21.,-28.,-4.,-28.,-3.,-35.,3.,-29.,4.,-17.,-21.,35.,18.,9.,20.,-24.,28.,25.,34.,7.,18.,35.,48.,16.,-5.,12.,22.,-18.,1.,4.,-12.,32.,43.,-11.,-21.,-29.,-3.,-27.,0.,-34.,3.,-25.,6.,-18.,-20.,36.,18.,11.,20.,-20.,26.,25.,34.,4.,24.,34.,47.,17.,-5.,10.,20.,-18.,0.,4.]])
154
155 II, JJ, KK = np.meshgrid(np.arange(atlas_volsize[0]), np.arange(atlas_volsize[1]), np.arange(atlas_volsize[2]), indexing='ij')
156 II = torch.tensor(II, device='cpu')
157 JJ = torch.tensor(JJ, device='cpu')
158 KK = torch.tensor(KK, device='cpu')
159
160 # Loop over segmentations and get COGs of ROIs
161 COGs = np.zeros([len(input_files), 4, nlab])
162 OKs = np.zeros([len(input_files), nlab])
163 for i in range(len(input_files)):
164 print('Getting centroids of ROIs: case ' + str(i + 1) + ' of ' + str(len(input_files)))
165 COG = np.zeros([4, nlab])
166 ok = np.ones(nlab)
167 seg_buffer, seg_aff, seg_h = load_volume(seg_files[i], im_only=False, squeeze=True, dtype=None, aff_ref=None)
168 label_to_idx = {lab: ii for ii, lab in enumerate(labels)}
169 coords_per_label = [[] for _ in range(nlab)]
170 nz = np.array(np.nonzero(seg_buffer)).T
171 vals = seg_buffer[tuple(nz.T)]
172 valid_mask = np.isin(vals, labels)
173 nz = nz[valid_mask]
174 vals = vals[valid_mask]
175 idxs = np.searchsorted(labels, vals)
176 for ii in range(nlab):
177 coords_per_label[ii] = nz[idxs == ii]
178 # Compute per-label median centroids
179 for ii, vox in enumerate(coords_per_label):
180 if vox.shape[0] > 50:
181 COG[:3, ii] = np.median(vox, axis=0)
182 COG[3, ii] = 1
183 else:
184 ok[ii] = 0
185 COGs[i] = np.matmul(seg_aff, COG)
186 OKs[i] = ok.copy()
187
188 # Linear registration matrices; first rigid, then affine
189 NUM = np.zeros(atlasCOG.shape)
190 DEN = np.zeros(atlasCOG.shape)
191 for i in range(len(input_files)):
192 M = getMrigid(COGs[i, :-1, OKs[i] > 0].T, atlasCOG[:, OKs[i] > 0])
193 NUM[:, OKs[i] > 0] = NUM[:, OKs[i] > 0] + (M @ COGs[i, :, OKs[i] > 0].T)[:-1, :]
194 DEN[:, OKs[i] > 0] = DEN[:, OKs[i] > 0] + 1
195 rigidAtlasCOG = NUM / DEN
196 Ms = np.zeros([len(input_files), 4, 4])
197 for i in range(len(input_files)):
198 Ms[i] = getM(rigidAtlasCOG[:, OKs[i] > 0], COGs[i, :, OKs[i] > 0].T)
199
200 # OK now we can deform to linear space (and compute linear atlas, while at it)
201 NUM = np.zeros(atlas_volsize)
202 DEN = np.zeros(atlas_volsize)
203 for i in range(len(input_files)):
204 print('Deforming to linear space: case ' + str(i + 1) + ' of ' + str(len(input_files)))
205 im_buffer, im_aff, im_hh = load_volume(input_files[i], im_only=False, squeeze=True, dtype=None, aff_ref=None)
206 im_buffer = torch.tensor(im_buffer, device='cpu')
207 voxdim = np.sqrt(np.sum(im_aff[:-1, :-1] ** 2, axis=0))
208 affine = torch.tensor(np.matmul(np.linalg.inv(im_aff), np.matmul(Ms[i], atlas_aff)), device='cpu')
209 II2 = affine[0, 0] * II + affine[0, 1] * JJ + affine[0, 2] * KK + affine[0, 3]
210 JJ2 = affine[1, 0] * II + affine[1, 1] * JJ + affine[1, 2] * KK + affine[1, 3]
211 KK2 = affine[2, 0] * II + affine[2, 1] * JJ + affine[2, 2] * KK + affine[2, 3]
212 im_lin = fast_3D_interp_torch(im_buffer, II2, JJ2, KK2, 'linear')
213 if args.use_reliability_maps:
214 lin_dists = torch.sqrt(((II2 - II2.round()) * voxdim[0]) ** 2 +
215 ((JJ2 - JJ2.round()) * voxdim[1]) ** 2 +
216 ((KK2 - KK2.round()) * voxdim[2]) ** 2)
217 lin_rel = torch.exp(-1.0 * lin_dists)
218 else:
219 lin_rel = torch.ones(II2.shape)
220
221 seg_buffer, seg_aff, seg_h = load_volume(seg_files[i], im_only=False, squeeze=True, dtype=None, aff_ref=None)
222 affine = torch.tensor(np.matmul(np.linalg.inv(seg_aff), np.matmul(Ms[i], atlas_aff)), device='cpu')
223 II2 = affine[0, 0] * II + affine[0, 1] * JJ + affine[0, 2] * KK + affine[0, 3]
224 JJ2 = affine[1, 0] * II + affine[1, 1] * JJ + affine[1, 2] * KK + affine[1, 3]
225 KK2 = affine[2, 0] * II + affine[2, 1] * JJ + affine[2, 2] * KK + affine[2, 3]
226 seg_lin = fast_3D_interp_torch(torch.tensor(seg_buffer.copy(), device='cpu'), II2, JJ2, KK2, 'nearest')
227 im_lin[seg_lin == 0] = 0
228 im_lin /= torch.median(im_lin[torch.logical_or(seg_lin==2, seg_lin==41)])
229 np.save(linear_files[i], torch.stack([im_lin, lin_rel]).detach().cpu().numpy())
230 NUM += (im_lin * lin_rel).detach().cpu().numpy()
231 DEN += lin_rel.detach().cpu().numpy()
232
233 print('Computing and saving affine atlas')
234 ATLAS = NUM / (1e-9 + DEN)
235 save_volume(ATLAS, atlas_aff, None, args.o + '/atlas.affine.nii.gz')
236
237 print('Building nonlinear registration model')
238 # Build model
239 source = tf.keras.Input(shape=(*atlas_volsize, 1))
240 target = tf.keras.Input(shape=(*atlas_volsize, 1))
241
242 config = {'name': 'vxm_dense', 'fill_value': None, 'input_model': None, 'unet_half_res': True, 'trg_feats': 1,
243 'src_feats': 1, 'use_probs': False, 'bidir': False, 'int_downsize': 2, 'int_steps': 10,
244 'nb_unet_conv_per_level': 1, 'unet_feat_mult': 1, 'nb_unet_levels': None,
245 'nb_unet_features': [[256, 256, 256, 256], [256, 256, 256, 256, 256, 256]], 'inshape': atlas_volsize}
246 cnn = vxm.networks.VxmDense(**config)
247 cnn.load_weights(path_model_registration_trained, by_name=True)
248 svf1 = cnn([source, target])[1]
249 svf2 = cnn([target, source])[1]
250 pos_svf = KL.Lambda(lambda x: 0.5 * x[0] - 0.5 * x[1])([svf1, svf2])
251 neg_svf = KL.Lambda(lambda x: -x)(pos_svf)
252 pos_def_small = vxm.layers.VecInt(method='ss', int_steps=10)(pos_svf)
253 neg_def_small = vxm.layers.VecInt(method='ss', int_steps=10)(neg_svf)
254 pos_def = vxm.layers.RescaleTransform(2)(pos_def_small)
255 neg_def = vxm.layers.RescaleTransform(2)(neg_def_small)
256 model = tf.keras.Model(inputs=[source, target],
257 outputs=[pos_def, neg_def])
258 model.load_weights(path_model_registration_trained)
259
260 # Global atlas building iterations
261 MAX_IT = 5
262 for it in range(MAX_IT):
263 # Initialize new atlas to zeros
264 NUM = np.zeros_like(ATLAS)
265 DEN = np.zeros_like(ATLAS)
266 for i in range(len(input_files)):
267 print('Iteration ' + str(1 + it) + ' of ' + str(MAX_IT) + ', image ' + str(i+1) + ' of ' + str(len(input_files)))
268 lin = np.load(linear_files[i])
269 pred = model.predict([lin[0:1, ..., np.newaxis] / np.max(lin[0]) ,
270 ATLAS[np.newaxis, ..., np.newaxis]])
271 field = torch.tensor(pred[0], device='cpu').squeeze()
272 II2 = II + field[..., 0]
273 JJ2 = JJ + field[..., 1]
274 KK2 = KK + field[..., 2]
275 deformed_im = fast_3D_interp_torch(torch.tensor(lin[0], device='cpu'), II2 , JJ2, KK2, 'linear')
276 deformed_rel = fast_3D_interp_torch(torch.tensor(lin[1], device='cpu'), II2, JJ2, KK2, 'linear')
277 NUM += (deformed_im * deformed_rel).detach().cpu().numpy()
278 DEN += deformed_rel.detach().cpu().numpy()
279 if it == (MAX_IT-1):
280 T = Ms[i] @ atlas_aff
281 RR = T[0, 0] * II2 + T[0, 1] * JJ2 + T[0, 2] * KK2 + T[0, 3]
282 AA = T[1, 0] * II2 + T[1, 1] * JJ2 + T[1, 2] * KK2 + T[1, 3]
283 SS = T[2, 0] * II2 + T[2, 1] * JJ2 + T[2, 2] * KK2 + T[2, 3]
284 save_volume(torch.stack([RR, AA, SS], dim=-1).detach().cpu().numpy(), atlas_aff, None, reg_files[i])
285 ATLAS = NUM / (1e-9 + DEN)
286 save_volume(ATLAS, atlas_aff, None, args.o + '/atlas.iteration.' + str(it+1) + '.nii.gz')
287
288 # Clean up
289 print('Deleting temporary files')
290 for i in range(len(linear_files)):
291 os.remove(linear_files[i])
292 os.rmdir(tempdir)
293
294 print(' ')
295 print('All done!')
296 print(' ')
297 print('If you use EasyReg in your analysis, please cite:')
298 print('A ready-to-use machine learning tool for symmetric multi-modality registration of brain MRI.')
299 print('JE Iglesias. Scientific Reports, 13, article number 6657 (2023).')
300 print('https://www.nature.com/articles/s41598-023-33781-0')
301 print(' ')
302
303
304 #######################
305 # Auxiliary functions #
306 #######################
307
308
309 def get_list_labels(label_list=None, save_label_list=None, FS_sort=False):
310
311 # load label list if previously computed
312 label_list = np.array(reformat_to_list(label_list, load_as_numpy=True, dtype='int'))
313
314
315 # sort labels in neutral/left/right according to FS labels
316 n_neutral_labels = 0
317 if FS_sort:
318 neutral_FS_labels = [0, 14, 15, 16, 21, 22, 23, 24, 72, 77, 80, 85, 100, 101, 102, 103, 104, 105, 106, 107, 108,
319 109, 165, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,
320 251, 252, 253, 254, 255, 258, 259, 260, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340,
321 502, 506, 507, 508, 509, 511, 512, 514, 515, 516, 517, 530,
322 531, 532, 533, 534, 535, 536, 537]
323 neutral = list()
324 left = list()
325 right = list()
326 for la in label_list:
327 if la in neutral_FS_labels:
328 if la not in neutral:
329 neutral.append(la)
330 elif (0 < la < 14) | (16 < la < 21) | (24 < la < 40) | (135 < la < 139) | (1000 <= la <= 1035) | \
331 (la == 865) | (20100 < la < 20110):
332 if la not in left:
333 left.append(la)
334 elif (39 < la < 72) | (162 < la < 165) | (2000 <= la <= 2035) | (20000 < la < 20010) | (la == 139) | \
335 (la == 866):
336 if la not in right:
337 right.append(la)
338 else:
339 raise Exception('label {} not in our current FS classification, '
340 'please update get_list_labels in utils.py'.format(la))
341 label_list = np.concatenate([sorted(neutral), sorted(left), sorted(right)])
342 if ((len(left) > 0) & (len(right) > 0)) | ((len(left) == 0) & (len(right) == 0)):
343 n_neutral_labels = len(neutral)
344 else:
345 n_neutral_labels = len(label_list)
346
347 # save labels if specified
348 if save_label_list is not None:
349 np.save(save_label_list, np.int32(label_list))
350
351 if FS_sort:
352 return np.int32(label_list), n_neutral_labels
353 else:
354 return np.int32(label_list), None
355
356 def reformat_to_list(var, length=None, load_as_numpy=False, dtype=None):
357 # convert to list
358 if var is None:
359 return None
360 var = load_array_if_path(var, load_as_numpy=load_as_numpy)
361 if isinstance(var, (int, float, np.int8, np.int16, np.int32, np.int64, np.float16, np.float32, np.float64)):
362 var = [var]
363 elif isinstance(var, tuple):
364 var = list(var)
365 elif isinstance(var, np.ndarray):
366 if var.shape == (1,):
367 var = [var[0]]
368 else:
369 var = np.squeeze(var).tolist()
370 elif isinstance(var, str):
371 var = [var]
372 elif isinstance(var, bool):
373 var = [var]
374 if isinstance(var, list):
375 if length is not None:
376 if len(var) == 1:
377 var = var * length
378 elif len(var) != length:
379 raise ValueError('if var is a list/tuple/numpy array, it should be of length 1 or {0}, '
380 'had {1}'.format(length, var))
381 else:
382 raise TypeError('var should be an int, float, tuple, list, numpy array, or path to numpy array')
383
384 # convert items type
385 if dtype is not None:
386 if dtype == 'int':
387 var = [int(v) for v in var]
388 elif dtype == 'float':
389 var = [float(v) for v in var]
390 elif dtype == 'bool':
391 var = [bool(v) for v in var]
392 elif dtype == 'str':
393 var = [str(v) for v in var]
394 else:
395 raise ValueError("dtype should be 'str', 'float', 'int', or 'bool'; had {}".format(dtype))
396 return var
397
398 def load_array_if_path(var, load_as_numpy=True):
399 if (isinstance(var, str)) & load_as_numpy:
400 assert os.path.isfile(var), 'No such path: %s' % var
401 var = np.load(var)
402 return var
403
404
405 def load_volume(path_volume, im_only=True, squeeze=True, dtype=None, aff_ref=None):
406
407 assert path_volume.endswith(('.nii', '.nii.gz', '.mgz', '.npz')), 'Unknown data file: %s' % path_volume
408
409 if path_volume.endswith(('.nii', '.nii.gz', '.mgz')):
410 x = nib.load(path_volume)
411 if squeeze:
412 volume = np.squeeze(x.get_fdata())
413 else:
414 volume = x.get_fdata()
415 aff = x.affine
416 header = x.header
417 else: # npz
418 volume = np.load(path_volume)['vol_data']
419 if squeeze:
420 volume = np.squeeze(volume)
421 aff = np.eye(4)
422 header = nib.Nifti1Header()
423 if dtype is not None:
424 if 'int' in dtype:
425 volume = np.round(volume)
426 volume = volume.astype(dtype=dtype)
427
428 # align image to reference affine matrix
429 if aff_ref is not None:
430 n_dims, _ = get_dims(list(volume.shape), max_channels=10)
431 volume, aff = align_volume_to_ref(volume, aff, aff_ref=aff_ref, return_aff=True, n_dims=n_dims)
432
433 if im_only:
434 return volume
435 else:
436 return volume, aff, header
437
438
439
440
441 def preprocess(path_image, n_levels=5, crop=None, min_pad=None, path_resample=None):
442 # read image and corresponding info
443 im, _, aff, n_dims, n_channels, h, im_res = get_volume_info(path_image, True)
444 if n_dims < 3:
445 sf.system.fatal('input should have 3 dimensions, had %s' % n_dims)
446 elif n_dims == 4 and n_channels == 1:
447 n_dims = 3
448 im = im[..., 0]
449 elif n_dims > 3:
450 sf.system.fatal('input should have 3 dimensions, had %s' % n_dims)
451 elif n_channels > 1:
452 print('WARNING: detected more than 1 channel, only keeping the first channel.')
453 im = im[..., 0]
454
455 # resample image if necessary
456 if np.any((im_res > 1.05) | (im_res < 0.95)):
457 im_res = np.array([1.] * 3)
458 im, aff = resample_volume(im, aff, im_res)
459 if path_resample is not None:
460 save_volume(im, aff, h, path_resample)
461
462 # align image
463 im = align_volume_to_ref(im, aff, aff_ref=np.eye(4), n_dims=n_dims, return_copy=False)
464 shape = list(im.shape[:n_dims])
465
466 # crop image if necessary
467 if crop is not None:
468 crop = reformat_to_list(crop, length=n_dims, dtype='int')
469 crop_shape = [find_closest_number_divisible_by_m(s, 2 ** n_levels, 'higher') for s in crop]
470 im, crop_idx = crop_volume(im, cropping_shape=crop_shape, return_crop_idx=True)
471 else:
472 crop_idx = None
473
474 # normalise image
475 im = rescale_volume(im, new_min=0, new_max=1, min_percentile=0.5, max_percentile=99.5)
476
477 # pad image
478 input_shape = im.shape[:n_dims]
479 pad_shape = [find_closest_number_divisible_by_m(s, 2 ** n_levels, 'higher') for s in input_shape]
480 min_pad = reformat_to_list(min_pad, length=n_dims, dtype='int')
481 min_pad = [find_closest_number_divisible_by_m(s, 2 ** n_levels, 'higher') for s in min_pad]
482 pad_shape = np.maximum(pad_shape, min_pad)
483 im, pad_idx = pad_volume(im, padding_shape=pad_shape, return_pad_idx=True)
484
485 # add batch and channel axes
486 im = add_axis(im, axis=[0, -1])
487
488 return im, aff, h, im_res, shape, pad_idx, crop_idx
489
490
491 def resample_volume(volume, aff, new_vox_size, interpolation='linear'):
492 pixdim = np.sqrt(np.sum(aff * aff, axis=0))[:-1]
493 new_vox_size = np.array(new_vox_size)
494 factor = pixdim / new_vox_size
495 sigmas = 0.25 / factor
496 sigmas[factor > 1] = 0 # don't blur if upsampling
497
498 volume_filt = gaussian_filter(volume, sigmas)
499
500 # volume2 = zoom(volume_filt, factor, order=1, mode='reflect', prefilter=False)
501 x = np.arange(0, volume_filt.shape[0])
502 y = np.arange(0, volume_filt.shape[1])
503 z = np.arange(0, volume_filt.shape[2])
504
505 start = - (factor - 1) / (2 * factor)
506 step = 1.0 / factor
507 stop = start + step * np.ceil(volume_filt.shape * factor)
508
509 xi = np.arange(start=start[0], stop=stop[0], step=step[0])
510 yi = np.arange(start=start[1], stop=stop[1], step=step[1])
511 zi = np.arange(start=start[2], stop=stop[2], step=step[2])
512 xi[xi < 0] = 0
513 yi[yi < 0] = 0
514 zi[zi < 0] = 0
515 xi[xi > (volume_filt.shape[0] - 1)] = volume_filt.shape[0] - 1
516 yi[yi > (volume_filt.shape[1] - 1)] = volume_filt.shape[1] - 1
517 zi[zi > (volume_filt.shape[2] - 1)] = volume_filt.shape[2] - 1
518
519 xig, yig, zig = np.meshgrid(xi, yi, zi, indexing='ij', sparse=False)
520 xig = torch.tensor(xig, device='cpu')
521 yig = torch.tensor(yig, device='cpu')
522 zig = torch.tensor(zig, device='cpu')
523 volume2 = fast_3D_interp_torch(torch.tensor(volume_filt, device='cpu'), xig, yig, zig, 'linear')
524
525 aff2 = aff.copy()
526 for c in range(3):
527 aff2[:-1, c] = aff2[:-1, c] / factor[c]
528 aff2[:-1, -1] = aff2[:-1, -1] - np.matmul(aff2[:-1, :-1], 0.5 * (factor - 1))
529
530 return volume2.numpy(), aff2
531
532 def find_closest_number_divisible_by_m(n, m, answer_type='lower'):
533 if n % m == 0:
534 return n
535 else:
536 q = int(n / m)
537 lower = q * m
538 higher = (q + 1) * m
539 if answer_type == 'lower':
540 return lower
541 elif answer_type == 'higher':
542 return higher
543 elif answer_type == 'closer':
544 return lower if (n - lower) < (higher - n) else higher
545 else:
546 sf.system.fatal('answer_type should be lower, higher, or closer, had : %s' % answer_type)
547
548
549
550 def get_volume_info(path_volume, return_volume=False, aff_ref=None, max_channels=10):
551
552 im, aff, header = load_volume(path_volume, im_only=False)
553
554 # understand if image is multichannel
555 im_shape = list(im.shape)
556 n_dims, n_channels = get_dims(im_shape, max_channels=max_channels)
557 im_shape = im_shape[:n_dims]
558
559 # get labels res
560 if '.nii' in path_volume:
561 data_res = np.array(header['pixdim'][1:n_dims + 1])
562 elif '.mgz' in path_volume:
563 data_res = np.array(header['delta']) # mgz image
564 else:
565 data_res = np.array([1.0] * n_dims)
566
567 # align to given affine matrix
568 if aff_ref is not None:
569 ras_axes = get_ras_axes(aff, n_dims=n_dims)
570 ras_axes_ref = get_ras_axes(aff_ref, n_dims=n_dims)
571 im = align_volume_to_ref(im, aff, aff_ref=aff_ref, n_dims=n_dims)
572 im_shape = np.array(im_shape)
573 data_res = np.array(data_res)
574 im_shape[ras_axes_ref] = im_shape[ras_axes]
575 data_res[ras_axes_ref] = data_res[ras_axes]
576 im_shape = im_shape.tolist()
577
578 # return info
579 if return_volume:
580 return im, im_shape, aff, n_dims, n_channels, header, data_res
581 else:
582 return im_shape, aff, n_dims, n_channels, header, data_res
583
584 def get_dims(shape, max_channels=10):
585 if shape[-1] <= max_channels:
586 n_dims = len(shape) - 1
587 n_channels = shape[-1]
588 else:
589 n_dims = len(shape)
590 n_channels = 1
591 return n_dims, n_channels
592
593
594 def get_ras_axes(aff, n_dims=3):
595 aff_inverted = np.linalg.inv(aff)
596 img_ras_axes = np.argmax(np.absolute(aff_inverted[0:n_dims, 0:n_dims]), axis=0)
597 for i in range(n_dims):
598 if i not in img_ras_axes:
599 unique, counts = np.unique(img_ras_axes, return_counts=True)
600 incorrect_value = unique[np.argmax(counts)]
601 img_ras_axes[np.where(img_ras_axes == incorrect_value)[0][-1]] = i
602
603 return img_ras_axes
604
605 def align_volume_to_ref(volume, aff, aff_ref=None, return_aff=False, n_dims=None, return_copy=True):
606
607 # work on copy
608 new_volume = volume.copy() if return_copy else volume
609 aff_flo = aff.copy()
610
611 # default value for aff_ref
612 if aff_ref is None:
613 aff_ref = np.eye(4)
614
615 # extract ras axes
616 if n_dims is None:
617 n_dims, _ = get_dims(new_volume.shape)
618 ras_axes_ref = get_ras_axes(aff_ref, n_dims=n_dims)
619 ras_axes_flo = get_ras_axes(aff_flo, n_dims=n_dims)
620
621 # align axes
622 aff_flo[:, ras_axes_ref] = aff_flo[:, ras_axes_flo]
623 for i in range(n_dims):
624 if ras_axes_flo[i] != ras_axes_ref[i]:
625 new_volume = np.swapaxes(new_volume, ras_axes_flo[i], ras_axes_ref[i])
626 swapped_axis_idx = np.where(ras_axes_flo == ras_axes_ref[i])
627 ras_axes_flo[swapped_axis_idx], ras_axes_flo[i] = ras_axes_flo[i], ras_axes_flo[swapped_axis_idx]
628
629 # align directions
630 dot_products = np.sum(aff_flo[:3, :3] * aff_ref[:3, :3], axis=0)
631 for i in range(n_dims):
632 if dot_products[i] < 0:
633 new_volume = np.flip(new_volume, axis=i)
634 aff_flo[:, i] = - aff_flo[:, i]
635 aff_flo[:3, 3] = aff_flo[:3, 3] - aff_flo[:3, i] * (new_volume.shape[i] - 1)
636
637 if return_aff:
638 return new_volume, aff_flo
639 else:
640 return new_volume
641
642 def build_seg_model(model_file_segmentation,
643 model_file_parcellation,
644 labels_segmentation,
645 labels_parcellation):
646
647 if not os.path.isfile(model_file_segmentation):
648 sf.system.fatal("The provided model path does not exist.")
649
650 # get labels
651 n_labels_seg = len(labels_segmentation)
652
653 # build UNet
654 net = unet(nb_features=24,
655 input_shape=[None, None, None, 1],
656 nb_levels=5,
657 conv_size=3,
658 nb_labels=n_labels_seg,
659 feat_mult=2,
660 activation='elu',
661 nb_conv_per_level=2,
662 batch_norm=-1,
663 name='unet')
664 net.load_weights(model_file_segmentation, by_name=True)
665 input_image = net.inputs[0]
666 name_segm_prediction_layer = 'unet_prediction'
667
668 # smooth posteriors
669 last_tensor = net.output
670 last_tensor._keras_shape = tuple(last_tensor.get_shape().as_list())
671 last_tensor = GaussianBlur(sigma=0.5)(last_tensor)
672 net = keras.Model(inputs=net.inputs, outputs=last_tensor)
673
674 # add aparc segmenter
675 n_labels_parcellation = len(labels_parcellation)
676
677 last_tensor = net.output
678 last_tensor = KL.Lambda(lambda x: tf.cast(tf.argmax(x, axis=-1), 'int32'))(last_tensor)
679 last_tensor = ConvertLabels(np.arange(n_labels_seg), labels_segmentation)(last_tensor)
680 parcellation_masking_values = np.array([1 if ((ll == 3) | (ll == 42)) else 0 for ll in labels_segmentation])
681 last_tensor = ConvertLabels(labels_segmentation, parcellation_masking_values)(last_tensor)
682 last_tensor = KL.Lambda(lambda x: tf.one_hot(tf.cast(x, 'int32'), depth=2, axis=-1))(last_tensor)
683 last_tensor = KL.Lambda(lambda x: tf.cast(tf.concat(x, axis=-1), 'float32'))([input_image, last_tensor])
684 net = keras.Model(inputs=net.inputs, outputs=last_tensor)
685
686 # build UNet
687 net = unet(nb_features=24,
688 input_shape=[None, None, None, 3],
689 nb_levels=5,
690 conv_size=3,
691 nb_labels=n_labels_parcellation,
692 feat_mult=2,
693 activation='elu',
694 nb_conv_per_level=2,
695 batch_norm=-1,
696 name='unet_parc',
697 input_model=net)
698 net.load_weights(model_file_parcellation, by_name=True)
699
700 # smooth predictions
701 last_tensor = net.output
702 last_tensor._keras_shape = tuple(last_tensor.get_shape().as_list())
703 last_tensor = GaussianBlur(sigma=0.5)(last_tensor)
704 net = keras.Model(inputs=net.inputs, outputs=[net.get_layer(name_segm_prediction_layer).output, last_tensor])
705
706 return net
707
708 def unet(nb_features,
709 input_shape,
710 nb_levels,
711 conv_size,
712 nb_labels,
713 name='unet',
714 prefix=None,
715 feat_mult=1,
716 pool_size=2,
717 padding='same',
718 dilation_rate_mult=1,
719 activation='elu',
720 skip_n_concatenations=0,
721 use_residuals=False,
722 final_pred_activation='softmax',
723 nb_conv_per_level=1,
724 layer_nb_feats=None,
725 conv_dropout=0,
726 batch_norm=None,
727 input_model=None):
728
729 # naming
730 model_name = name
731 if prefix is None:
732 prefix = model_name
733
734 # volume size data
735 ndims = len(input_shape) - 1
736 if isinstance(pool_size, int):
737 pool_size = (pool_size,) * ndims
738
739 # get encoding model
740 enc_model = conv_enc(nb_features,
741 input_shape,
742 nb_levels,
743 conv_size,
744 name=model_name,
745 prefix=prefix,
746 feat_mult=feat_mult,
747 pool_size=pool_size,
748 padding=padding,
749 dilation_rate_mult=dilation_rate_mult,
750 activation=activation,
751 use_residuals=use_residuals,
752 nb_conv_per_level=nb_conv_per_level,
753 layer_nb_feats=layer_nb_feats,
754 conv_dropout=conv_dropout,
755 batch_norm=batch_norm,
756 input_model=input_model)
757
758 # get decoder
759 # use_skip_connections=True makes it a u-net
760 lnf = layer_nb_feats[(nb_levels * nb_conv_per_level):] if layer_nb_feats is not None else None
761 dec_model = conv_dec(nb_features,
762 None,
763 nb_levels,
764 conv_size,
765 nb_labels,
766 name=model_name,
767 prefix=prefix,
768 feat_mult=feat_mult,
769 pool_size=pool_size,
770 use_skip_connections=True,
771 skip_n_concatenations=skip_n_concatenations,
772 padding=padding,
773 dilation_rate_mult=dilation_rate_mult,
774 activation=activation,
775 use_residuals=use_residuals,
776 final_pred_activation=final_pred_activation,
777 nb_conv_per_level=nb_conv_per_level,
778 batch_norm=batch_norm,
779 layer_nb_feats=lnf,
780 conv_dropout=conv_dropout,
781 input_model=enc_model)
782 final_model = dec_model
783
784 return final_model
785
786 def conv_enc(nb_features,
787 input_shape,
788 nb_levels,
789 conv_size,
790 name=None,
791 prefix=None,
792 feat_mult=1,
793 pool_size=2,
794 dilation_rate_mult=1,
795 padding='same',
796 activation='elu',
797 layer_nb_feats=None,
798 use_residuals=False,
799 nb_conv_per_level=2,
800 conv_dropout=0,
801 batch_norm=None,
802 input_model=None):
803
804 # naming
805 model_name = name
806 if prefix is None:
807 prefix = model_name
808
809 # first layer: input
810 name = '%s_input' % prefix
811 if input_model is None:
812 input_tensor = KL.Input(shape=input_shape, name=name)
813 last_tensor = input_tensor
814 else:
815 input_tensor = input_model.inputs
816 last_tensor = input_model.outputs
817 if isinstance(last_tensor, list):
818 last_tensor = last_tensor[0]
819
820 # volume size data
821 ndims = len(input_shape) - 1
822 if isinstance(pool_size, int):
823 pool_size = (pool_size,) * ndims
824
825 # prepare layers
826 convL = getattr(KL, 'Conv%dD' % ndims)
827 conv_kwargs = {'padding': padding, 'activation': activation, 'data_format': 'channels_last'}
828 maxpool = getattr(KL, 'MaxPooling%dD' % ndims)
829
830 # down arm:
831 # add nb_levels of conv + ReLu + conv + ReLu. Pool after each of first nb_levels - 1 layers
832 lfidx = 0 # level feature index
833 for level in range(nb_levels):
834 lvl_first_tensor = last_tensor
835 nb_lvl_feats = np.round(nb_features * feat_mult ** level).astype(int)
836 conv_kwargs['dilation_rate'] = dilation_rate_mult ** level
837
838 for conv in range(nb_conv_per_level): # does several conv per level, max pooling applied at the end
839 if layer_nb_feats is not None: # None or List of all the feature numbers
840 nb_lvl_feats = layer_nb_feats[lfidx]
841 lfidx += 1
842
843 name = '%s_conv_downarm_%d_%d' % (prefix, level, conv)
844 if conv < (nb_conv_per_level - 1) or (not use_residuals):
845 last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor)
846 else: # no activation
847 last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor)
848
849 if conv_dropout > 0:
850 # conv dropout along feature space only
851 name = '%s_dropout_downarm_%d_%d' % (prefix, level, conv)
852 noise_shape = [None, *[1] * ndims, nb_lvl_feats]
853 last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor)
854
855 if use_residuals:
856 convarm_layer = last_tensor
857
858 # the "add" layer is the original input
859 # However, it may not have the right number of features to be added
860 nb_feats_in = lvl_first_tensor.get_shape()[-1]
861 nb_feats_out = convarm_layer.get_shape()[-1]
862 add_layer = lvl_first_tensor
863 if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out):
864 name = '%s_expand_down_merge_%d' % (prefix, level)
865 last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(lvl_first_tensor)
866 add_layer = last_tensor
867
868 if conv_dropout > 0:
869 name = '%s_dropout_down_merge_%d_%d' % (prefix, level, conv)
870 noise_shape = [None, *[1] * ndims, nb_lvl_feats]
871
872 name = '%s_res_down_merge_%d' % (prefix, level)
873 last_tensor = KL.add([add_layer, convarm_layer], name=name)
874
875 name = '%s_res_down_merge_act_%d' % (prefix, level)
876 last_tensor = KL.Activation(activation, name=name)(last_tensor)
877
878 if batch_norm is not None:
879 name = '%s_bn_down_%d' % (prefix, level)
880 last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor)
881
882 # max pool if we're not at the last level
883 if level < (nb_levels - 1):
884 name = '%s_maxpool_%d' % (prefix, level)
885 last_tensor = maxpool(pool_size=pool_size, name=name, padding=padding)(last_tensor)
886
887 # create the model and return
888 model = keras.Model(inputs=input_tensor, outputs=[last_tensor], name=model_name)
889 return model
890
891
892 def conv_dec(nb_features,
893 input_shape,
894 nb_levels,
895 conv_size,
896 nb_labels,
897 name=None,
898 prefix=None,
899 feat_mult=1,
900 pool_size=2,
901 use_skip_connections=False,
902 skip_n_concatenations=0,
903 padding='same',
904 dilation_rate_mult=1,
905 activation='elu',
906 use_residuals=False,
907 final_pred_activation='softmax',
908 nb_conv_per_level=2,
909 layer_nb_feats=None,
910 batch_norm=None,
911 conv_dropout=0,
912 input_model=None):
913
914 # naming
915 model_name = name
916 if prefix is None:
917 prefix = model_name
918
919 # if using skip connections, make sure need to use them.
920 if use_skip_connections:
921 assert input_model is not None, "is using skip connections, tensors dictionary is required"
922
923 # first layer: input
924 input_name = '%s_input' % prefix
925 if input_model is None:
926 input_tensor = KL.Input(shape=input_shape, name=input_name)
927 last_tensor = input_tensor
928 else:
929 input_tensor = input_model.input
930 last_tensor = input_model.output
931 input_shape = last_tensor.shape.as_list()[1:]
932
933 # vol size info
934 ndims = len(input_shape) - 1
935 if isinstance(pool_size, int):
936 if ndims > 1:
937 pool_size = (pool_size,) * ndims
938
939 # prepare layers
940 convL = getattr(KL, 'Conv%dD' % ndims)
941 conv_kwargs = {'padding': padding, 'activation': activation}
942 upsample = getattr(KL, 'UpSampling%dD' % ndims)
943
944 # up arm:
945 # nb_levels - 1 layers of Deconvolution3D
946 # (approx via up + conv + ReLu) + merge + conv + ReLu + conv + ReLu
947 lfidx = 0
948 for level in range(nb_levels - 1):
949 nb_lvl_feats = np.round(nb_features * feat_mult ** (nb_levels - 2 - level)).astype(int)
950 conv_kwargs['dilation_rate'] = dilation_rate_mult ** (nb_levels - 2 - level)
951
952 # upsample matching the max pooling layers size
953 name = '%s_up_%d' % (prefix, nb_levels + level)
954 last_tensor = upsample(size=pool_size, name=name)(last_tensor)
955 up_tensor = last_tensor
956
957 # merge layers combining previous layer
958 if use_skip_connections & (level < (nb_levels - skip_n_concatenations - 1)):
959 conv_name = '%s_conv_downarm_%d_%d' % (prefix, nb_levels - 2 - level, nb_conv_per_level - 1)
960 cat_tensor = input_model.get_layer(conv_name).output
961 name = '%s_merge_%d' % (prefix, nb_levels + level)
962 last_tensor = KL.concatenate([cat_tensor, last_tensor], axis=ndims + 1, name=name)
963
964 # convolution layers
965 for conv in range(nb_conv_per_level):
966 if layer_nb_feats is not None:
967 nb_lvl_feats = layer_nb_feats[lfidx]
968 lfidx += 1
969
970 name = '%s_conv_uparm_%d_%d' % (prefix, nb_levels + level, conv)
971 if conv < (nb_conv_per_level - 1) or (not use_residuals):
972 last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor)
973 else:
974 last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor)
975
976 if conv_dropout > 0:
977 name = '%s_dropout_uparm_%d_%d' % (prefix, level, conv)
978 noise_shape = [None, *[1] * ndims, nb_lvl_feats]
979 last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor)
980
981 # residual block
982 if use_residuals:
983
984 # the "add" layer is the original input
985 # However, it may not have the right number of features to be added
986 add_layer = up_tensor
987 nb_feats_in = add_layer.get_shape()[-1]
988 nb_feats_out = last_tensor.get_shape()[-1]
989 if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out):
990 name = '%s_expand_up_merge_%d' % (prefix, level)
991 add_layer = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(add_layer)
992
993 if conv_dropout > 0:
994 name = '%s_dropout_up_merge_%d_%d' % (prefix, level, conv)
995 noise_shape = [None, *[1] * ndims, nb_lvl_feats]
996 last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor)
997
998 name = '%s_res_up_merge_%d' % (prefix, level)
999 last_tensor = KL.add([last_tensor, add_layer], name=name)
1000
1001 name = '%s_res_up_merge_act_%d' % (prefix, level)
1002 last_tensor = KL.Activation(activation, name=name)(last_tensor)
1003
1004 if batch_norm is not None:
1005 name = '%s_bn_up_%d' % (prefix, level)
1006 last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor)
1007
1008 # Compute likelyhood prediction (no activation yet)
1009 name = '%s_likelihood' % prefix
1010 last_tensor = convL(nb_labels, 1, activation=None, name=name)(last_tensor)
1011 like_tensor = last_tensor
1012
1013 # output prediction layer
1014 # we use a softmax to compute P(L_x|I) where x is each location
1015 if final_pred_activation == 'softmax':
1016 name = '%s_prediction' % prefix
1017 softmax_lambda_fcn = lambda x: keras.activations.softmax(x, axis=ndims + 1)
1018 pred_tensor = KL.Lambda(softmax_lambda_fcn, name=name)(last_tensor)
1019
1020 # otherwise create a layer that does nothing.
1021 else:
1022 name = '%s_prediction' % prefix
1023 pred_tensor = KL.Activation('linear', name=name)(like_tensor)
1024
1025 # create the model and retun
1026 model = keras.Model(inputs=input_tensor, outputs=pred_tensor, name=model_name)
1027 return model
1028
1029 def postprocess(post_patch_seg, post_patch_parc, shape, pad_idx, crop_idx,
1030 labels_segmentation, labels_parcellation, aff, im_res):
1031
1032 # get posteriors
1033 post_patch_seg = np.squeeze(post_patch_seg)
1034 post_patch_seg = crop_volume_with_idx(post_patch_seg, pad_idx, n_dims=3, return_copy=False)
1035
1036 # keep biggest connected component
1037 tmp_post_patch_seg = post_patch_seg[..., 1:]
1038 post_patch_seg_mask = np.sum(tmp_post_patch_seg, axis=-1) > 0.25
1039 post_patch_seg_mask = get_largest_connected_component(post_patch_seg_mask)
1040 post_patch_seg_mask = np.stack([post_patch_seg_mask]*tmp_post_patch_seg.shape[-1], axis=-1)
1041 tmp_post_patch_seg = mask_volume(tmp_post_patch_seg, mask=post_patch_seg_mask, return_copy=False)
1042 post_patch_seg[..., 1:] = tmp_post_patch_seg
1043
1044 # reset posteriors to zero outside the largest connected component of each topological class
1045 post_patch_seg_mask = post_patch_seg > 0.2
1046 post_patch_seg[..., 1:] *= post_patch_seg_mask[..., 1:]
1047
1048 # get hard segmentation
1049 post_patch_seg /= np.sum(post_patch_seg, axis=-1)[..., np.newaxis]
1050 seg_patch = labels_segmentation[post_patch_seg.argmax(-1).astype('int32')].astype('int32')
1051
1052 # postprocess parcellation
1053 post_patch_parc = np.squeeze(post_patch_parc)
1054 post_patch_parc = crop_volume_with_idx(post_patch_parc, pad_idx, n_dims=3, return_copy=False)
1055 mask = (seg_patch == 3) | (seg_patch == 42)
1056 post_patch_parc[..., 0] = np.ones_like(post_patch_parc[..., 0])
1057 post_patch_parc[..., 0] = mask_volume(post_patch_parc[..., 0], mask=mask < 0.1, return_copy=False)
1058 post_patch_parc /= np.sum(post_patch_parc, axis=-1)[..., np.newaxis]
1059 parc_patch = labels_parcellation[post_patch_parc.argmax(-1).astype('int32')].astype('int32')
1060 seg_patch[mask] = parc_patch[mask]
1061
1062 # paste patches back to matrix of original image size
1063 if crop_idx is not None:
1064 # we need to go through this because of the posteriors of the background, otherwise pad_volume would work
1065 seg = np.zeros(shape=shape, dtype='int32')
1066 posteriors = np.zeros(shape=[*shape, labels_segmentation.shape[0]])
1067 posteriors[..., 0] = np.ones(shape) # place background around patch
1068 seg[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5]] = seg_patch
1069 posteriors[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5], :] = post_patch_seg
1070 else:
1071 seg = seg_patch
1072 posteriors = post_patch_seg
1073
1074 # align prediction back to first orientation
1075 seg = align_volume_to_ref(seg, aff=np.eye(4), aff_ref=aff, n_dims=3, return_copy=False)
1076 posteriors = align_volume_to_ref(posteriors, np.eye(4), aff_ref=aff, n_dims=3, return_copy=False)
1077
1078 # compute volumes
1079 volumes = np.sum(posteriors[..., 1:], axis=tuple(range(0, len(posteriors.shape) - 1)))
1080 volumes = np.concatenate([np.array([np.sum(volumes)]), volumes])
1081 if post_patch_parc is not None:
1082 volumes_parc = np.sum(post_patch_parc[..., 1:], axis=tuple(range(0, len(posteriors.shape) - 1)))
1083 total_volume_cortex = np.sum(volumes[np.where((labels_segmentation == 3) | (labels_segmentation == 42))[0] - 1])
1084 volumes_parc = volumes_parc / np.sum(volumes_parc) * total_volume_cortex
1085 volumes = np.concatenate([volumes, volumes_parc])
1086 volumes = np.around(volumes * np.prod(im_res), 3)
1087
1088 return seg, posteriors, volumes
1089
1090 def save_volume(volume, aff, header, path, res=None, dtype=None, n_dims=3):
1091 mkdir(os.path.dirname(path))
1092 if '.npz' in path:
1093 np.savez_compressed(path, vol_data=volume)
1094 else:
1095 if header is None:
1096 header = nib.Nifti1Header()
1097 if isinstance(aff, str):
1098 if aff == 'FS':
1099 aff = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]])
1100 elif aff is None:
1101 aff = np.eye(4)
1102 nifty = nib.Nifti1Image(volume, aff, header)
1103 if dtype is not None:
1104 if 'int' in dtype:
1105 volume = np.round(volume)
1106 volume = volume.astype(dtype=dtype)
1107 nifty.set_data_dtype(dtype)
1108 if res is not None:
1109 if n_dims is None:
1110 n_dims, _ = get_dims(volume.shape)
1111 res = reformat_to_list(res, length=n_dims, dtype=None)
1112 nifty.header.set_zooms(res)
1113 nib.save(nifty, path)
1114
1115
1116
1117 def mkdir(path_dir):
1118
1119 if len(path_dir)>0:
1120 if path_dir[-1] == '/':
1121 path_dir = path_dir[:-1]
1122 if not os.path.isdir(path_dir):
1123 list_dir_to_create = [path_dir]
1124 while not os.path.isdir(os.path.dirname(list_dir_to_create[-1])):
1125 list_dir_to_create.append(os.path.dirname(list_dir_to_create[-1]))
1126 for dir_to_create in reversed(list_dir_to_create):
1127 os.mkdir(dir_to_create)
1128
1129
1130 def getM(ref, mov):
1131 zmat = np.zeros(ref.shape[::-1])
1132 zcol = np.zeros([ref.shape[1], 1])
1133 ocol = np.ones([ref.shape[1], 1])
1134 zero = np.zeros(zmat.shape)
1135 A = np.concatenate([
1136 np.concatenate([np.transpose(ref), zero, zero, ocol, zcol, zcol], axis=1),
1137 np.concatenate([zero, np.transpose(ref), zero, zcol, ocol, zcol], axis=1),
1138 np.concatenate([zero, zero, np.transpose(ref), zcol, zcol, ocol], axis=1)], axis=0)
1139 b = np.concatenate([np.transpose(mov[0, :]), np.transpose(mov[1, :]), np.transpose(mov[2, :])], axis=0)
1140 x = np.matmul(np.linalg.inv(np.matmul(np.transpose(A), A)), np.matmul(np.transpose(A), b))
1141 M = np.stack([
1142 [x[0], x[1], x[2], x[9]],
1143 [x[3], x[4], x[5], x[10]],
1144 [x[6], x[7], x[8], x[11]],
1145 [0, 0, 0, 1]])
1146 return M
1147
1148 def getMrigid(A, B):
1149 centroid_A = np.mean(A, axis=1, keepdims=True)
1150 centroid_B = np.mean(B, axis=1, keepdims=True)
1151 A_centered = A - centroid_A
1152 B_centered = B - centroid_B
1153 H = A_centered @ B_centered.T
1154 U, S, Vt = np.linalg.svd(H)
1155 R = Vt.T @ U.T
1156 if np.linalg.det(R) < 0:
1157 Vt[2, :] *= -1
1158 R = Vt.T @ U.T
1159 t = centroid_B - R @ centroid_A
1160 T = np.eye(4)
1161 T[:3, :3] = R
1162 T[:3, 3] = t.flatten()
1163 return T
1164
1165
1166 def fast_3D_interp_torch(X, II, JJ, KK, mode):
1167 if mode=='nearest':
1168 IIr = torch.round(II).long()
1169 JJr = torch.round(JJ).long()
1170 KKr = torch.round(KK).long()
1171 IIr[IIr < 0] = 0
1172 JJr[JJr < 0] = 0
1173 KKr[KKr < 0] = 0
1174 IIr[IIr > (X.shape[0] - 1)] = (X.shape[0] - 1)
1175 JJr[JJr > (X.shape[1] - 1)] = (X.shape[1] - 1)
1176 KKr[KKr > (X.shape[2] - 1)] = (X.shape[2] - 1)
1177 Y = X[IIr, JJr, KKr]
1178 elif mode=='linear':
1179 ok = (II>0) & (JJ>0) & (KK>0) & (II<=X.shape[0]-1) & (JJ<=X.shape[1]-1) & (KK<=X.shape[2]-1)
1180 IIv = II[ok]
1181 JJv = JJ[ok]
1182 KKv = KK[ok]
1183
1184 fx = torch.floor(IIv).long()
1185 cx = fx + 1
1186 cx[cx > (X.shape[0] - 1)] = (X.shape[0] - 1)
1187 wcx = IIv - fx
1188 wfx = 1 - wcx
1189
1190 fy = torch.floor(JJv).long()
1191 cy = fy + 1
1192 cy[cy > (X.shape[1] - 1)] = (X.shape[1] - 1)
1193 wcy = JJv - fy
1194 wfy = 1 - wcy
1195
1196 fz = torch.floor(KKv).long()
1197 cz = fz + 1
1198 cz[cz > (X.shape[2] - 1)] = (X.shape[2] - 1)
1199 wcz = KKv - fz
1200 wfz = 1 - wcz
1201
1202 c000 = X[fx, fy, fz]
1203 c100 = X[cx, fy, fz]
1204 c010 = X[fx, cy, fz]
1205 c110 = X[cx, cy, fz]
1206 c001 = X[fx, fy, cz]
1207 c101 = X[cx, fy, cz]
1208 c011 = X[fx, cy, cz]
1209 c111 = X[cx, cy, cz]
1210
1211 c00 = c000 * wfx + c100 * wcx
1212 c01 = c001 * wfx + c101 * wcx
1213 c10 = c010 * wfx + c110 * wcx
1214 c11 = c011 * wfx + c111 * wcx
1215
1216 c0 = c00 * wfy + c10 * wcy
1217 c1 = c01 * wfy + c11 * wcy
1218
1219 c = c0 * wfz + c1 * wcz
1220
1221 Y = torch.zeros(II.shape, device='cpu')
1222 Y[ok] = c.float()
1223
1224 else:
1225 sf.system.fatal('mode must be linear or nearest')
1226
1227 return Y
1228
1229
1230
1231 def fast_3D_interp_field_torch(X, II, JJ, KK):
1232
1233 ok = (II > 0) & (JJ > 0) & (KK > 0) & (II <= X.shape[0] - 1) & (JJ <= X.shape[1] - 1) & (KK <= X.shape[2] - 1)
1234 IIv = II[ok]
1235 JJv = JJ[ok]
1236 KKv = KK[ok]
1237
1238 fx = torch.floor(IIv).long()
1239 cx = fx + 1
1240 cx[cx > (X.shape[0] - 1)] = (X.shape[0] - 1)
1241 wcx = IIv - fx
1242 wfx = 1 - wcx
1243
1244 fy = torch.floor(JJv).long()
1245 cy = fy + 1
1246 cy[cy > (X.shape[1] - 1)] = (X.shape[1] - 1)
1247 wcy = JJv - fy
1248 wfy = 1 - wcy
1249
1250 fz = torch.floor(KKv).long()
1251 cz = fz + 1
1252 cz[cz > (X.shape[2] - 1)] = (X.shape[2] - 1)
1253 wcz = KKv - fz
1254 wfz = 1 - wcz
1255
1256 Y = torch.zeros([*II.shape, 3], device='cpu')
1257 for channel in range(3):
1258
1259 Xc = X[:, :, :, channel]
1260
1261 c000 = Xc[fx, fy, fz]
1262 c100 = Xc[cx, fy, fz]
1263 c010 = Xc[fx, cy, fz]
1264 c110 = Xc[cx, cy, fz]
1265 c001 = Xc[fx, fy, cz]
1266 c101 = Xc[cx, fy, cz]
1267 c011 = Xc[fx, cy, cz]
1268 c111 = Xc[cx, cy, cz]
1269
1270 c00 = c000 * wfx + c100 * wcx
1271 c01 = c001 * wfx + c101 * wcx
1272 c10 = c010 * wfx + c110 * wcx
1273 c11 = c011 * wfx + c111 * wcx
1274
1275 c0 = c00 * wfy + c10 * wcy
1276 c1 = c01 * wfy + c11 * wcy
1277
1278 c = c0 * wfz + c1 * wcz
1279
1280 Yc = torch.zeros(II.shape, device='cpu')
1281 Yc[ok] = c.float()
1282
1283 Y[:, :, :, channel] = Yc
1284
1285 return Y
1286
1287
1288 def crop_volume_with_idx(volume, crop_idx, aff=None, n_dims=None, return_copy=True):
1289
1290 # get info
1291 new_volume = volume.copy() if return_copy else volume
1292 n_dims = int(np.array(crop_idx).shape[0] / 2) if n_dims is None else n_dims
1293
1294 # crop image
1295 if n_dims == 2:
1296 new_volume = new_volume[crop_idx[0]:crop_idx[2], crop_idx[1]:crop_idx[3], ...]
1297 elif n_dims == 3:
1298 new_volume = new_volume[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5], ...]
1299 else:
1300 sf.system.fatal('cannot crop volumes with more than 3 dimensions')
1301
1302 if aff is not None:
1303 aff[0:3, -1] = aff[0:3, -1] + aff[:3, :3] @ crop_idx[:3]
1304 return new_volume, aff
1305 else:
1306 return new_volume
1307
1308
1309 def get_largest_connected_component(mask, structure=None):
1310 components, n_components = scipy_label(mask, structure)
1311 return components == np.argmax(np.bincount(components.flat)[1:]) + 1 if n_components > 0 else mask.copy()
1312
1313
1314 def mask_volume(volume, mask=None, threshold=0.1, dilate=0, erode=0, fill_holes=False, masking_value=0,
1315 return_mask=False, return_copy=True):
1316
1317 # get info
1318 new_volume = volume.copy() if return_copy else volume
1319 vol_shape = list(new_volume.shape)
1320 n_dims, n_channels = get_dims(vol_shape)
1321
1322 # get mask and erode/dilate it
1323 if mask is None:
1324 mask = new_volume >= threshold
1325 else:
1326 assert list(mask.shape[:n_dims]) == vol_shape[:n_dims], 'mask should have shape {0}, or {1}, had {2}'.format(
1327 vol_shape[:n_dims], vol_shape[:n_dims] + [n_channels], list(mask.shape))
1328 mask = mask > 0
1329 if dilate > 0:
1330 dilate_struct = build_binary_structure(dilate, n_dims)
1331 mask_to_apply = binary_dilation(mask, dilate_struct)
1332 else:
1333 mask_to_apply = mask
1334 if erode > 0:
1335 erode_struct = build_binary_structure(erode, n_dims)
1336 mask_to_apply = binary_erosion(mask_to_apply, erode_struct)
1337 if fill_holes:
1338 mask_to_apply = binary_fill_holes(mask_to_apply)
1339
1340 # replace values outside of mask by padding_char
1341 if mask_to_apply.shape == new_volume.shape:
1342 new_volume[np.logical_not(mask_to_apply)] = masking_value
1343 else:
1344 new_volume[np.stack([np.logical_not(mask_to_apply)] * n_channels, axis=-1)] = masking_value
1345
1346 if return_mask:
1347 return new_volume, mask_to_apply
1348 else:
1349 return new_volume
1350
1351
1352 def build_binary_structure(connectivity, n_dims, shape=None):
1353 if shape is None:
1354 shape = [connectivity * 2 + 1] * n_dims
1355 else:
1356 shape = reformat_to_list(shape, length=n_dims)
1357 dist = np.ones(shape)
1358 center = tuple([tuple([int(s / 2)]) for s in shape])
1359 dist[center] = 0
1360 dist = distance_transform_edt(dist)
1361 struct = (dist <= connectivity) * 1
1362 return struct
1363
1364
1365
1366 def crop_volume(volume, cropping_margin=None, cropping_shape=None, aff=None, return_crop_idx=False, mode='center'):
1367
1368 assert (cropping_margin is not None) | (cropping_shape is not None), \
1369 'cropping_margin or cropping_shape should be provided'
1370 assert not ((cropping_margin is not None) & (cropping_shape is not None)), \
1371 'only one of cropping_margin or cropping_shape should be provided'
1372
1373 # get info
1374 new_volume = volume.copy()
1375 vol_shape = new_volume.shape
1376 n_dims, _ = get_dims(vol_shape)
1377
1378 # find cropping indices
1379 if cropping_margin is not None:
1380 cropping_margin = reformat_to_list(cropping_margin, length=n_dims)
1381 do_cropping = np.array(vol_shape[:n_dims]) > 2 * np.array(cropping_margin)
1382 min_crop_idx = [cropping_margin[i] if do_cropping[i] else 0 for i in range(n_dims)]
1383 max_crop_idx = [vol_shape[i] - cropping_margin[i] if do_cropping[i] else vol_shape[i] for i in range(n_dims)]
1384 else:
1385 cropping_shape = reformat_to_list(cropping_shape, length=n_dims)
1386 if mode == 'center':
1387 min_crop_idx = np.maximum([int((vol_shape[i] - cropping_shape[i]) / 2) for i in range(n_dims)], 0)
1388 max_crop_idx = np.minimum([min_crop_idx[i] + cropping_shape[i] for i in range(n_dims)],
1389 np.array(vol_shape)[:n_dims])
1390 elif mode == 'random':
1391 crop_max_val = np.maximum(np.array([vol_shape[i] - cropping_shape[i] for i in range(n_dims)]), 0)
1392 min_crop_idx = np.random.randint(0, high=crop_max_val + 1)
1393 max_crop_idx = np.minimum(min_crop_idx + np.array(cropping_shape), np.array(vol_shape)[:n_dims])
1394 else:
1395 raise ValueError('mode should be either "center" or "random", had %s' % mode)
1396 crop_idx = np.concatenate([np.array(min_crop_idx), np.array(max_crop_idx)])
1397
1398 # crop volume
1399 if n_dims == 2:
1400 new_volume = new_volume[crop_idx[0]: crop_idx[2], crop_idx[1]: crop_idx[3], ...]
1401 elif n_dims == 3:
1402 new_volume = new_volume[crop_idx[0]: crop_idx[3], crop_idx[1]: crop_idx[4], crop_idx[2]: crop_idx[5], ...]
1403
1404 # sort outputs
1405 output = [new_volume]
1406 if aff is not None:
1407 aff[0:3, -1] = aff[0:3, -1] + aff[:3, :3] @ np.array(min_crop_idx)
1408 output.append(aff)
1409 if return_crop_idx:
1410 output.append(crop_idx)
1411 return output[0] if len(output) == 1 else tuple(output)
1412
1413
1414
1415 def rescale_volume(volume, new_min=0, new_max=255, min_percentile=2., max_percentile=98., use_positive_only=False):
1416
1417 # select only positive intensities
1418 new_volume = volume.copy()
1419 intensities = new_volume[new_volume > 0] if use_positive_only else new_volume.flatten()
1420
1421 # define min and max intensities in original image for normalisation
1422 robust_min = np.min(intensities) if min_percentile == 0 else np.percentile(intensities, min_percentile)
1423 robust_max = np.max(intensities) if max_percentile == 100 else np.percentile(intensities, max_percentile)
1424
1425 # trim values outside range
1426 new_volume = np.clip(new_volume, robust_min, robust_max)
1427
1428 # rescale image
1429 if robust_min != robust_max:
1430 return new_min + (new_volume - robust_min) / (robust_max - robust_min) * (new_max - new_min)
1431 else: # avoid dividing by zero
1432 return np.zeros_like(new_volume)
1433
1434
1435
1436
1437 def pad_volume(volume, padding_shape, padding_value=0, aff=None, return_pad_idx=False):
1438 # get info
1439 new_volume = volume.copy()
1440 vol_shape = new_volume.shape
1441 n_dims, n_channels = get_dims(vol_shape)
1442 padding_shape = reformat_to_list(padding_shape, length=n_dims, dtype='int')
1443
1444 # check if need to pad
1445 if np.any(np.array(padding_shape, dtype='int32') > np.array(vol_shape[:n_dims], dtype='int32')):
1446
1447 # get padding margins
1448 min_margins = np.maximum(np.int32(np.floor((np.array(padding_shape) - np.array(vol_shape)[:n_dims]) / 2)), 0)
1449 max_margins = np.maximum(np.int32(np.ceil((np.array(padding_shape) - np.array(vol_shape)[:n_dims]) / 2)), 0)
1450 pad_idx = np.concatenate([min_margins, min_margins + np.array(vol_shape[:n_dims])])
1451 pad_margins = tuple([(min_margins[i], max_margins[i]) for i in range(n_dims)])
1452 if n_channels > 1:
1453 pad_margins = tuple(list(pad_margins) + [(0, 0)])
1454
1455 # pad volume
1456 new_volume = np.pad(new_volume, pad_margins, mode='constant', constant_values=padding_value)
1457
1458 if aff is not None:
1459 if n_dims == 2:
1460 min_margins = np.append(min_margins, 0)
1461 aff[:-1, -1] = aff[:-1, -1] - aff[:-1, :-1] @ min_margins
1462
1463 else:
1464 pad_idx = np.concatenate([np.array([0] * n_dims), np.array(vol_shape[:n_dims])])
1465
1466 # sort outputs
1467 output = [new_volume]
1468 if aff is not None:
1469 output.append(aff)
1470 if return_pad_idx:
1471 output.append(pad_idx)
1472 return output[0] if len(output) == 1 else tuple(output)
1473
1474
1475
1476 def add_axis(x, axis=0):
1477 axis = reformat_to_list(axis)
1478 for ax in axis:
1479 x = np.expand_dims(x, axis=ax)
1480 return x
1481
1482
1483 def volshape_to_meshgrid(volshape, **kwargs):
1484 """
1485 compute Tensor meshgrid from a volume size
1486 """
1487
1488 isint = [float(d).is_integer() for d in volshape]
1489 if not all(isint):
1490 raise ValueError("volshape needs to be a list of integers")
1491
1492 linvec = [tf.range(0, d) for d in volshape]
1493 return meshgrid(*linvec, **kwargs)
1494
1495
1496 def meshgrid(*args, **kwargs):
1497
1498 indexing = kwargs.pop("indexing", "xy")
1499 if kwargs:
1500 key = list(kwargs.keys())[0]
1501 raise TypeError("'{}' is an invalid keyword argument "
1502 "for this function".format(key))
1503
1504 if indexing not in ("xy", "ij"):
1505 raise ValueError("indexing parameter must be either 'xy' or 'ij'")
1506
1507 # with ops.name_scope(name, "meshgrid", args) as name:
1508 ndim = len(args)
1509 s0 = (1,) * ndim
1510
1511 # Prepare reshape by inserting dimensions with size 1 where needed
1512 output = []
1513 for i, x in enumerate(args):
1514 output.append(tf.reshape(tf.stack(x), (s0[:i] + (-1,) + s0[i + 1::])))
1515 # Create parameters for broadcasting each tensor to the full size
1516 shapes = [tf.size(x) for x in args]
1517 sz = [x.get_shape().as_list()[0] for x in args]
1518
1519 # output_dtype = tf.convert_to_tensor(args[0]).dtype.base_dtype
1520 if indexing == "xy" and ndim > 1:
1521 output[0] = tf.reshape(output[0], (1, -1) + (1,) * (ndim - 2))
1522 output[1] = tf.reshape(output[1], (-1, 1) + (1,) * (ndim - 2))
1523 shapes[0], shapes[1] = shapes[1], shapes[0]
1524 sz[0], sz[1] = sz[1], sz[0]
1525
1526 # This is the part of the implementation from tf that is slow.
1527 # We replace it below to get a ~6x speedup (essentially using tile instead of * tf.ones())
1528 # mult_fact = tf.ones(shapes, output_dtype)
1529 # return [x * mult_fact for x in output]
1530 for i in range(len(output)):
1531 stack_sz = [*sz[:i], 1, *sz[(i + 1):]]
1532 if indexing == 'xy' and ndim > 1 and i < 2:
1533 stack_sz[0], stack_sz[1] = stack_sz[1], stack_sz[0]
1534 output[i] = tf.tile(output[i], tf.stack(stack_sz))
1535 return output
1536
1537
1538 def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True):
1539
1540 # convert sigma into a tensor
1541 if not tf.is_tensor(sigma):
1542 sigma_tens = tf.convert_to_tensor(reformat_to_list(sigma), dtype='float32')
1543 else:
1544 assert max_sigma is not None, 'max_sigma must be provided when sigma is given as a tensor'
1545 sigma_tens = sigma
1546 shape = sigma_tens.get_shape().as_list()
1547
1548 # get n_dims and batchsize
1549 if shape[0] is not None:
1550 n_dims = shape[0]
1551 batchsize = None
1552 else:
1553 n_dims = shape[1]
1554 batchsize = tf.split(tf.shape(sigma_tens), [1, -1])[0]
1555
1556 # reformat max_sigma
1557 if max_sigma is not None: # dynamic blurring
1558 max_sigma = np.array(reformat_to_list(max_sigma, length=n_dims))
1559 else: # sigma is fixed
1560 max_sigma = np.array(reformat_to_list(sigma, length=n_dims))
1561
1562 # randomise the burring std dev and/or split it between dimensions
1563 if blur_range is not None:
1564 if blur_range != 1:
1565 sigma_tens = sigma_tens * tf.random.uniform(tf.shape(sigma_tens), minval=1 / blur_range, maxval=blur_range)
1566
1567 # get size of blurring kernels
1568 windowsize = np.int32(np.ceil(2.5 * max_sigma) / 2) * 2 + 1
1569
1570 if separable:
1571
1572 split_sigma = tf.split(sigma_tens, [1] * n_dims, axis=-1)
1573
1574 kernels = list()
1575 comb = np.array(list(combinations(list(range(n_dims)), n_dims - 1))[::-1])
1576 for (i, wsize) in enumerate(windowsize):
1577
1578 if wsize > 1:
1579
1580 # build meshgrid and replicate it along batch dim if dynamic blurring
1581 locations = tf.cast(tf.range(0, wsize), 'float32') - (wsize - 1) / 2
1582 if batchsize is not None:
1583 locations = tf.tile(tf.expand_dims(locations, axis=0),
1584 tf.concat([batchsize, tf.ones(tf.shape(tf.shape(locations)), dtype='int32')],
1585 axis=0))
1586 comb[i] += 1
1587
1588 # compute gaussians
1589 exp_term = -K.square(locations) / (2 * split_sigma[i] ** 2)
1590 g = tf.exp(exp_term - tf.math.log(np.sqrt(2 * np.pi) * split_sigma[i]))
1591 g = g / tf.reduce_sum(g)
1592
1593 for axis in comb[i]:
1594 g = tf.expand_dims(g, axis=axis)
1595 kernels.append(tf.expand_dims(tf.expand_dims(g, -1), -1))
1596
1597 else:
1598 kernels.append(None)
1599
1600 else:
1601
1602 # build meshgrid
1603 mesh = [tf.cast(f, 'float32') for f in volshape_to_meshgrid(windowsize, indexing='ij')]
1604 diff = tf.stack([mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1)
1605
1606 # replicate meshgrid to batch size and reshape sigma_tens
1607 if batchsize is not None:
1608 diff = tf.tile(tf.expand_dims(diff, axis=0),
1609 tf.concat([batchsize, tf.ones(tf.shape(tf.shape(diff)), dtype='int32')], axis=0))
1610 for i in range(n_dims):
1611 sigma_tens = tf.expand_dims(sigma_tens, axis=1)
1612 else:
1613 for i in range(n_dims):
1614 sigma_tens = tf.expand_dims(sigma_tens, axis=0)
1615
1616 # compute gaussians
1617 sigma_is_0 = tf.equal(sigma_tens, 0)
1618 exp_term = -K.square(diff) / (2 * tf.where(sigma_is_0, tf.ones_like(sigma_tens), sigma_tens)**2)
1619 norms = exp_term - tf.math.log(tf.where(sigma_is_0, tf.ones_like(sigma_tens), np.sqrt(2 * np.pi) * sigma_tens))
1620 kernels = K.sum(norms, -1)
1621 kernels = tf.exp(kernels)
1622 kernels /= tf.reduce_sum(kernels)
1623 kernels = tf.expand_dims(tf.expand_dims(kernels, -1), -1)
1624
1625 return kernels
1626
1627
1628 def get_mapping_lut(source, dest=None):
1629 """This functions returns the look-up table to map a list of N values (source) to another list (dest).
1630 If the second list is not given, we assume it is equal to [0, ..., N-1]."""
1631
1632 # initialise
1633 source = np.array(reformat_to_list(source), dtype='int32')
1634 n_labels = source.shape[0]
1635
1636 # build new label list if neccessary
1637 if dest is None:
1638 dest = np.arange(n_labels, dtype='int32')
1639 else:
1640 assert len(source) == len(dest), 'label_list and new_label_list should have the same length'
1641 dest = np.array(reformat_to_list(dest, dtype='int'))
1642
1643 # build look-up table
1644 lut = np.zeros(np.max(source) + 1, dtype='int32')
1645 for source, dest in zip(source, dest):
1646 lut[source] = dest
1647
1648 return lut
1649
1650
1651 class GaussianBlur(KL.Layer):
1652 """Applies gaussian blur to an input image."""
1653
1654 def __init__(self, sigma, random_blur_range=None, use_mask=False, **kwargs):
1655 self.sigma = reformat_to_list(sigma)
1656 assert np.all(np.array(self.sigma) >= 0), 'sigma should be superior or equal to 0'
1657 self.use_mask = use_mask
1658
1659 self.n_dims = None
1660 self.n_channels = None
1661 self.blur_range = random_blur_range
1662 self.stride = None
1663 self.separable = None
1664 self.kernels = None
1665 self.convnd = None
1666 super(GaussianBlur, self).__init__(**kwargs)
1667
1668 def get_config(self):
1669 config = super().get_config()
1670 config["sigma"] = self.sigma
1671 config["random_blur_range"] = self.blur_range
1672 config["use_mask"] = self.use_mask
1673 return config
1674
1675 def build(self, input_shape):
1676
1677 # get shapes
1678 if self.use_mask:
1679 assert len(input_shape) == 2, 'please provide a mask as second layer input when use_mask=True'
1680 self.n_dims = len(input_shape[0]) - 2
1681 self.n_channels = input_shape[0][-1]
1682 else:
1683 self.n_dims = len(input_shape) - 2
1684 self.n_channels = input_shape[-1]
1685
1686 # prepare blurring kernel
1687 self.stride = [1]*(self.n_dims+2)
1688 self.sigma = reformat_to_list(self.sigma, length=self.n_dims)
1689 self.separable = np.linalg.norm(np.array(self.sigma)) > 5
1690 if self.blur_range is None: # fixed kernels
1691 self.kernels = gaussian_kernel(self.sigma, separable=self.separable)
1692 else:
1693 self.kernels = None
1694
1695 # prepare convolution
1696 self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims)
1697
1698 self.built = True
1699 super(GaussianBlur, self).build(input_shape)
1700
1701 def call(self, inputs, **kwargs):
1702
1703 if self.use_mask:
1704 image = inputs[0]
1705 mask = tf.cast(inputs[1], 'bool')
1706 else:
1707 image = inputs
1708 mask = None
1709
1710 # redefine the kernels at each new step when blur_range is activated
1711 if self.blur_range is not None:
1712 self.kernels = gaussian_kernel(self.sigma, blur_range=self.blur_range, separable=self.separable)
1713
1714 if self.separable:
1715 for k in self.kernels:
1716 if k is not None:
1717 image = tf.concat([self.convnd(tf.expand_dims(image[..., n], -1), k, self.stride, 'SAME')
1718 for n in range(self.n_channels)], -1)
1719 if self.use_mask:
1720 maskb = tf.cast(mask, 'float32')
1721 maskb = tf.concat([self.convnd(tf.expand_dims(maskb[..., n], -1), k, self.stride, 'SAME')
1722 for n in range(self.n_channels)], -1)
1723 image = image / (maskb + keras.backend.epsilon())
1724 image = tf.where(mask, image, tf.zeros_like(image))
1725 else:
1726 if any(self.sigma):
1727 image = tf.concat([self.convnd(tf.expand_dims(image[..., n], -1), self.kernels, self.stride, 'SAME')
1728 for n in range(self.n_channels)], -1)
1729 if self.use_mask:
1730 maskb = tf.cast(mask, 'float32')
1731 maskb = tf.concat([self.convnd(tf.expand_dims(maskb[..., n], -1), self.kernels, self.stride, 'SAME')
1732 for n in range(self.n_channels)], -1)
1733 image = image / (maskb + keras.backend.epsilon())
1734 image = tf.where(mask, image, tf.zeros_like(image))
1735
1736 return image
1737
1738
1739 class ConvertLabels(KL.Layer):
1740
1741 def __init__(self, source_values, dest_values=None, **kwargs):
1742 self.source_values = source_values
1743 self.dest_values = dest_values
1744 self.lut = None
1745 super(ConvertLabels, self).__init__(**kwargs)
1746
1747 def get_config(self):
1748 config = super().get_config()
1749 config["source_values"] = self.source_values
1750 config["dest_values"] = self.dest_values
1751 return config
1752
1753 def build(self, input_shape):
1754 self.lut = tf.convert_to_tensor(get_mapping_lut(self.source_values, dest=self.dest_values), dtype='int32')
1755 self.built = True
1756 super(ConvertLabels, self).build(input_shape)
1757
1758 def call(self, inputs, **kwargs):
1759 return tf.gather(self.lut, tf.cast(inputs, dtype='int32'))
1760
1761
1762
1763
1764 # execute script
1765 if __name__ == '__main__':
1766 main()
Attached Files
To refer to attachments on a page, use attachment:filename, as shown below in the list of files. Do NOT use the URL of the [get] link, since this is subject to change and can break easily.You are not allowed to attach a file to this page.
