Attachment 'mri_easyatlas.py'

Download

   1 import os
   2 import argparse
   3 import numpy as np
   4 import voxelmorph as vxm
   5 import torch
   6 import surfa as sf
   7 import nibabel as nib
   8 import glob
   9 from scipy.ndimage import gaussian_filter, binary_dilation, binary_erosion, distance_transform_edt, binary_fill_holes
  10 from scipy.ndimage import label as scipy_label
  11 
  12 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
  13 os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
  14 import tensorflow as tf
  15 import keras
  16 import keras.backend as K
  17 import keras.layers as KL
  18 
  19 
  20 
  21 # set tensorflow logging
  22 tf.get_logger().setLevel('ERROR')
  23 K.set_image_data_format('channels_last')
  24 
  25 
  26 def main():
  27 
  28     parser = argparse.ArgumentParser(description="EasyAtlas: fast atlas construction with EasyReg", epilog='\n')
  29 
  30     # input/outputs
  31     parser.add_argument("--i", help="Input directory with scans")
  32     parser.add_argument("--o", help="Output directory where atlas and other files will be written")
  33     parser.add_argument("--threads", type=int, default=-1, help="(optional) Number of cores to be used. You can use -1 to use all available cores. Default is -1.")
  34     parser.add_argument('--use_reliability_maps', action='store_true', help='Use reliability maps when averaging into atlas (recommended if data are not 1mm isotropic!')
  35 
  36     # parse commandline
  37     args = parser.parse_args()
  38 
  39     #############
  40 
  41     # Very first thing: we require FreeSurfer
  42     if not os.environ.get('FREESURFER_HOME'):
  43         sf.system.fatal('FREESURFER_HOME is not set. Please source freesurfer.')
  44     fs_home = os.environ.get('FREESURFER_HOME')
  45 
  46     if args.i is None:
  47         sf.system.fatal('Input directory must be provided')
  48     if args.o is None:
  49         sf.system.fatal('Output directory must be provided')
  50 
  51     # limit the number of threads to be used if running on CPU
  52     if args.threads<0:
  53         args.threads = os.cpu_count()
  54         print('using all available threads ( %s )' % args.threads)
  55     else:
  56         print('using %s threads' % args.threads)
  57     tf.config.threading.set_inter_op_parallelism_threads(args.threads)
  58     tf.config.threading.set_intra_op_parallelism_threads(args.threads)
  59     torch.set_num_threads(args.threads)
  60 
  61     # path models
  62     path_model_segmentation = fs_home + '/models/synthseg_2.0.h5'
  63     path_model_parcellation = fs_home + '/models/synthseg_parc_2.0.h5'
  64     path_model_registration_trained = fs_home + '/models/easyreg_v10_230103.h5'
  65 
  66     # path labels
  67     labels_segmentation = fs_home +  '/models/synthseg_segmentation_labels_2.0.npy'
  68     labels_parcellation = fs_home +  '/models/synthseg_parcellation_labels.npy'
  69     atlas_volsize = [160, 160, 192]
  70     atlas_aff = np.matrix([[-1, 0, 0, 79], [0, 0, 1, -104], [0, -1, 0, 79], [0, 0, 0, 1]])
  71 
  72     # get label lists
  73     labels_segmentation, _ = get_list_labels(label_list=labels_segmentation)
  74     labels_segmentation, unique_idx = np.unique(labels_segmentation, return_index=True)
  75     labels_parcellation, _ = np.unique(get_list_labels(labels_parcellation)[0], return_index=True)
  76 
  77     # Create output (and SynthSeg) directory if needed
  78     if os.path.exists(args.o) and os.path.isdir(args.o):
  79         print('Output directory already exists; no need to create it')
  80     else:
  81         os.mkdir(args.o)
  82     segdir = args.o + '/SynthSeg/'
  83     if os.path.exists(segdir) and os.path.isdir(segdir):
  84         print('SynthSeg directory already exists; no need to create it')
  85     else:
  86         os.mkdir(segdir)
  87     regdir = args.o + '/Registrations/'
  88     if os.path.exists(regdir) and os.path.isdir(regdir):
  89         print('Registration directory already exists; no need to create it')
  90     else:
  91         os.mkdir(regdir)
  92     tempdir = args.o + '/temp/'
  93     if os.path.exists(tempdir) and os.path.isdir(tempdir):
  94         print('Temporary directory already exists; no need to create it')
  95     else:
  96         os.mkdir(tempdir)
  97 
  98     # Build list of input, affine, segmentation files (supports nii, mgz, nii.gz)
  99     input_files = sorted(glob.glob(args.i + '/*.nii.gz') + glob.glob(args.i + '/*.nii') + glob.glob(args.i + '/*.mgz'))
 100     seg_files = []
 101     reg_files = []
 102     linear_files = []
 103     for file in input_files:
 104         _, tail = os.path.split(file)
 105         seg_files.append(segdir + '/' + tail)
 106         reg_files.append(regdir + '/' + tail)
 107         linear_files.append(tempdir + '/' + tail + '.npy')
 108 
 109     # Decide if we need to segment anything
 110     all_segs_ready = True
 111     for file in seg_files:
 112         if os.path.exists(file) is False:
 113             all_segs_ready = False
 114 
 115     # Run SynthSeg if needed
 116     if all_segs_ready:
 117         print('SynthSeg already there for all input files; no need to segment anything')
 118     else:
 119         print('Setting up segmentation net')
 120         segmentation_net = build_seg_model(model_file_segmentation=path_model_segmentation,
 121                                            model_file_parcellation=path_model_parcellation,
 122                                            labels_segmentation=labels_segmentation,
 123                                            labels_parcellation=labels_parcellation)
 124         for i in range(len(input_files)):
 125             if os.path.exists(seg_files[i]):
 126                 print('Image ' + str(i + 1) + ' of ' + str(len(input_files)) + ': segmentation already there')
 127             else:
 128                 print('Image ' + str(i + 1) + ' of ' + str(len(input_files)) + ': segmenting')
 129                 image, aff, h, im_res, shape, pad_idx, crop_idx = preprocess(path_image=input_files[i], crop=None,
 130                                                                              min_pad=128, path_resample=None)
 131                 post_patch_segmentation, post_patch_parcellation = segmentation_net.predict(image)
 132                 seg_buffer, _, _ = postprocess(post_patch_seg=post_patch_segmentation,
 133                                                    post_patch_parc=post_patch_parcellation,
 134                                                    shape=shape,
 135                                                    pad_idx=pad_idx,
 136                                                    crop_idx=crop_idx,
 137                                                    labels_segmentation=labels_segmentation,
 138                                                    labels_parcellation=labels_parcellation,
 139                                                    aff=aff,
 140                                                    im_res=im_res)
 141                 save_volume(seg_buffer, aff, h, seg_files[i], dtype='int32')
 142 
 143     # Now the linear registration part
 144     print('Linear registration with centroids of segmentations')
 145 
 146     # First, prepare a bunch of common variables
 147     labels = np.array([2,4,5,7,8,10,11,12,13,14,15,16,17,18,26,28,41,43,44,46,47,49,50,51,52,53,54,58,60,
 148                                     1001,1002,1003,1005,1006,1007,1008,1009,1010,1011,1012,1013,1014,1015,1016,1017,1018,1019,1020,1021,1022,1023,1024,1025,1026,1027,1028,1029,1030,1031,1032,1033,1034,1035,
 149                                     2001,2002,2003,2005,2006,2007,2008,2009,2010,2011,2012,2013,2014,2015,2016,2017,2018,2019,2020,2021,2022,2023,2024,2025,2026,2027,2028,2029,2030,2031,2032,2033,2034,2035])
 150     nlab = len(labels)
 151     atlasCOG = np.array([[-28.,-18.,-37.,-19.,-27.,-19.,-23.,-31.,-26.,-2.,-3.,-3.,-29.,-26.,-14.,-14.,24.,14.,31.,12.,18.,14.,19.,26.,21.,25.,22.,11.,8.,-52.,-6.,-36.,-7.,-24.,-37.,-39.,-52.,-9.,-27.,-26.,-14.,-8.,-59.,-28.,-7.,-49.,-43.,-47.,-12.,-46.,-6.,-43.,-10.,-7.,-33.,-11.,-23.,-55.,-50.,-10.,-29.,-46.,-38.,48.,4.,31.,3.,21.,33.,37.,47.,3.,24.,20.,8.,4.,54.,21.,5.,45.,38.,46.,8.,45.,3.,38.,6.,4.,29.,9.,19.,51.,49.,10.,24.,43.,33.],
 152                         [-30.,-17.,-13.,-36.,-40.,-22.,-3.,-5.,-9.,-14.,-31.,-21.,-15.,-1.,3.,-16.,-32.,-20.,-14.,-37.,-42.,-24.,-3.,-6.,-10.,-15.,-2.,3.,-17.,-44.,-5.,-15.,-71.,2.,-29.,-70.,-23.,-44.,-73.,22.,-57.,27.,-19.,-23.,-45.,4.,31.,20.,-68.,-38.,-33.,-26.,-60.,23.,22.,0.,-72.,-12.,-49.,49.,17.,-25.,-3.,-42.,-1.,-16.,-76.,0.,-34.,-69.,-16.,-44.,-73.,22.,-56.,28.,-18.,-25.,-45.,-3.,30.,14.,-69.,-37.,-32.,-30.,-60.,21.,21.,0.,-72.,-11.,-49.,48.,15.,-27.,-3.],
 153                         [12.,14.,-13.,-41.,-51.,1.,13.,3.,1.,0.,-40.,-28.,-15.,-10.,2.,-7.,11.,14.,-12.,-40.,-51.,2.,14.,4.,2.,-14.,-10.,4.,-7.,-8.,32.,40.,-14.,-21.,-28.,-4.,-28.,-3.,-35.,3.,-29.,4.,-17.,-21.,35.,18.,9.,20.,-24.,28.,25.,34.,7.,18.,35.,48.,16.,-5.,12.,22.,-18.,1.,4.,-12.,32.,43.,-11.,-21.,-29.,-3.,-27.,0.,-34.,3.,-25.,6.,-18.,-20.,36.,18.,11.,20.,-20.,26.,25.,34.,4.,24.,34.,47.,17.,-5.,10.,20.,-18.,0.,4.]])
 154 
 155     II, JJ, KK = np.meshgrid(np.arange(atlas_volsize[0]), np.arange(atlas_volsize[1]), np.arange(atlas_volsize[2]), indexing='ij')
 156     II = torch.tensor(II, device='cpu')
 157     JJ = torch.tensor(JJ, device='cpu')
 158     KK = torch.tensor(KK, device='cpu')
 159 
 160     # Loop over segmentations and get COGs of ROIs
 161     COGs = np.zeros([len(input_files), 4, nlab])
 162     OKs = np.zeros([len(input_files), nlab])
 163     for i in range(len(input_files)):
 164         print('Getting centroids of ROIs: case ' + str(i + 1) + ' of ' + str(len(input_files)))
 165         COG = np.zeros([4, nlab])
 166         ok = np.ones(nlab)
 167         seg_buffer, seg_aff, seg_h = load_volume(seg_files[i], im_only=False, squeeze=True, dtype=None, aff_ref=None)
 168         label_to_idx = {lab: ii for ii, lab in enumerate(labels)}
 169         coords_per_label = [[] for _ in range(nlab)]
 170         nz = np.array(np.nonzero(seg_buffer)).T
 171         vals = seg_buffer[tuple(nz.T)]
 172         valid_mask = np.isin(vals, labels)
 173         nz = nz[valid_mask]
 174         vals = vals[valid_mask]
 175         idxs = np.searchsorted(labels, vals)
 176         for ii in range(nlab):
 177             coords_per_label[ii] = nz[idxs == ii]
 178         # Compute per-label median centroids
 179         for ii, vox in enumerate(coords_per_label):
 180             if vox.shape[0] > 50:
 181                 COG[:3, ii] = np.median(vox, axis=0)
 182                 COG[3, ii] = 1
 183             else:
 184                 ok[ii] = 0
 185         COGs[i] = np.matmul(seg_aff, COG)
 186         OKs[i] = ok.copy()
 187 
 188     # Linear registration matrices; first rigid, then affine
 189     NUM = np.zeros(atlasCOG.shape)
 190     DEN = np.zeros(atlasCOG.shape)
 191     for i in range(len(input_files)):
 192         M = getMrigid(COGs[i, :-1, OKs[i] > 0].T, atlasCOG[:, OKs[i] > 0])
 193         NUM[:, OKs[i] > 0] = NUM[:, OKs[i] > 0] + (M @ COGs[i, :, OKs[i] > 0].T)[:-1, :]
 194         DEN[:, OKs[i] > 0] = DEN[:, OKs[i] > 0] + 1
 195     rigidAtlasCOG = NUM / DEN
 196     Ms = np.zeros([len(input_files), 4, 4])
 197     for i in range(len(input_files)):
 198         Ms[i] = getM(rigidAtlasCOG[:, OKs[i] > 0], COGs[i, :, OKs[i] > 0].T)
 199 
 200     # OK now we can deform to linear space (and compute linear atlas, while at it)
 201     NUM = np.zeros(atlas_volsize)
 202     DEN = np.zeros(atlas_volsize)
 203     for i in range(len(input_files)):
 204         print('Deforming to linear space: case ' + str(i + 1) + ' of ' + str(len(input_files)))
 205         im_buffer, im_aff, im_hh = load_volume(input_files[i], im_only=False, squeeze=True, dtype=None, aff_ref=None)
 206         im_buffer = torch.tensor(im_buffer, device='cpu')
 207         voxdim = np.sqrt(np.sum(im_aff[:-1, :-1] ** 2, axis=0))
 208         affine = torch.tensor(np.matmul(np.linalg.inv(im_aff), np.matmul(Ms[i], atlas_aff)), device='cpu')
 209         II2 = affine[0, 0] * II + affine[0, 1] * JJ + affine[0, 2] * KK + affine[0, 3]
 210         JJ2 = affine[1, 0] * II + affine[1, 1] * JJ + affine[1, 2] * KK + affine[1, 3]
 211         KK2 = affine[2, 0] * II + affine[2, 1] * JJ + affine[2, 2] * KK + affine[2, 3]
 212         im_lin = fast_3D_interp_torch(im_buffer, II2, JJ2, KK2, 'linear')
 213         if args.use_reliability_maps:
 214             lin_dists = torch.sqrt(((II2 - II2.round()) * voxdim[0]) ** 2 +
 215                                    ((JJ2 - JJ2.round()) * voxdim[1]) ** 2 +
 216                                    ((KK2 - KK2.round()) * voxdim[2]) ** 2)
 217             lin_rel = torch.exp(-1.0 * lin_dists)
 218         else:
 219             lin_rel = torch.ones(II2.shape)
 220 
 221         seg_buffer, seg_aff, seg_h = load_volume(seg_files[i], im_only=False, squeeze=True, dtype=None, aff_ref=None)
 222         affine = torch.tensor(np.matmul(np.linalg.inv(seg_aff), np.matmul(Ms[i], atlas_aff)), device='cpu')
 223         II2 = affine[0, 0] * II + affine[0, 1] * JJ + affine[0, 2] * KK + affine[0, 3]
 224         JJ2 = affine[1, 0] * II + affine[1, 1] * JJ + affine[1, 2] * KK + affine[1, 3]
 225         KK2 = affine[2, 0] * II + affine[2, 1] * JJ + affine[2, 2] * KK + affine[2, 3]
 226         seg_lin = fast_3D_interp_torch(torch.tensor(seg_buffer.copy(), device='cpu'), II2, JJ2, KK2, 'nearest')
 227         im_lin[seg_lin == 0] = 0
 228         im_lin /= torch.median(im_lin[torch.logical_or(seg_lin==2, seg_lin==41)])
 229         np.save(linear_files[i], torch.stack([im_lin, lin_rel]).detach().cpu().numpy())
 230         NUM += (im_lin * lin_rel).detach().cpu().numpy()
 231         DEN += lin_rel.detach().cpu().numpy()
 232 
 233     print('Computing and saving affine atlas')
 234     ATLAS = NUM / (1e-9 + DEN)
 235     save_volume(ATLAS, atlas_aff, None, args.o + '/atlas.affine.nii.gz')
 236 
 237     print('Building nonlinear registration model')
 238     # Build model
 239     source = tf.keras.Input(shape=(*atlas_volsize, 1))
 240     target = tf.keras.Input(shape=(*atlas_volsize, 1))
 241 
 242     config = {'name': 'vxm_dense', 'fill_value': None, 'input_model': None, 'unet_half_res': True, 'trg_feats': 1,
 243               'src_feats': 1, 'use_probs': False, 'bidir': False, 'int_downsize': 2, 'int_steps': 10,
 244               'nb_unet_conv_per_level': 1, 'unet_feat_mult': 1, 'nb_unet_levels': None,
 245               'nb_unet_features': [[256, 256, 256, 256], [256, 256, 256, 256, 256, 256]], 'inshape': atlas_volsize}
 246     cnn = vxm.networks.VxmDense(**config)
 247     cnn.load_weights(path_model_registration_trained, by_name=True)
 248     svf1 = cnn([source, target])[1]
 249     svf2 = cnn([target, source])[1]
 250     pos_svf = KL.Lambda(lambda x: 0.5 * x[0] - 0.5 * x[1])([svf1, svf2])
 251     neg_svf = KL.Lambda(lambda x: -x)(pos_svf)
 252     pos_def_small = vxm.layers.VecInt(method='ss', int_steps=10)(pos_svf)
 253     neg_def_small = vxm.layers.VecInt(method='ss', int_steps=10)(neg_svf)
 254     pos_def = vxm.layers.RescaleTransform(2)(pos_def_small)
 255     neg_def = vxm.layers.RescaleTransform(2)(neg_def_small)
 256     model = tf.keras.Model(inputs=[source, target],
 257                            outputs=[pos_def, neg_def])
 258     model.load_weights(path_model_registration_trained)
 259 
 260     # Global atlas building iterations
 261     MAX_IT = 5
 262     for it in range(MAX_IT):
 263         # Initialize new atlas to zeros
 264         NUM = np.zeros_like(ATLAS)
 265         DEN = np.zeros_like(ATLAS)
 266         for i in range(len(input_files)):
 267             print('Iteration ' + str(1 + it) + ' of ' + str(MAX_IT) + ', image ' + str(i+1) + ' of ' + str(len(input_files)))
 268             lin = np.load(linear_files[i])
 269             pred = model.predict([lin[0:1, ..., np.newaxis] / np.max(lin[0]) ,
 270                                   ATLAS[np.newaxis, ..., np.newaxis]])
 271             field = torch.tensor(pred[0], device='cpu').squeeze()
 272             II2 = II + field[..., 0]
 273             JJ2 = JJ + field[..., 1]
 274             KK2 = KK + field[..., 2]
 275             deformed_im = fast_3D_interp_torch(torch.tensor(lin[0], device='cpu'), II2 , JJ2, KK2, 'linear')
 276             deformed_rel = fast_3D_interp_torch(torch.tensor(lin[1], device='cpu'), II2, JJ2, KK2, 'linear')
 277             NUM += (deformed_im * deformed_rel).detach().cpu().numpy()
 278             DEN += deformed_rel.detach().cpu().numpy()
 279             if it == (MAX_IT-1):
 280                 T = Ms[i] @ atlas_aff
 281                 RR = T[0, 0] * II2 + T[0, 1] * JJ2 + T[0, 2] * KK2 + T[0, 3]
 282                 AA = T[1, 0] * II2 + T[1, 1] * JJ2 + T[1, 2] * KK2 + T[1, 3]
 283                 SS = T[2, 0] * II2 + T[2, 1] * JJ2 + T[2, 2] * KK2 + T[2, 3]
 284                 save_volume(torch.stack([RR, AA, SS], dim=-1).detach().cpu().numpy(), atlas_aff, None, reg_files[i])
 285         ATLAS = NUM / (1e-9 + DEN)
 286         save_volume(ATLAS, atlas_aff, None, args.o + '/atlas.iteration.' + str(it+1) + '.nii.gz')
 287 
 288     # Clean up
 289     print('Deleting temporary files')
 290     for i in range(len(linear_files)):
 291         os.remove(linear_files[i])
 292     os.rmdir(tempdir)
 293 
 294     print(' ')
 295     print('All done!')
 296     print(' ')
 297     print('If you use EasyReg in your analysis, please cite:')
 298     print('A ready-to-use machine learning tool for symmetric multi-modality registration of brain MRI.')
 299     print('JE Iglesias. Scientific Reports, 13, article number 6657 (2023).')
 300     print('https://www.nature.com/articles/s41598-023-33781-0')
 301     print(' ')
 302 
 303 
 304 #######################
 305 # Auxiliary functions #
 306 #######################
 307 
 308 
 309 def get_list_labels(label_list=None, save_label_list=None, FS_sort=False):
 310 
 311     # load label list if previously computed
 312     label_list = np.array(reformat_to_list(label_list, load_as_numpy=True, dtype='int'))
 313 
 314 
 315     # sort labels in neutral/left/right according to FS labels
 316     n_neutral_labels = 0
 317     if FS_sort:
 318         neutral_FS_labels = [0, 14, 15, 16, 21, 22, 23, 24, 72, 77, 80, 85, 100, 101, 102, 103, 104, 105, 106, 107, 108,
 319                              109, 165, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,
 320                              251, 252, 253, 254, 255, 258, 259, 260, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340,
 321                              502, 506, 507, 508, 509, 511, 512, 514, 515, 516, 517, 530,
 322                              531, 532, 533, 534, 535, 536, 537]
 323         neutral = list()
 324         left = list()
 325         right = list()
 326         for la in label_list:
 327             if la in neutral_FS_labels:
 328                 if la not in neutral:
 329                     neutral.append(la)
 330             elif (0 < la < 14) | (16 < la < 21) | (24 < la < 40) | (135 < la < 139) | (1000 <= la <= 1035) | \
 331                     (la == 865) | (20100 < la < 20110):
 332                 if la not in left:
 333                     left.append(la)
 334             elif (39 < la < 72) | (162 < la < 165) | (2000 <= la <= 2035) | (20000 < la < 20010) | (la == 139) | \
 335                     (la == 866):
 336                 if la not in right:
 337                     right.append(la)
 338             else:
 339                 raise Exception('label {} not in our current FS classification, '
 340                                 'please update get_list_labels in utils.py'.format(la))
 341         label_list = np.concatenate([sorted(neutral), sorted(left), sorted(right)])
 342         if ((len(left) > 0) & (len(right) > 0)) | ((len(left) == 0) & (len(right) == 0)):
 343             n_neutral_labels = len(neutral)
 344         else:
 345             n_neutral_labels = len(label_list)
 346 
 347     # save labels if specified
 348     if save_label_list is not None:
 349         np.save(save_label_list, np.int32(label_list))
 350 
 351     if FS_sort:
 352         return np.int32(label_list), n_neutral_labels
 353     else:
 354         return np.int32(label_list), None
 355 
 356 def reformat_to_list(var, length=None, load_as_numpy=False, dtype=None):
 357     # convert to list
 358     if var is None:
 359         return None
 360     var = load_array_if_path(var, load_as_numpy=load_as_numpy)
 361     if isinstance(var, (int, float, np.int8, np.int16, np.int32, np.int64, np.float16, np.float32, np.float64)):
 362         var = [var]
 363     elif isinstance(var, tuple):
 364         var = list(var)
 365     elif isinstance(var, np.ndarray):
 366         if var.shape == (1,):
 367             var = [var[0]]
 368         else:
 369             var = np.squeeze(var).tolist()
 370     elif isinstance(var, str):
 371         var = [var]
 372     elif isinstance(var, bool):
 373         var = [var]
 374     if isinstance(var, list):
 375         if length is not None:
 376             if len(var) == 1:
 377                 var = var * length
 378             elif len(var) != length:
 379                 raise ValueError('if var is a list/tuple/numpy array, it should be of length 1 or {0}, '
 380                                  'had {1}'.format(length, var))
 381     else:
 382         raise TypeError('var should be an int, float, tuple, list, numpy array, or path to numpy array')
 383 
 384     # convert items type
 385     if dtype is not None:
 386         if dtype == 'int':
 387             var = [int(v) for v in var]
 388         elif dtype == 'float':
 389             var = [float(v) for v in var]
 390         elif dtype == 'bool':
 391             var = [bool(v) for v in var]
 392         elif dtype == 'str':
 393             var = [str(v) for v in var]
 394         else:
 395             raise ValueError("dtype should be 'str', 'float', 'int', or 'bool'; had {}".format(dtype))
 396     return var
 397 
 398 def load_array_if_path(var, load_as_numpy=True):
 399     if (isinstance(var, str)) & load_as_numpy:
 400         assert os.path.isfile(var), 'No such path: %s' % var
 401         var = np.load(var)
 402     return var
 403 
 404 
 405 def load_volume(path_volume, im_only=True, squeeze=True, dtype=None, aff_ref=None):
 406 
 407     assert path_volume.endswith(('.nii', '.nii.gz', '.mgz', '.npz')), 'Unknown data file: %s' % path_volume
 408 
 409     if path_volume.endswith(('.nii', '.nii.gz', '.mgz')):
 410         x = nib.load(path_volume)
 411         if squeeze:
 412             volume = np.squeeze(x.get_fdata())
 413         else:
 414             volume = x.get_fdata()
 415         aff = x.affine
 416         header = x.header
 417     else:  # npz
 418         volume = np.load(path_volume)['vol_data']
 419         if squeeze:
 420             volume = np.squeeze(volume)
 421         aff = np.eye(4)
 422         header = nib.Nifti1Header()
 423     if dtype is not None:
 424         if 'int' in dtype:
 425             volume = np.round(volume)
 426         volume = volume.astype(dtype=dtype)
 427 
 428     # align image to reference affine matrix
 429     if aff_ref is not None:
 430         n_dims, _ = get_dims(list(volume.shape), max_channels=10)
 431         volume, aff = align_volume_to_ref(volume, aff, aff_ref=aff_ref, return_aff=True, n_dims=n_dims)
 432 
 433     if im_only:
 434         return volume
 435     else:
 436         return volume, aff, header
 437 
 438 
 439 
 440 
 441 def preprocess(path_image, n_levels=5, crop=None, min_pad=None, path_resample=None):
 442     # read image and corresponding info
 443     im, _, aff, n_dims, n_channels, h, im_res = get_volume_info(path_image, True)
 444     if n_dims < 3:
 445         sf.system.fatal('input should have 3 dimensions, had %s' % n_dims)
 446     elif n_dims == 4 and n_channels == 1:
 447         n_dims = 3
 448         im = im[..., 0]
 449     elif n_dims > 3:
 450         sf.system.fatal('input should have 3 dimensions, had %s' % n_dims)
 451     elif n_channels > 1:
 452         print('WARNING: detected more than 1 channel, only keeping the first channel.')
 453         im = im[..., 0]
 454 
 455     # resample image if necessary
 456     if np.any((im_res > 1.05) | (im_res < 0.95)):
 457         im_res = np.array([1.] * 3)
 458         im, aff = resample_volume(im, aff, im_res)
 459         if path_resample is not None:
 460             save_volume(im, aff, h, path_resample)
 461 
 462     # align image
 463     im = align_volume_to_ref(im, aff, aff_ref=np.eye(4), n_dims=n_dims, return_copy=False)
 464     shape = list(im.shape[:n_dims])
 465 
 466     # crop image if necessary
 467     if crop is not None:
 468         crop = reformat_to_list(crop, length=n_dims, dtype='int')
 469         crop_shape = [find_closest_number_divisible_by_m(s, 2 ** n_levels, 'higher') for s in crop]
 470         im, crop_idx = crop_volume(im, cropping_shape=crop_shape, return_crop_idx=True)
 471     else:
 472         crop_idx = None
 473 
 474     # normalise image
 475     im = rescale_volume(im, new_min=0, new_max=1, min_percentile=0.5, max_percentile=99.5)
 476 
 477     # pad image
 478     input_shape = im.shape[:n_dims]
 479     pad_shape = [find_closest_number_divisible_by_m(s, 2 ** n_levels, 'higher') for s in input_shape]
 480     min_pad = reformat_to_list(min_pad, length=n_dims, dtype='int')
 481     min_pad = [find_closest_number_divisible_by_m(s, 2 ** n_levels, 'higher') for s in min_pad]
 482     pad_shape = np.maximum(pad_shape, min_pad)
 483     im, pad_idx = pad_volume(im, padding_shape=pad_shape, return_pad_idx=True)
 484 
 485     # add batch and channel axes
 486     im = add_axis(im, axis=[0, -1])
 487 
 488     return im, aff, h, im_res, shape, pad_idx, crop_idx
 489 
 490 
 491 def resample_volume(volume, aff, new_vox_size, interpolation='linear'):
 492     pixdim = np.sqrt(np.sum(aff * aff, axis=0))[:-1]
 493     new_vox_size = np.array(new_vox_size)
 494     factor = pixdim / new_vox_size
 495     sigmas = 0.25 / factor
 496     sigmas[factor > 1] = 0  # don't blur if upsampling
 497 
 498     volume_filt = gaussian_filter(volume, sigmas)
 499 
 500     # volume2 = zoom(volume_filt, factor, order=1, mode='reflect', prefilter=False)
 501     x = np.arange(0, volume_filt.shape[0])
 502     y = np.arange(0, volume_filt.shape[1])
 503     z = np.arange(0, volume_filt.shape[2])
 504 
 505     start = - (factor - 1) / (2 * factor)
 506     step = 1.0 / factor
 507     stop = start + step * np.ceil(volume_filt.shape * factor)
 508 
 509     xi = np.arange(start=start[0], stop=stop[0], step=step[0])
 510     yi = np.arange(start=start[1], stop=stop[1], step=step[1])
 511     zi = np.arange(start=start[2], stop=stop[2], step=step[2])
 512     xi[xi < 0] = 0
 513     yi[yi < 0] = 0
 514     zi[zi < 0] = 0
 515     xi[xi > (volume_filt.shape[0] - 1)] = volume_filt.shape[0] - 1
 516     yi[yi > (volume_filt.shape[1] - 1)] = volume_filt.shape[1] - 1
 517     zi[zi > (volume_filt.shape[2] - 1)] = volume_filt.shape[2] - 1
 518 
 519     xig, yig, zig = np.meshgrid(xi, yi, zi, indexing='ij', sparse=False)
 520     xig = torch.tensor(xig, device='cpu')
 521     yig = torch.tensor(yig, device='cpu')
 522     zig = torch.tensor(zig, device='cpu')
 523     volume2 = fast_3D_interp_torch(torch.tensor(volume_filt, device='cpu'), xig, yig, zig, 'linear')
 524 
 525     aff2 = aff.copy()
 526     for c in range(3):
 527         aff2[:-1, c] = aff2[:-1, c] / factor[c]
 528     aff2[:-1, -1] = aff2[:-1, -1] - np.matmul(aff2[:-1, :-1], 0.5 * (factor - 1))
 529 
 530     return volume2.numpy(), aff2
 531 
 532 def find_closest_number_divisible_by_m(n, m, answer_type='lower'):
 533     if n % m == 0:
 534         return n
 535     else:
 536         q = int(n / m)
 537         lower = q * m
 538         higher = (q + 1) * m
 539         if answer_type == 'lower':
 540             return lower
 541         elif answer_type == 'higher':
 542             return higher
 543         elif answer_type == 'closer':
 544             return lower if (n - lower) < (higher - n) else higher
 545         else:
 546             sf.system.fatal('answer_type should be lower, higher, or closer, had : %s' % answer_type)
 547 
 548 
 549 
 550 def get_volume_info(path_volume, return_volume=False, aff_ref=None, max_channels=10):
 551 
 552     im, aff, header = load_volume(path_volume, im_only=False)
 553 
 554     # understand if image is multichannel
 555     im_shape = list(im.shape)
 556     n_dims, n_channels = get_dims(im_shape, max_channels=max_channels)
 557     im_shape = im_shape[:n_dims]
 558 
 559     # get labels res
 560     if '.nii' in path_volume:
 561         data_res = np.array(header['pixdim'][1:n_dims + 1])
 562     elif '.mgz' in path_volume:
 563         data_res = np.array(header['delta'])  # mgz image
 564     else:
 565         data_res = np.array([1.0] * n_dims)
 566 
 567     # align to given affine matrix
 568     if aff_ref is not None:
 569         ras_axes = get_ras_axes(aff, n_dims=n_dims)
 570         ras_axes_ref = get_ras_axes(aff_ref, n_dims=n_dims)
 571         im = align_volume_to_ref(im, aff, aff_ref=aff_ref, n_dims=n_dims)
 572         im_shape = np.array(im_shape)
 573         data_res = np.array(data_res)
 574         im_shape[ras_axes_ref] = im_shape[ras_axes]
 575         data_res[ras_axes_ref] = data_res[ras_axes]
 576         im_shape = im_shape.tolist()
 577 
 578     # return info
 579     if return_volume:
 580         return im, im_shape, aff, n_dims, n_channels, header, data_res
 581     else:
 582         return im_shape, aff, n_dims, n_channels, header, data_res
 583 
 584 def get_dims(shape, max_channels=10):
 585     if shape[-1] <= max_channels:
 586         n_dims = len(shape) - 1
 587         n_channels = shape[-1]
 588     else:
 589         n_dims = len(shape)
 590         n_channels = 1
 591     return n_dims, n_channels
 592 
 593 
 594 def get_ras_axes(aff, n_dims=3):
 595     aff_inverted = np.linalg.inv(aff)
 596     img_ras_axes = np.argmax(np.absolute(aff_inverted[0:n_dims, 0:n_dims]), axis=0)
 597     for i in range(n_dims):
 598         if i not in img_ras_axes:
 599             unique, counts = np.unique(img_ras_axes, return_counts=True)
 600             incorrect_value = unique[np.argmax(counts)]
 601             img_ras_axes[np.where(img_ras_axes == incorrect_value)[0][-1]] = i
 602 
 603     return img_ras_axes
 604 
 605 def align_volume_to_ref(volume, aff, aff_ref=None, return_aff=False, n_dims=None, return_copy=True):
 606 
 607     # work on copy
 608     new_volume = volume.copy() if return_copy else volume
 609     aff_flo = aff.copy()
 610 
 611     # default value for aff_ref
 612     if aff_ref is None:
 613         aff_ref = np.eye(4)
 614 
 615     # extract ras axes
 616     if n_dims is None:
 617         n_dims, _ = get_dims(new_volume.shape)
 618     ras_axes_ref = get_ras_axes(aff_ref, n_dims=n_dims)
 619     ras_axes_flo = get_ras_axes(aff_flo, n_dims=n_dims)
 620 
 621     # align axes
 622     aff_flo[:, ras_axes_ref] = aff_flo[:, ras_axes_flo]
 623     for i in range(n_dims):
 624         if ras_axes_flo[i] != ras_axes_ref[i]:
 625             new_volume = np.swapaxes(new_volume, ras_axes_flo[i], ras_axes_ref[i])
 626             swapped_axis_idx = np.where(ras_axes_flo == ras_axes_ref[i])
 627             ras_axes_flo[swapped_axis_idx], ras_axes_flo[i] = ras_axes_flo[i], ras_axes_flo[swapped_axis_idx]
 628 
 629     # align directions
 630     dot_products = np.sum(aff_flo[:3, :3] * aff_ref[:3, :3], axis=0)
 631     for i in range(n_dims):
 632         if dot_products[i] < 0:
 633             new_volume = np.flip(new_volume, axis=i)
 634             aff_flo[:, i] = - aff_flo[:, i]
 635             aff_flo[:3, 3] = aff_flo[:3, 3] - aff_flo[:3, i] * (new_volume.shape[i] - 1)
 636 
 637     if return_aff:
 638         return new_volume, aff_flo
 639     else:
 640         return new_volume
 641 
 642 def build_seg_model(model_file_segmentation,
 643                 model_file_parcellation,
 644                 labels_segmentation,
 645                 labels_parcellation):
 646 
 647     if not os.path.isfile(model_file_segmentation):
 648         sf.system.fatal("The provided model path does not exist.")
 649 
 650     # get labels
 651     n_labels_seg = len(labels_segmentation)
 652 
 653     # build UNet
 654     net = unet(nb_features=24,
 655                input_shape=[None, None, None, 1],
 656                nb_levels=5,
 657                conv_size=3,
 658                nb_labels=n_labels_seg,
 659                feat_mult=2,
 660                activation='elu',
 661                nb_conv_per_level=2,
 662                batch_norm=-1,
 663                name='unet')
 664     net.load_weights(model_file_segmentation, by_name=True)
 665     input_image = net.inputs[0]
 666     name_segm_prediction_layer = 'unet_prediction'
 667 
 668     # smooth posteriors
 669     last_tensor = net.output
 670     last_tensor._keras_shape = tuple(last_tensor.get_shape().as_list())
 671     last_tensor = GaussianBlur(sigma=0.5)(last_tensor)
 672     net = keras.Model(inputs=net.inputs, outputs=last_tensor)
 673 
 674     # add aparc segmenter
 675     n_labels_parcellation = len(labels_parcellation)
 676 
 677     last_tensor = net.output
 678     last_tensor = KL.Lambda(lambda x: tf.cast(tf.argmax(x, axis=-1), 'int32'))(last_tensor)
 679     last_tensor = ConvertLabels(np.arange(n_labels_seg), labels_segmentation)(last_tensor)
 680     parcellation_masking_values = np.array([1 if ((ll == 3) | (ll == 42)) else 0 for ll in labels_segmentation])
 681     last_tensor = ConvertLabels(labels_segmentation, parcellation_masking_values)(last_tensor)
 682     last_tensor = KL.Lambda(lambda x: tf.one_hot(tf.cast(x, 'int32'), depth=2, axis=-1))(last_tensor)
 683     last_tensor = KL.Lambda(lambda x: tf.cast(tf.concat(x, axis=-1), 'float32'))([input_image, last_tensor])
 684     net = keras.Model(inputs=net.inputs, outputs=last_tensor)
 685 
 686     # build UNet
 687     net = unet(nb_features=24,
 688                input_shape=[None, None, None, 3],
 689                nb_levels=5,
 690                conv_size=3,
 691                nb_labels=n_labels_parcellation,
 692                feat_mult=2,
 693                activation='elu',
 694                nb_conv_per_level=2,
 695                batch_norm=-1,
 696                name='unet_parc',
 697                input_model=net)
 698     net.load_weights(model_file_parcellation, by_name=True)
 699 
 700     # smooth predictions
 701     last_tensor = net.output
 702     last_tensor._keras_shape = tuple(last_tensor.get_shape().as_list())
 703     last_tensor = GaussianBlur(sigma=0.5)(last_tensor)
 704     net = keras.Model(inputs=net.inputs, outputs=[net.get_layer(name_segm_prediction_layer).output, last_tensor])
 705 
 706     return net
 707 
 708 def unet(nb_features,
 709          input_shape,
 710          nb_levels,
 711          conv_size,
 712          nb_labels,
 713          name='unet',
 714          prefix=None,
 715          feat_mult=1,
 716          pool_size=2,
 717          padding='same',
 718          dilation_rate_mult=1,
 719          activation='elu',
 720          skip_n_concatenations=0,
 721          use_residuals=False,
 722          final_pred_activation='softmax',
 723          nb_conv_per_level=1,
 724          layer_nb_feats=None,
 725          conv_dropout=0,
 726          batch_norm=None,
 727          input_model=None):
 728 
 729     # naming
 730     model_name = name
 731     if prefix is None:
 732         prefix = model_name
 733 
 734     # volume size data
 735     ndims = len(input_shape) - 1
 736     if isinstance(pool_size, int):
 737         pool_size = (pool_size,) * ndims
 738 
 739     # get encoding model
 740     enc_model = conv_enc(nb_features,
 741                          input_shape,
 742                          nb_levels,
 743                          conv_size,
 744                          name=model_name,
 745                          prefix=prefix,
 746                          feat_mult=feat_mult,
 747                          pool_size=pool_size,
 748                          padding=padding,
 749                          dilation_rate_mult=dilation_rate_mult,
 750                          activation=activation,
 751                          use_residuals=use_residuals,
 752                          nb_conv_per_level=nb_conv_per_level,
 753                          layer_nb_feats=layer_nb_feats,
 754                          conv_dropout=conv_dropout,
 755                          batch_norm=batch_norm,
 756                          input_model=input_model)
 757 
 758     # get decoder
 759     # use_skip_connections=True makes it a u-net
 760     lnf = layer_nb_feats[(nb_levels * nb_conv_per_level):] if layer_nb_feats is not None else None
 761     dec_model = conv_dec(nb_features,
 762                          None,
 763                          nb_levels,
 764                          conv_size,
 765                          nb_labels,
 766                          name=model_name,
 767                          prefix=prefix,
 768                          feat_mult=feat_mult,
 769                          pool_size=pool_size,
 770                          use_skip_connections=True,
 771                          skip_n_concatenations=skip_n_concatenations,
 772                          padding=padding,
 773                          dilation_rate_mult=dilation_rate_mult,
 774                          activation=activation,
 775                          use_residuals=use_residuals,
 776                          final_pred_activation=final_pred_activation,
 777                          nb_conv_per_level=nb_conv_per_level,
 778                          batch_norm=batch_norm,
 779                          layer_nb_feats=lnf,
 780                          conv_dropout=conv_dropout,
 781                          input_model=enc_model)
 782     final_model = dec_model
 783 
 784     return final_model
 785 
 786 def conv_enc(nb_features,
 787              input_shape,
 788              nb_levels,
 789              conv_size,
 790              name=None,
 791              prefix=None,
 792              feat_mult=1,
 793              pool_size=2,
 794              dilation_rate_mult=1,
 795              padding='same',
 796              activation='elu',
 797              layer_nb_feats=None,
 798              use_residuals=False,
 799              nb_conv_per_level=2,
 800              conv_dropout=0,
 801              batch_norm=None,
 802              input_model=None):
 803 
 804     # naming
 805     model_name = name
 806     if prefix is None:
 807         prefix = model_name
 808 
 809     # first layer: input
 810     name = '%s_input' % prefix
 811     if input_model is None:
 812         input_tensor = KL.Input(shape=input_shape, name=name)
 813         last_tensor = input_tensor
 814     else:
 815         input_tensor = input_model.inputs
 816         last_tensor = input_model.outputs
 817         if isinstance(last_tensor, list):
 818             last_tensor = last_tensor[0]
 819 
 820     # volume size data
 821     ndims = len(input_shape) - 1
 822     if isinstance(pool_size, int):
 823         pool_size = (pool_size,) * ndims
 824 
 825     # prepare layers
 826     convL = getattr(KL, 'Conv%dD' % ndims)
 827     conv_kwargs = {'padding': padding, 'activation': activation, 'data_format': 'channels_last'}
 828     maxpool = getattr(KL, 'MaxPooling%dD' % ndims)
 829 
 830     # down arm:
 831     # add nb_levels of conv + ReLu + conv + ReLu. Pool after each of first nb_levels - 1 layers
 832     lfidx = 0  # level feature index
 833     for level in range(nb_levels):
 834         lvl_first_tensor = last_tensor
 835         nb_lvl_feats = np.round(nb_features * feat_mult ** level).astype(int)
 836         conv_kwargs['dilation_rate'] = dilation_rate_mult ** level
 837 
 838         for conv in range(nb_conv_per_level):  # does several conv per level, max pooling applied at the end
 839             if layer_nb_feats is not None:  # None or List of all the feature numbers
 840                 nb_lvl_feats = layer_nb_feats[lfidx]
 841                 lfidx += 1
 842 
 843             name = '%s_conv_downarm_%d_%d' % (prefix, level, conv)
 844             if conv < (nb_conv_per_level - 1) or (not use_residuals):
 845                 last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor)
 846             else:  # no activation
 847                 last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor)
 848 
 849             if conv_dropout > 0:
 850                 # conv dropout along feature space only
 851                 name = '%s_dropout_downarm_%d_%d' % (prefix, level, conv)
 852                 noise_shape = [None, *[1] * ndims, nb_lvl_feats]
 853                 last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor)
 854 
 855         if use_residuals:
 856             convarm_layer = last_tensor
 857 
 858             # the "add" layer is the original input
 859             # However, it may not have the right number of features to be added
 860             nb_feats_in = lvl_first_tensor.get_shape()[-1]
 861             nb_feats_out = convarm_layer.get_shape()[-1]
 862             add_layer = lvl_first_tensor
 863             if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out):
 864                 name = '%s_expand_down_merge_%d' % (prefix, level)
 865                 last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(lvl_first_tensor)
 866                 add_layer = last_tensor
 867 
 868                 if conv_dropout > 0:
 869                     name = '%s_dropout_down_merge_%d_%d' % (prefix, level, conv)
 870                     noise_shape = [None, *[1] * ndims, nb_lvl_feats]
 871 
 872             name = '%s_res_down_merge_%d' % (prefix, level)
 873             last_tensor = KL.add([add_layer, convarm_layer], name=name)
 874 
 875             name = '%s_res_down_merge_act_%d' % (prefix, level)
 876             last_tensor = KL.Activation(activation, name=name)(last_tensor)
 877 
 878         if batch_norm is not None:
 879             name = '%s_bn_down_%d' % (prefix, level)
 880             last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor)
 881 
 882         # max pool if we're not at the last level
 883         if level < (nb_levels - 1):
 884             name = '%s_maxpool_%d' % (prefix, level)
 885             last_tensor = maxpool(pool_size=pool_size, name=name, padding=padding)(last_tensor)
 886 
 887     # create the model and return
 888     model = keras.Model(inputs=input_tensor, outputs=[last_tensor], name=model_name)
 889     return model
 890 
 891 
 892 def conv_dec(nb_features,
 893              input_shape,
 894              nb_levels,
 895              conv_size,
 896              nb_labels,
 897              name=None,
 898              prefix=None,
 899              feat_mult=1,
 900              pool_size=2,
 901              use_skip_connections=False,
 902              skip_n_concatenations=0,
 903              padding='same',
 904              dilation_rate_mult=1,
 905              activation='elu',
 906              use_residuals=False,
 907              final_pred_activation='softmax',
 908              nb_conv_per_level=2,
 909              layer_nb_feats=None,
 910              batch_norm=None,
 911              conv_dropout=0,
 912              input_model=None):
 913 
 914     # naming
 915     model_name = name
 916     if prefix is None:
 917         prefix = model_name
 918 
 919     # if using skip connections, make sure need to use them.
 920     if use_skip_connections:
 921         assert input_model is not None, "is using skip connections, tensors dictionary is required"
 922 
 923     # first layer: input
 924     input_name = '%s_input' % prefix
 925     if input_model is None:
 926         input_tensor = KL.Input(shape=input_shape, name=input_name)
 927         last_tensor = input_tensor
 928     else:
 929         input_tensor = input_model.input
 930         last_tensor = input_model.output
 931         input_shape = last_tensor.shape.as_list()[1:]
 932 
 933     # vol size info
 934     ndims = len(input_shape) - 1
 935     if isinstance(pool_size, int):
 936         if ndims > 1:
 937             pool_size = (pool_size,) * ndims
 938 
 939     # prepare layers
 940     convL = getattr(KL, 'Conv%dD' % ndims)
 941     conv_kwargs = {'padding': padding, 'activation': activation}
 942     upsample = getattr(KL, 'UpSampling%dD' % ndims)
 943 
 944     # up arm:
 945     # nb_levels - 1 layers of Deconvolution3D
 946     #    (approx via up + conv + ReLu) + merge + conv + ReLu + conv + ReLu
 947     lfidx = 0
 948     for level in range(nb_levels - 1):
 949         nb_lvl_feats = np.round(nb_features * feat_mult ** (nb_levels - 2 - level)).astype(int)
 950         conv_kwargs['dilation_rate'] = dilation_rate_mult ** (nb_levels - 2 - level)
 951 
 952         # upsample matching the max pooling layers size
 953         name = '%s_up_%d' % (prefix, nb_levels + level)
 954         last_tensor = upsample(size=pool_size, name=name)(last_tensor)
 955         up_tensor = last_tensor
 956 
 957         # merge layers combining previous layer
 958         if use_skip_connections & (level < (nb_levels - skip_n_concatenations - 1)):
 959             conv_name = '%s_conv_downarm_%d_%d' % (prefix, nb_levels - 2 - level, nb_conv_per_level - 1)
 960             cat_tensor = input_model.get_layer(conv_name).output
 961             name = '%s_merge_%d' % (prefix, nb_levels + level)
 962             last_tensor = KL.concatenate([cat_tensor, last_tensor], axis=ndims + 1, name=name)
 963 
 964         # convolution layers
 965         for conv in range(nb_conv_per_level):
 966             if layer_nb_feats is not None:
 967                 nb_lvl_feats = layer_nb_feats[lfidx]
 968                 lfidx += 1
 969 
 970             name = '%s_conv_uparm_%d_%d' % (prefix, nb_levels + level, conv)
 971             if conv < (nb_conv_per_level - 1) or (not use_residuals):
 972                 last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor)
 973             else:
 974                 last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor)
 975 
 976             if conv_dropout > 0:
 977                 name = '%s_dropout_uparm_%d_%d' % (prefix, level, conv)
 978                 noise_shape = [None, *[1] * ndims, nb_lvl_feats]
 979                 last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor)
 980 
 981         # residual block
 982         if use_residuals:
 983 
 984             # the "add" layer is the original input
 985             # However, it may not have the right number of features to be added
 986             add_layer = up_tensor
 987             nb_feats_in = add_layer.get_shape()[-1]
 988             nb_feats_out = last_tensor.get_shape()[-1]
 989             if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out):
 990                 name = '%s_expand_up_merge_%d' % (prefix, level)
 991                 add_layer = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(add_layer)
 992 
 993                 if conv_dropout > 0:
 994                     name = '%s_dropout_up_merge_%d_%d' % (prefix, level, conv)
 995                     noise_shape = [None, *[1] * ndims, nb_lvl_feats]
 996                     last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor)
 997 
 998             name = '%s_res_up_merge_%d' % (prefix, level)
 999             last_tensor = KL.add([last_tensor, add_layer], name=name)
1000 
1001             name = '%s_res_up_merge_act_%d' % (prefix, level)
1002             last_tensor = KL.Activation(activation, name=name)(last_tensor)
1003 
1004         if batch_norm is not None:
1005             name = '%s_bn_up_%d' % (prefix, level)
1006             last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor)
1007 
1008     # Compute likelyhood prediction (no activation yet)
1009     name = '%s_likelihood' % prefix
1010     last_tensor = convL(nb_labels, 1, activation=None, name=name)(last_tensor)
1011     like_tensor = last_tensor
1012 
1013     # output prediction layer
1014     # we use a softmax to compute P(L_x|I) where x is each location
1015     if final_pred_activation == 'softmax':
1016         name = '%s_prediction' % prefix
1017         softmax_lambda_fcn = lambda x: keras.activations.softmax(x, axis=ndims + 1)
1018         pred_tensor = KL.Lambda(softmax_lambda_fcn, name=name)(last_tensor)
1019 
1020     # otherwise create a layer that does nothing.
1021     else:
1022         name = '%s_prediction' % prefix
1023         pred_tensor = KL.Activation('linear', name=name)(like_tensor)
1024 
1025     # create the model and retun
1026     model = keras.Model(inputs=input_tensor, outputs=pred_tensor, name=model_name)
1027     return model
1028 
1029 def postprocess(post_patch_seg, post_patch_parc, shape, pad_idx, crop_idx,
1030                 labels_segmentation, labels_parcellation, aff, im_res):
1031 
1032     # get posteriors
1033     post_patch_seg = np.squeeze(post_patch_seg)
1034     post_patch_seg = crop_volume_with_idx(post_patch_seg, pad_idx, n_dims=3, return_copy=False)
1035 
1036     # keep biggest connected component
1037     tmp_post_patch_seg = post_patch_seg[..., 1:]
1038     post_patch_seg_mask = np.sum(tmp_post_patch_seg, axis=-1) > 0.25
1039     post_patch_seg_mask = get_largest_connected_component(post_patch_seg_mask)
1040     post_patch_seg_mask = np.stack([post_patch_seg_mask]*tmp_post_patch_seg.shape[-1], axis=-1)
1041     tmp_post_patch_seg = mask_volume(tmp_post_patch_seg, mask=post_patch_seg_mask, return_copy=False)
1042     post_patch_seg[..., 1:] = tmp_post_patch_seg
1043 
1044     # reset posteriors to zero outside the largest connected component of each topological class
1045     post_patch_seg_mask = post_patch_seg > 0.2
1046     post_patch_seg[..., 1:] *= post_patch_seg_mask[..., 1:]
1047 
1048     # get hard segmentation
1049     post_patch_seg /= np.sum(post_patch_seg, axis=-1)[..., np.newaxis]
1050     seg_patch = labels_segmentation[post_patch_seg.argmax(-1).astype('int32')].astype('int32')
1051 
1052     # postprocess parcellation
1053     post_patch_parc = np.squeeze(post_patch_parc)
1054     post_patch_parc = crop_volume_with_idx(post_patch_parc, pad_idx, n_dims=3, return_copy=False)
1055     mask = (seg_patch == 3) | (seg_patch == 42)
1056     post_patch_parc[..., 0] = np.ones_like(post_patch_parc[..., 0])
1057     post_patch_parc[..., 0] = mask_volume(post_patch_parc[..., 0], mask=mask < 0.1, return_copy=False)
1058     post_patch_parc /= np.sum(post_patch_parc, axis=-1)[..., np.newaxis]
1059     parc_patch = labels_parcellation[post_patch_parc.argmax(-1).astype('int32')].astype('int32')
1060     seg_patch[mask] = parc_patch[mask]
1061 
1062     # paste patches back to matrix of original image size
1063     if crop_idx is not None:
1064         # we need to go through this because of the posteriors of the background, otherwise pad_volume would work
1065         seg = np.zeros(shape=shape, dtype='int32')
1066         posteriors = np.zeros(shape=[*shape, labels_segmentation.shape[0]])
1067         posteriors[..., 0] = np.ones(shape)  # place background around patch
1068         seg[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5]] = seg_patch
1069         posteriors[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5], :] = post_patch_seg
1070     else:
1071         seg = seg_patch
1072         posteriors = post_patch_seg
1073 
1074     # align prediction back to first orientation
1075     seg = align_volume_to_ref(seg, aff=np.eye(4), aff_ref=aff, n_dims=3, return_copy=False)
1076     posteriors = align_volume_to_ref(posteriors, np.eye(4), aff_ref=aff, n_dims=3, return_copy=False)
1077 
1078     # compute volumes
1079     volumes = np.sum(posteriors[..., 1:], axis=tuple(range(0, len(posteriors.shape) - 1)))
1080     volumes = np.concatenate([np.array([np.sum(volumes)]), volumes])
1081     if post_patch_parc is not None:
1082         volumes_parc = np.sum(post_patch_parc[..., 1:], axis=tuple(range(0, len(posteriors.shape) - 1)))
1083         total_volume_cortex = np.sum(volumes[np.where((labels_segmentation == 3) | (labels_segmentation == 42))[0] - 1])
1084         volumes_parc = volumes_parc / np.sum(volumes_parc) * total_volume_cortex
1085         volumes = np.concatenate([volumes, volumes_parc])
1086     volumes = np.around(volumes * np.prod(im_res), 3)
1087 
1088     return seg, posteriors, volumes
1089 
1090 def save_volume(volume, aff, header, path, res=None, dtype=None, n_dims=3):
1091     mkdir(os.path.dirname(path))
1092     if '.npz' in path:
1093         np.savez_compressed(path, vol_data=volume)
1094     else:
1095         if header is None:
1096             header = nib.Nifti1Header()
1097         if isinstance(aff, str):
1098             if aff == 'FS':
1099                 aff = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]])
1100         elif aff is None:
1101             aff = np.eye(4)
1102         nifty = nib.Nifti1Image(volume, aff, header)
1103         if dtype is not None:
1104             if 'int' in dtype:
1105                 volume = np.round(volume)
1106             volume = volume.astype(dtype=dtype)
1107             nifty.set_data_dtype(dtype)
1108         if res is not None:
1109             if n_dims is None:
1110                 n_dims, _ = get_dims(volume.shape)
1111             res = reformat_to_list(res, length=n_dims, dtype=None)
1112             nifty.header.set_zooms(res)
1113         nib.save(nifty, path)
1114 
1115 
1116 
1117 def mkdir(path_dir):
1118 
1119     if len(path_dir)>0:
1120         if path_dir[-1] == '/':
1121             path_dir = path_dir[:-1]
1122         if not os.path.isdir(path_dir):
1123             list_dir_to_create = [path_dir]
1124             while not os.path.isdir(os.path.dirname(list_dir_to_create[-1])):
1125                 list_dir_to_create.append(os.path.dirname(list_dir_to_create[-1]))
1126             for dir_to_create in reversed(list_dir_to_create):
1127                 os.mkdir(dir_to_create)
1128 
1129 
1130 def getM(ref, mov):
1131     zmat = np.zeros(ref.shape[::-1])
1132     zcol = np.zeros([ref.shape[1], 1])
1133     ocol = np.ones([ref.shape[1], 1])
1134     zero = np.zeros(zmat.shape)
1135     A = np.concatenate([
1136         np.concatenate([np.transpose(ref), zero, zero, ocol, zcol, zcol], axis=1),
1137         np.concatenate([zero, np.transpose(ref), zero, zcol, ocol, zcol], axis=1),
1138         np.concatenate([zero, zero, np.transpose(ref), zcol, zcol, ocol], axis=1)], axis=0)
1139     b = np.concatenate([np.transpose(mov[0, :]), np.transpose(mov[1, :]), np.transpose(mov[2, :])], axis=0)
1140     x = np.matmul(np.linalg.inv(np.matmul(np.transpose(A), A)), np.matmul(np.transpose(A), b))
1141     M = np.stack([
1142         [x[0], x[1], x[2], x[9]],
1143         [x[3], x[4], x[5], x[10]],
1144         [x[6], x[7], x[8], x[11]],
1145         [0, 0, 0, 1]])
1146     return M
1147 
1148 def getMrigid(A, B):
1149     centroid_A = np.mean(A, axis=1, keepdims=True)
1150     centroid_B = np.mean(B, axis=1, keepdims=True)
1151     A_centered = A - centroid_A
1152     B_centered = B - centroid_B
1153     H = A_centered @ B_centered.T
1154     U, S, Vt = np.linalg.svd(H)
1155     R = Vt.T @ U.T
1156     if np.linalg.det(R) < 0:
1157         Vt[2, :] *= -1
1158         R = Vt.T @ U.T
1159     t = centroid_B - R @ centroid_A
1160     T = np.eye(4)
1161     T[:3, :3] = R
1162     T[:3, 3] = t.flatten()
1163     return T
1164 
1165 
1166 def fast_3D_interp_torch(X, II, JJ, KK, mode):
1167     if mode=='nearest':
1168         IIr = torch.round(II).long()
1169         JJr = torch.round(JJ).long()
1170         KKr = torch.round(KK).long()
1171         IIr[IIr < 0] = 0
1172         JJr[JJr < 0] = 0
1173         KKr[KKr < 0] = 0
1174         IIr[IIr > (X.shape[0] - 1)] = (X.shape[0] - 1)
1175         JJr[JJr > (X.shape[1] - 1)] = (X.shape[1] - 1)
1176         KKr[KKr > (X.shape[2] - 1)] = (X.shape[2] - 1)
1177         Y = X[IIr, JJr, KKr]
1178     elif mode=='linear':
1179         ok = (II>0) & (JJ>0) & (KK>0) & (II<=X.shape[0]-1) & (JJ<=X.shape[1]-1) & (KK<=X.shape[2]-1)
1180         IIv = II[ok]
1181         JJv = JJ[ok]
1182         KKv = KK[ok]
1183 
1184         fx = torch.floor(IIv).long()
1185         cx = fx + 1
1186         cx[cx > (X.shape[0] - 1)] = (X.shape[0] - 1)
1187         wcx = IIv - fx
1188         wfx = 1 - wcx
1189 
1190         fy = torch.floor(JJv).long()
1191         cy = fy + 1
1192         cy[cy > (X.shape[1] - 1)] = (X.shape[1] - 1)
1193         wcy = JJv - fy
1194         wfy = 1 - wcy
1195 
1196         fz = torch.floor(KKv).long()
1197         cz = fz + 1
1198         cz[cz > (X.shape[2] - 1)] = (X.shape[2] - 1)
1199         wcz = KKv - fz
1200         wfz = 1 - wcz
1201 
1202         c000 = X[fx, fy, fz]
1203         c100 = X[cx, fy, fz]
1204         c010 = X[fx, cy, fz]
1205         c110 = X[cx, cy, fz]
1206         c001 = X[fx, fy, cz]
1207         c101 = X[cx, fy, cz]
1208         c011 = X[fx, cy, cz]
1209         c111 = X[cx, cy, cz]
1210 
1211         c00 = c000 * wfx + c100 * wcx
1212         c01 = c001 * wfx + c101 * wcx
1213         c10 = c010 * wfx + c110 * wcx
1214         c11 = c011 * wfx + c111 * wcx
1215 
1216         c0 = c00 * wfy + c10 * wcy
1217         c1 = c01 * wfy + c11 * wcy
1218 
1219         c = c0 * wfz + c1 * wcz
1220 
1221         Y = torch.zeros(II.shape, device='cpu')
1222         Y[ok] = c.float()
1223 
1224     else:
1225         sf.system.fatal('mode must be linear or nearest')
1226 
1227     return Y
1228 
1229 
1230 
1231 def fast_3D_interp_field_torch(X, II, JJ, KK):
1232 
1233     ok = (II > 0) & (JJ > 0) & (KK > 0) & (II <= X.shape[0] - 1) & (JJ <= X.shape[1] - 1) & (KK <= X.shape[2] - 1)
1234     IIv = II[ok]
1235     JJv = JJ[ok]
1236     KKv = KK[ok]
1237 
1238     fx = torch.floor(IIv).long()
1239     cx = fx + 1
1240     cx[cx > (X.shape[0] - 1)] = (X.shape[0] - 1)
1241     wcx = IIv - fx
1242     wfx = 1 - wcx
1243 
1244     fy = torch.floor(JJv).long()
1245     cy = fy + 1
1246     cy[cy > (X.shape[1] - 1)] = (X.shape[1] - 1)
1247     wcy = JJv - fy
1248     wfy = 1 - wcy
1249 
1250     fz = torch.floor(KKv).long()
1251     cz = fz + 1
1252     cz[cz > (X.shape[2] - 1)] = (X.shape[2] - 1)
1253     wcz = KKv - fz
1254     wfz = 1 - wcz
1255 
1256     Y = torch.zeros([*II.shape, 3], device='cpu')
1257     for channel in range(3):
1258 
1259         Xc = X[:, :, :, channel]
1260 
1261         c000 = Xc[fx, fy, fz]
1262         c100 = Xc[cx, fy, fz]
1263         c010 = Xc[fx, cy, fz]
1264         c110 = Xc[cx, cy, fz]
1265         c001 = Xc[fx, fy, cz]
1266         c101 = Xc[cx, fy, cz]
1267         c011 = Xc[fx, cy, cz]
1268         c111 = Xc[cx, cy, cz]
1269 
1270         c00 = c000 * wfx + c100 * wcx
1271         c01 = c001 * wfx + c101 * wcx
1272         c10 = c010 * wfx + c110 * wcx
1273         c11 = c011 * wfx + c111 * wcx
1274 
1275         c0 = c00 * wfy + c10 * wcy
1276         c1 = c01 * wfy + c11 * wcy
1277 
1278         c = c0 * wfz + c1 * wcz
1279 
1280         Yc = torch.zeros(II.shape, device='cpu')
1281         Yc[ok] = c.float()
1282 
1283         Y[:, :, :, channel] = Yc
1284 
1285     return Y
1286 
1287 
1288 def crop_volume_with_idx(volume, crop_idx, aff=None, n_dims=None, return_copy=True):
1289 
1290     # get info
1291     new_volume = volume.copy() if return_copy else volume
1292     n_dims = int(np.array(crop_idx).shape[0] / 2) if n_dims is None else n_dims
1293 
1294     # crop image
1295     if n_dims == 2:
1296         new_volume = new_volume[crop_idx[0]:crop_idx[2], crop_idx[1]:crop_idx[3], ...]
1297     elif n_dims == 3:
1298         new_volume = new_volume[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5], ...]
1299     else:
1300         sf.system.fatal('cannot crop volumes with more than 3 dimensions')
1301 
1302     if aff is not None:
1303         aff[0:3, -1] = aff[0:3, -1] + aff[:3, :3] @ crop_idx[:3]
1304         return new_volume, aff
1305     else:
1306         return new_volume
1307 
1308 
1309 def get_largest_connected_component(mask, structure=None):
1310     components, n_components = scipy_label(mask, structure)
1311     return components == np.argmax(np.bincount(components.flat)[1:]) + 1 if n_components > 0 else mask.copy()
1312 
1313 
1314 def mask_volume(volume, mask=None, threshold=0.1, dilate=0, erode=0, fill_holes=False, masking_value=0,
1315                 return_mask=False, return_copy=True):
1316 
1317     # get info
1318     new_volume = volume.copy() if return_copy else volume
1319     vol_shape = list(new_volume.shape)
1320     n_dims, n_channels = get_dims(vol_shape)
1321 
1322     # get mask and erode/dilate it
1323     if mask is None:
1324         mask = new_volume >= threshold
1325     else:
1326         assert list(mask.shape[:n_dims]) == vol_shape[:n_dims], 'mask should have shape {0}, or {1}, had {2}'.format(
1327             vol_shape[:n_dims], vol_shape[:n_dims] + [n_channels], list(mask.shape))
1328         mask = mask > 0
1329     if dilate > 0:
1330         dilate_struct = build_binary_structure(dilate, n_dims)
1331         mask_to_apply = binary_dilation(mask, dilate_struct)
1332     else:
1333         mask_to_apply = mask
1334     if erode > 0:
1335         erode_struct = build_binary_structure(erode, n_dims)
1336         mask_to_apply = binary_erosion(mask_to_apply, erode_struct)
1337     if fill_holes:
1338         mask_to_apply = binary_fill_holes(mask_to_apply)
1339 
1340     # replace values outside of mask by padding_char
1341     if mask_to_apply.shape == new_volume.shape:
1342         new_volume[np.logical_not(mask_to_apply)] = masking_value
1343     else:
1344         new_volume[np.stack([np.logical_not(mask_to_apply)] * n_channels, axis=-1)] = masking_value
1345 
1346     if return_mask:
1347         return new_volume, mask_to_apply
1348     else:
1349         return new_volume
1350 
1351 
1352 def build_binary_structure(connectivity, n_dims, shape=None):
1353     if shape is None:
1354         shape = [connectivity * 2 + 1] * n_dims
1355     else:
1356         shape = reformat_to_list(shape, length=n_dims)
1357     dist = np.ones(shape)
1358     center = tuple([tuple([int(s / 2)]) for s in shape])
1359     dist[center] = 0
1360     dist = distance_transform_edt(dist)
1361     struct = (dist <= connectivity) * 1
1362     return struct
1363 
1364 
1365 
1366 def crop_volume(volume, cropping_margin=None, cropping_shape=None, aff=None, return_crop_idx=False, mode='center'):
1367 
1368     assert (cropping_margin is not None) | (cropping_shape is not None), \
1369         'cropping_margin or cropping_shape should be provided'
1370     assert not ((cropping_margin is not None) & (cropping_shape is not None)), \
1371         'only one of cropping_margin or cropping_shape should be provided'
1372 
1373     # get info
1374     new_volume = volume.copy()
1375     vol_shape = new_volume.shape
1376     n_dims, _ = get_dims(vol_shape)
1377 
1378     # find cropping indices
1379     if cropping_margin is not None:
1380         cropping_margin = reformat_to_list(cropping_margin, length=n_dims)
1381         do_cropping = np.array(vol_shape[:n_dims]) > 2 * np.array(cropping_margin)
1382         min_crop_idx = [cropping_margin[i] if do_cropping[i] else 0 for i in range(n_dims)]
1383         max_crop_idx = [vol_shape[i] - cropping_margin[i] if do_cropping[i] else vol_shape[i] for i in range(n_dims)]
1384     else:
1385         cropping_shape = reformat_to_list(cropping_shape, length=n_dims)
1386         if mode == 'center':
1387             min_crop_idx = np.maximum([int((vol_shape[i] - cropping_shape[i]) / 2) for i in range(n_dims)], 0)
1388             max_crop_idx = np.minimum([min_crop_idx[i] + cropping_shape[i] for i in range(n_dims)],
1389                                       np.array(vol_shape)[:n_dims])
1390         elif mode == 'random':
1391             crop_max_val = np.maximum(np.array([vol_shape[i] - cropping_shape[i] for i in range(n_dims)]), 0)
1392             min_crop_idx = np.random.randint(0, high=crop_max_val + 1)
1393             max_crop_idx = np.minimum(min_crop_idx + np.array(cropping_shape), np.array(vol_shape)[:n_dims])
1394         else:
1395             raise ValueError('mode should be either "center" or "random", had %s' % mode)
1396     crop_idx = np.concatenate([np.array(min_crop_idx), np.array(max_crop_idx)])
1397 
1398     # crop volume
1399     if n_dims == 2:
1400         new_volume = new_volume[crop_idx[0]: crop_idx[2], crop_idx[1]: crop_idx[3], ...]
1401     elif n_dims == 3:
1402         new_volume = new_volume[crop_idx[0]: crop_idx[3], crop_idx[1]: crop_idx[4], crop_idx[2]: crop_idx[5], ...]
1403 
1404     # sort outputs
1405     output = [new_volume]
1406     if aff is not None:
1407         aff[0:3, -1] = aff[0:3, -1] + aff[:3, :3] @ np.array(min_crop_idx)
1408         output.append(aff)
1409     if return_crop_idx:
1410         output.append(crop_idx)
1411     return output[0] if len(output) == 1 else tuple(output)
1412 
1413 
1414 
1415 def rescale_volume(volume, new_min=0, new_max=255, min_percentile=2., max_percentile=98., use_positive_only=False):
1416 
1417     # select only positive intensities
1418     new_volume = volume.copy()
1419     intensities = new_volume[new_volume > 0] if use_positive_only else new_volume.flatten()
1420 
1421     # define min and max intensities in original image for normalisation
1422     robust_min = np.min(intensities) if min_percentile == 0 else np.percentile(intensities, min_percentile)
1423     robust_max = np.max(intensities) if max_percentile == 100 else np.percentile(intensities, max_percentile)
1424 
1425     # trim values outside range
1426     new_volume = np.clip(new_volume, robust_min, robust_max)
1427 
1428     # rescale image
1429     if robust_min != robust_max:
1430         return new_min + (new_volume - robust_min) / (robust_max - robust_min) * (new_max - new_min)
1431     else:  # avoid dividing by zero
1432         return np.zeros_like(new_volume)
1433 
1434 
1435 
1436 
1437 def pad_volume(volume, padding_shape, padding_value=0, aff=None, return_pad_idx=False):
1438     # get info
1439     new_volume = volume.copy()
1440     vol_shape = new_volume.shape
1441     n_dims, n_channels = get_dims(vol_shape)
1442     padding_shape = reformat_to_list(padding_shape, length=n_dims, dtype='int')
1443 
1444     # check if need to pad
1445     if np.any(np.array(padding_shape, dtype='int32') > np.array(vol_shape[:n_dims], dtype='int32')):
1446 
1447         # get padding margins
1448         min_margins = np.maximum(np.int32(np.floor((np.array(padding_shape) - np.array(vol_shape)[:n_dims]) / 2)), 0)
1449         max_margins = np.maximum(np.int32(np.ceil((np.array(padding_shape) - np.array(vol_shape)[:n_dims]) / 2)), 0)
1450         pad_idx = np.concatenate([min_margins, min_margins + np.array(vol_shape[:n_dims])])
1451         pad_margins = tuple([(min_margins[i], max_margins[i]) for i in range(n_dims)])
1452         if n_channels > 1:
1453             pad_margins = tuple(list(pad_margins) + [(0, 0)])
1454 
1455         # pad volume
1456         new_volume = np.pad(new_volume, pad_margins, mode='constant', constant_values=padding_value)
1457 
1458         if aff is not None:
1459             if n_dims == 2:
1460                 min_margins = np.append(min_margins, 0)
1461             aff[:-1, -1] = aff[:-1, -1] - aff[:-1, :-1] @ min_margins
1462 
1463     else:
1464         pad_idx = np.concatenate([np.array([0] * n_dims), np.array(vol_shape[:n_dims])])
1465 
1466     # sort outputs
1467     output = [new_volume]
1468     if aff is not None:
1469         output.append(aff)
1470     if return_pad_idx:
1471         output.append(pad_idx)
1472     return output[0] if len(output) == 1 else tuple(output)
1473 
1474 
1475 
1476 def add_axis(x, axis=0):
1477     axis = reformat_to_list(axis)
1478     for ax in axis:
1479         x = np.expand_dims(x, axis=ax)
1480     return x
1481 
1482 
1483 def volshape_to_meshgrid(volshape, **kwargs):
1484     """
1485     compute Tensor meshgrid from a volume size
1486     """
1487 
1488     isint = [float(d).is_integer() for d in volshape]
1489     if not all(isint):
1490         raise ValueError("volshape needs to be a list of integers")
1491 
1492     linvec = [tf.range(0, d) for d in volshape]
1493     return meshgrid(*linvec, **kwargs)
1494 
1495 
1496 def meshgrid(*args, **kwargs):
1497 
1498     indexing = kwargs.pop("indexing", "xy")
1499     if kwargs:
1500         key = list(kwargs.keys())[0]
1501         raise TypeError("'{}' is an invalid keyword argument "
1502                         "for this function".format(key))
1503 
1504     if indexing not in ("xy", "ij"):
1505         raise ValueError("indexing parameter must be either 'xy' or 'ij'")
1506 
1507     # with ops.name_scope(name, "meshgrid", args) as name:
1508     ndim = len(args)
1509     s0 = (1,) * ndim
1510 
1511     # Prepare reshape by inserting dimensions with size 1 where needed
1512     output = []
1513     for i, x in enumerate(args):
1514         output.append(tf.reshape(tf.stack(x), (s0[:i] + (-1,) + s0[i + 1::])))
1515     # Create parameters for broadcasting each tensor to the full size
1516     shapes = [tf.size(x) for x in args]
1517     sz = [x.get_shape().as_list()[0] for x in args]
1518 
1519     # output_dtype = tf.convert_to_tensor(args[0]).dtype.base_dtype
1520     if indexing == "xy" and ndim > 1:
1521         output[0] = tf.reshape(output[0], (1, -1) + (1,) * (ndim - 2))
1522         output[1] = tf.reshape(output[1], (-1, 1) + (1,) * (ndim - 2))
1523         shapes[0], shapes[1] = shapes[1], shapes[0]
1524         sz[0], sz[1] = sz[1], sz[0]
1525 
1526     # This is the part of the implementation from tf that is slow.
1527     # We replace it below to get a ~6x speedup (essentially using tile instead of * tf.ones())
1528     # mult_fact = tf.ones(shapes, output_dtype)
1529     # return [x * mult_fact for x in output]
1530     for i in range(len(output)):
1531         stack_sz = [*sz[:i], 1, *sz[(i + 1):]]
1532         if indexing == 'xy' and ndim > 1 and i < 2:
1533             stack_sz[0], stack_sz[1] = stack_sz[1], stack_sz[0]
1534         output[i] = tf.tile(output[i], tf.stack(stack_sz))
1535     return output
1536 
1537 
1538 def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True):
1539 
1540     # convert sigma into a tensor
1541     if not tf.is_tensor(sigma):
1542         sigma_tens = tf.convert_to_tensor(reformat_to_list(sigma), dtype='float32')
1543     else:
1544         assert max_sigma is not None, 'max_sigma must be provided when sigma is given as a tensor'
1545         sigma_tens = sigma
1546     shape = sigma_tens.get_shape().as_list()
1547 
1548     # get n_dims and batchsize
1549     if shape[0] is not None:
1550         n_dims = shape[0]
1551         batchsize = None
1552     else:
1553         n_dims = shape[1]
1554         batchsize = tf.split(tf.shape(sigma_tens), [1, -1])[0]
1555 
1556     # reformat max_sigma
1557     if max_sigma is not None:  # dynamic blurring
1558         max_sigma = np.array(reformat_to_list(max_sigma, length=n_dims))
1559     else:  # sigma is fixed
1560         max_sigma = np.array(reformat_to_list(sigma, length=n_dims))
1561 
1562     # randomise the burring std dev and/or split it between dimensions
1563     if blur_range is not None:
1564         if blur_range != 1:
1565             sigma_tens = sigma_tens * tf.random.uniform(tf.shape(sigma_tens), minval=1 / blur_range, maxval=blur_range)
1566 
1567     # get size of blurring kernels
1568     windowsize = np.int32(np.ceil(2.5 * max_sigma) / 2) * 2 + 1
1569 
1570     if separable:
1571 
1572         split_sigma = tf.split(sigma_tens, [1] * n_dims, axis=-1)
1573 
1574         kernels = list()
1575         comb = np.array(list(combinations(list(range(n_dims)), n_dims - 1))[::-1])
1576         for (i, wsize) in enumerate(windowsize):
1577 
1578             if wsize > 1:
1579 
1580                 # build meshgrid and replicate it along batch dim if dynamic blurring
1581                 locations = tf.cast(tf.range(0, wsize), 'float32') - (wsize - 1) / 2
1582                 if batchsize is not None:
1583                     locations = tf.tile(tf.expand_dims(locations, axis=0),
1584                                         tf.concat([batchsize, tf.ones(tf.shape(tf.shape(locations)), dtype='int32')],
1585                                                   axis=0))
1586                     comb[i] += 1
1587 
1588                 # compute gaussians
1589                 exp_term = -K.square(locations) / (2 * split_sigma[i] ** 2)
1590                 g = tf.exp(exp_term - tf.math.log(np.sqrt(2 * np.pi) * split_sigma[i]))
1591                 g = g / tf.reduce_sum(g)
1592 
1593                 for axis in comb[i]:
1594                     g = tf.expand_dims(g, axis=axis)
1595                 kernels.append(tf.expand_dims(tf.expand_dims(g, -1), -1))
1596 
1597             else:
1598                 kernels.append(None)
1599 
1600     else:
1601 
1602         # build meshgrid
1603         mesh = [tf.cast(f, 'float32') for f in volshape_to_meshgrid(windowsize, indexing='ij')]
1604         diff = tf.stack([mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1)
1605 
1606         # replicate meshgrid to batch size and reshape sigma_tens
1607         if batchsize is not None:
1608             diff = tf.tile(tf.expand_dims(diff, axis=0),
1609                            tf.concat([batchsize, tf.ones(tf.shape(tf.shape(diff)), dtype='int32')], axis=0))
1610             for i in range(n_dims):
1611                 sigma_tens = tf.expand_dims(sigma_tens, axis=1)
1612         else:
1613             for i in range(n_dims):
1614                 sigma_tens = tf.expand_dims(sigma_tens, axis=0)
1615 
1616         # compute gaussians
1617         sigma_is_0 = tf.equal(sigma_tens, 0)
1618         exp_term = -K.square(diff) / (2 * tf.where(sigma_is_0, tf.ones_like(sigma_tens), sigma_tens)**2)
1619         norms = exp_term - tf.math.log(tf.where(sigma_is_0, tf.ones_like(sigma_tens), np.sqrt(2 * np.pi) * sigma_tens))
1620         kernels = K.sum(norms, -1)
1621         kernels = tf.exp(kernels)
1622         kernels /= tf.reduce_sum(kernels)
1623         kernels = tf.expand_dims(tf.expand_dims(kernels, -1), -1)
1624 
1625     return kernels
1626 
1627 
1628 def get_mapping_lut(source, dest=None):
1629     """This functions returns the look-up table to map a list of N values (source) to another list (dest).
1630     If the second list is not given, we assume it is equal to [0, ..., N-1]."""
1631 
1632     # initialise
1633     source = np.array(reformat_to_list(source), dtype='int32')
1634     n_labels = source.shape[0]
1635 
1636     # build new label list if neccessary
1637     if dest is None:
1638         dest = np.arange(n_labels, dtype='int32')
1639     else:
1640         assert len(source) == len(dest), 'label_list and new_label_list should have the same length'
1641         dest = np.array(reformat_to_list(dest, dtype='int'))
1642 
1643     # build look-up table
1644     lut = np.zeros(np.max(source) + 1, dtype='int32')
1645     for source, dest in zip(source, dest):
1646         lut[source] = dest
1647 
1648     return lut
1649 
1650 
1651 class GaussianBlur(KL.Layer):
1652     """Applies gaussian blur to an input image."""
1653 
1654     def __init__(self, sigma, random_blur_range=None, use_mask=False, **kwargs):
1655         self.sigma = reformat_to_list(sigma)
1656         assert np.all(np.array(self.sigma) >= 0), 'sigma should be superior or equal to 0'
1657         self.use_mask = use_mask
1658 
1659         self.n_dims = None
1660         self.n_channels = None
1661         self.blur_range = random_blur_range
1662         self.stride = None
1663         self.separable = None
1664         self.kernels = None
1665         self.convnd = None
1666         super(GaussianBlur, self).__init__(**kwargs)
1667 
1668     def get_config(self):
1669         config = super().get_config()
1670         config["sigma"] = self.sigma
1671         config["random_blur_range"] = self.blur_range
1672         config["use_mask"] = self.use_mask
1673         return config
1674 
1675     def build(self, input_shape):
1676 
1677         # get shapes
1678         if self.use_mask:
1679             assert len(input_shape) == 2, 'please provide a mask as second layer input when use_mask=True'
1680             self.n_dims = len(input_shape[0]) - 2
1681             self.n_channels = input_shape[0][-1]
1682         else:
1683             self.n_dims = len(input_shape) - 2
1684             self.n_channels = input_shape[-1]
1685 
1686         # prepare blurring kernel
1687         self.stride = [1]*(self.n_dims+2)
1688         self.sigma = reformat_to_list(self.sigma, length=self.n_dims)
1689         self.separable = np.linalg.norm(np.array(self.sigma)) > 5
1690         if self.blur_range is None:  # fixed kernels
1691             self.kernels = gaussian_kernel(self.sigma, separable=self.separable)
1692         else:
1693             self.kernels = None
1694 
1695         # prepare convolution
1696         self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims)
1697 
1698         self.built = True
1699         super(GaussianBlur, self).build(input_shape)
1700 
1701     def call(self, inputs, **kwargs):
1702 
1703         if self.use_mask:
1704             image = inputs[0]
1705             mask = tf.cast(inputs[1], 'bool')
1706         else:
1707             image = inputs
1708             mask = None
1709 
1710         # redefine the kernels at each new step when blur_range is activated
1711         if self.blur_range is not None:
1712             self.kernels = gaussian_kernel(self.sigma, blur_range=self.blur_range, separable=self.separable)
1713 
1714         if self.separable:
1715             for k in self.kernels:
1716                 if k is not None:
1717                     image = tf.concat([self.convnd(tf.expand_dims(image[..., n], -1), k, self.stride, 'SAME')
1718                                        for n in range(self.n_channels)], -1)
1719                     if self.use_mask:
1720                         maskb = tf.cast(mask, 'float32')
1721                         maskb = tf.concat([self.convnd(tf.expand_dims(maskb[..., n], -1), k, self.stride, 'SAME')
1722                                            for n in range(self.n_channels)], -1)
1723                         image = image / (maskb + keras.backend.epsilon())
1724                         image = tf.where(mask, image, tf.zeros_like(image))
1725         else:
1726             if any(self.sigma):
1727                 image = tf.concat([self.convnd(tf.expand_dims(image[..., n], -1), self.kernels, self.stride, 'SAME')
1728                                    for n in range(self.n_channels)], -1)
1729                 if self.use_mask:
1730                     maskb = tf.cast(mask, 'float32')
1731                     maskb = tf.concat([self.convnd(tf.expand_dims(maskb[..., n], -1), self.kernels, self.stride, 'SAME')
1732                                        for n in range(self.n_channels)], -1)
1733                     image = image / (maskb + keras.backend.epsilon())
1734                     image = tf.where(mask, image, tf.zeros_like(image))
1735 
1736         return image
1737 
1738 
1739 class ConvertLabels(KL.Layer):
1740 
1741     def __init__(self, source_values, dest_values=None, **kwargs):
1742         self.source_values = source_values
1743         self.dest_values = dest_values
1744         self.lut = None
1745         super(ConvertLabels, self).__init__(**kwargs)
1746 
1747     def get_config(self):
1748         config = super().get_config()
1749         config["source_values"] = self.source_values
1750         config["dest_values"] = self.dest_values
1751         return config
1752 
1753     def build(self, input_shape):
1754         self.lut = tf.convert_to_tensor(get_mapping_lut(self.source_values, dest=self.dest_values), dtype='int32')
1755         self.built = True
1756         super(ConvertLabels, self).build(input_shape)
1757 
1758     def call(self, inputs, **kwargs):
1759         return tf.gather(self.lut, tf.cast(inputs, dtype='int32'))
1760 
1761 
1762 
1763 
1764 # execute script
1765 if __name__ == '__main__':
1766     main()

Attached Files

To refer to attachments on a page, use attachment:filename, as shown below in the list of files. Do NOT use the URL of the [get] link, since this is subject to change and can break easily.
  • [get | view] (2025-10-23 15:15:18, 70.7 KB) [[attachment:mri_easyatlas.py]]
 All files | Selected Files: delete move to page copy to page

You are not allowed to attach a file to this page.