-
Notifications
You must be signed in to change notification settings - Fork 2
/
common.py
122 lines (100 loc) · 4.11 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import sys
import yaml
import random
import logging
import numpy as np
import pandas as pd
import tensorflow as tf
from datetime import datetime
def set_seed(SEED=42):
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)
def get_logger(name):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter(fmt='%(asctime)s %(levelname)-8s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
screen_handler = logging.StreamHandler(stream=sys.stdout)
screen_handler.setFormatter(formatter)
logger.addHandler(screen_handler)
return logger
def get_session(args):
assert int(tf.__version__.split('.')[0]) >= 2.0
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
if args.gpus != '-1':
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
# Memory growth must be set before GPUs have been initialized
print(e)
def create_stamp():
weekday = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
temp = datetime.now()
return "{:02d}{:02d}{:02d}_{}_{:02d}_{:02d}_{:02d}".format(
temp.year // 100,
temp.month,
temp.day,
weekday[temp.weekday()],
temp.hour,
temp.minute,
temp.second,
)
def search_same(args):
search_ignore = ['checkpoint', 'history', 'snapshot', 'summary',
'src_path', 'data_path', 'result_path',
'epochs', 'stamp', 'gpus', 'ignore_search']
if len(args.ignore_search) > 0:
search_ignore += args.ignore_search.split(',')
initial_epoch = 0
stamps = os.listdir(os.path.join(args.result_path, args.dataset))
for stamp in stamps:
try:
desc = yaml.full_load(
open(os.path.join(
args.result_path, '{}/{}/model_desc.yml'.format(args.dataset, stamp))))
except:
continue
flag = True
for k, v in vars(args).items():
if k in search_ignore:
continue
if v != desc[k]:
# if stamp == '200903_Thu_05_38_31':
# print(stamp, k, desc[k], v)
flag = False
break
if flag:
args.stamp = stamp
try:
df = pd.read_csv(
os.path.join(
args.result_path,
'{}/{}/history/epoch.csv'.format(args.dataset, args.stamp)))
except:
continue
if len(df) > 0:
if int(df['epoch'].values[-1]+1) == args.epochs:
print('{} Training already finished!!!'.format(stamp))
return args, -1
elif np.isnan(df['val_loss'].values[-1]) or np.isinf(df['val_loss'].values[-1]):
print('{} | Epoch {:04d}: Invalid loss, terminating training'.format(stamp, int(df['epoch'].values[-1]+1)))
return args, -1
else:
ckpt_list = sorted([d for d in os.listdir(os.path.join(args.result_path, '{}/{}/checkpoint'.format(args.dataset, args.stamp))) if 'h5' in d],
key=lambda x: int(x.split('_')[0]))
if len(ckpt_list) > 0:
args.snapshot = os.path.join(
args.result_path,
'{}/{}/checkpoint/{}'.format(args.dataset, args.stamp, ckpt_list[-1]))
initial_epoch = int(ckpt_list[-1].split('_')[0])
else:
print('{} Training already finished!!!'.format(stamp))
return args, -1
break
return args, initial_epoch