diff --git a/src/train.py b/src/train.py index 9b8349a4a3947d3d1d15401db53b986836f44440..9a21a8b86571c9c89e7b8c4cf3d2f6d9b7329533 100644 --- a/src/train.py +++ b/src/train.py @@ -4,7 +4,7 @@ import os import shutil import random import ast -from os.path import basename, exists +from os.path import basename, exists, splitext import torch import torch.backends.cudnn as cudnn @@ -16,15 +16,14 @@ from utils.misc import OutPathGetter, Logger, register def read_config(config_path): - f = open(config_path, 'r') - cfg = yaml.load(f.read(), Loader=yaml.FullLoader) - f.close() + with open(config_path, 'r') as f: + cfg = yaml.load(f.read(), Loader=yaml.FullLoader) return cfg or {} def parse_config(cfg_name, cfg): # Parse the name of config file - sp = cfg_name.split('.')[0].split('_') + sp = splitext(cfg_name)[0].split('_') if len(sp) >= 2: cfg.setdefault('tag', sp[1]) cfg.setdefault('suffix', '_'.join(sp[2:]))