From c8bfd29b458e68e9b5dfefe6fd4c43f1dd3e7c34 Mon Sep 17 00:00:00 2001 From: Bobholamovic <bob1998425@hotmail.com> Date: Sat, 1 Feb 2020 11:18:12 +0800 Subject: [PATCH] Fix load yaml --- src/train.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/train.py b/src/train.py index 9b8349a..9a21a8b 100644 --- a/src/train.py +++ b/src/train.py @@ -4,7 +4,7 @@ import os import shutil import random import ast -from os.path import basename, exists +from os.path import basename, exists, splitext import torch import torch.backends.cudnn as cudnn @@ -16,15 +16,14 @@ from utils.misc import OutPathGetter, Logger, register def read_config(config_path): - f = open(config_path, 'r') - cfg = yaml.load(f.read(), Loader=yaml.FullLoader) - f.close() + with open(config_path, 'r') as f: + cfg = yaml.load(f.read(), Loader=yaml.FullLoader) return cfg or {} def parse_config(cfg_name, cfg): # Parse the name of config file - sp = cfg_name.split('.')[0].split('_') + sp = splitext(cfg_name)[0].split('_') if len(sp) >= 2: cfg.setdefault('tag', sp[1]) cfg.setdefault('suffix', '_'.join(sp[2:])) -- GitLab