From c8bfd29b458e68e9b5dfefe6fd4c43f1dd3e7c34 Mon Sep 17 00:00:00 2001
From: Bobholamovic <bob1998425@hotmail.com>
Date: Sat, 1 Feb 2020 11:18:12 +0800
Subject: [PATCH] Fix load yaml

---
 src/train.py | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/src/train.py b/src/train.py
index 9b8349a..9a21a8b 100644
--- a/src/train.py
+++ b/src/train.py
@@ -4,7 +4,7 @@ import os
 import shutil
 import random
 import ast
-from os.path import basename, exists
+from os.path import basename, exists, splitext
 
 import torch
 import torch.backends.cudnn as cudnn
@@ -16,15 +16,14 @@ from utils.misc import OutPathGetter, Logger, register
 
 
 def read_config(config_path):
-    f = open(config_path, 'r')
-    cfg = yaml.load(f.read(), Loader=yaml.FullLoader)
-    f.close()
+    with open(config_path, 'r') as f:
+        cfg = yaml.load(f.read(), Loader=yaml.FullLoader)
     return cfg or {}
 
 
 def parse_config(cfg_name, cfg):
     # Parse the name of config file
-    sp = cfg_name.split('.')[0].split('_')
+    sp = splitext(cfg_name)[0].split('_')
     if len(sp) >= 2:
         cfg.setdefault('tag', sp[1])
         cfg.setdefault('suffix', '_'.join(sp[2:]))
-- 
GitLab