diff --git a/src/utils/utils.py b/src/utils/utils.py
index 8e0a7c77880c4f6993f44861ee7b1bdf5249ecc7..6bdadb11b67d6030c3da6338ed023c1f6ec89f3b 100644
--- a/src/utils/utils.py
+++ b/src/utils/utils.py
@@ -1,5 +1,5 @@
 import math
-import weakref
+from collections import OrderedDict
 
 import torch
 import numpy as np
@@ -21,61 +21,92 @@ def mod_crop(blob, N):
             return blob[..., :nh, :nw]
 
 
+class FeatureContainer:
+    r"""A simple wrapper for OrderedDict."""
+    def __init__(self):
+        self._dict = OrderedDict()
+
+    def __setitem__(self, key, val):
+        if key not in self._dict:
+            self._dict[key] = list()
+        self._dict[key].append(val)
+
+    def __getitem__(self, key):
+        return self._dict[key]
+
+    def __repr__(self):
+        return self._dict.__repr__()
+
+    def items(self):
+        return self._dict.items()
+
+    def keys(self):
+        return self._dict.keys()
+
+    def values(self):
+        return self._dict.values()
+
+
 class HookHelper:
     def __init__(self, model, fetch_dict, out_dict, hook_type='forward_out'):
-        self.model = weakref.proxy(model)
+        # XXX: A HookHelper object should only be used as a context manager and should not 
+        # persist in memory since it may keep references to some very large objects.
+        self.model = model
         self.fetch_dict = fetch_dict
-        # Subclass the built-in list to make it weak referenceable
-        class _list(list):
-            pass
-        for entry in self.fetch_dict.values():
-            # entry is expected to be a string or a non-nested tuple
-            if isinstance(entry, tuple):
-                for key in entry:
-                    out_dict[key] = _list()
-            else:
-                out_dict[entry] = _list()
-        self.out_dict = weakref.WeakValueDictionary(out_dict)
+        self.out_dict = out_dict
         self._handles = []
 
-        if hook_type not in ('forward_in', 'forward_out', 'backward_out'):
+        if hook_type not in ('forward_in', 'forward_out', 'backward'):
             raise NotImplementedError("Hook type is not implemented.")
+        self.hook_type = hook_type
 
-        def _proto_hook(x, entry):
-            # x should be a tensor or a tuple
+    def __enter__(self):
+        def _proto_forward_hook(x, entry):
+            # x should be a tensor or a tuple;
+            # entry is expected to be a string or a non-nested tuple.
             if isinstance(entry, tuple):
                 for key, f in zip(entry, x):
-                    self.out_dict[key].append(f.detach().clone())
+                    self.out_dict[key] = f.data.clone()
             else:
-                self.out_dict[entry].append(x.detach().clone())
-
-        def _forward_in_hook(m, x, y, entry):
-            # x is a tuple
-            return _proto_hook(x[0] if len(x)==1 else x, entry)
-
-        def _forward_out_hook(m, x, y, entry):
-            # y is a tensor or a tuple
-            return _proto_hook(y, entry)
-
-        def _backward_out_hook(m, grad_in, grad_out, entry):
-            # grad_out is a tuple
-            return _proto_hook(grad_out[0] if len(grad_out)==1 else grad_out, entry)
+                self.out_dict[entry] = x.data.clone()
 
-        self._hook_func, self._reg_func_name = {
-            'forward_in': (_forward_in_hook, 'register_forward_hook'),
-            'forward_out': (_forward_out_hook, 'register_forward_hook'),
-            'backward_out': (_backward_out_hook, 'register_backward_hook'),
-        }[hook_type]
-
-    def __enter__(self):
-        for name, module in self.model.named_modules():
-            if name in self.fetch_dict:
-                entry = self.fetch_dict[name]
-                self._handles.append(
-                    getattr(module, self._reg_func_name)(
-                        lambda *args, entry=entry: self._hook_func(*args, entry=entry)
+        if self.hook_type == 'forward_in':
+            # NOTE: Register forward hooks for MODULEs.
+            for name, module in self.model.named_modules():
+                if name in self.fetch_dict:
+                    entry = self.fetch_dict[name]
+                    self._handles.append(
+                        module.register_forward_hook(
+                            lambda m, x, y, entry=entry:
+                                # x is a tuple
+                                _proto_forward_hook(x[0] if len(x)==1 else x, entry)
+                        )
+                    )
+        elif self.hook_type == 'forward_out':
+            # NOTE: Register forward hooks for MODULEs.
+            for name, module in self.model.named_modules():
+                if name in self.fetch_dict:
+                    entry = self.fetch_dict[name]
+                    self._handles.append(
+                        module.register_forward_hook(
+                            lambda m, x, y, entry=entry:
+                                # y is a tensor or a tuple
+                                _proto_forward_hook(y, entry)
+                        )
+                    )
+        elif self.hook_type == 'backward':
+            # NOTE: Register backward hooks for TENSORs.
+            for name, param in self.model.named_parameters():
+                if name in self.fetch_dict:
+                    entry = self.fetch_dict[name]
+                    self._handles.append(
+                        param.register_hook(
+                            lambda grad, entry=entry:
+                                _proto_forward_hook(grad, entry)
+                        )
                     )
-                )
+        else:
+            raise NotImplementedError
 
     def __exit__(self, exc_type, exc_val, ext_tb):
         for handle in self._handles: