-
Notifications
You must be signed in to change notification settings - Fork 38
/
utils.py
58 lines (44 loc) · 1.53 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# Authors:
# Christian F. Baumgartner (c.f.baumgartner@gmail.com)
# Lisa M. Koch (lisa.margret.koch@gmail.com)
import nibabel as nib
import numpy as np
import os
import glob
def makefolder(folder):
'''
Helper function to make a new folder if doesn't exist
:param folder: path to new folder
:return: True if folder created, False if folder already exists
'''
if not os.path.exists(folder):
os.makedirs(folder)
return True
return False
def load_nii(img_path):
'''
Shortcut to load a nifti file
'''
nimg = nib.load(img_path)
return nimg.get_data(), nimg.affine, nimg.header
def save_nii(img_path, data, affine, header):
'''
Shortcut to save a nifty file
'''
nimg = nib.Nifti1Image(data, affine=affine, header=header)
nimg.to_filename(img_path)
def get_latest_model_checkpoint_path(folder, name):
'''
Returns the checkpoint with the highest iteration number with a given name
:param folder: Folder where the checkpoints are saved
:param name: Name under which you saved the model
:return: The path to the checkpoint with the latest iteration
'''
iteration_nums = []
for file in glob.glob(os.path.join(folder, '%s*.meta' % name)):
file = file.split('/')[-1]
file_base, postfix_and_number, rest = file.split('.')[0:3]
it_num = int(postfix_and_number.split('-')[-1])
iteration_nums.append(it_num)
latest_iteration = np.max(iteration_nums)
return os.path.join(folder, name + '-' + str(latest_iteration))