Skip to content
Snippets Groups Projects

Update outdated code

Open manli requested to merge github/fork/Bobholamovic/master into master
1 file
+ 75
44
Compare changes
  • Side-by-side
  • Inline
+ 75
44
import math
import weakref
from collections import OrderedDict
import torch
import numpy as np
@@ -21,61 +21,92 @@ def mod_crop(blob, N):
return blob[..., :nh, :nw]
class FeatureContainer:
r"""A simple wrapper for OrderedDict."""
def __init__(self):
self._dict = OrderedDict()
def __setitem__(self, key, val):
if key not in self._dict:
self._dict[key] = list()
self._dict[key].append(val)
def __getitem__(self, key):
return self._dict[key]
def __repr__(self):
return self._dict.__repr__()
def items(self):
return self._dict.items()
def keys(self):
return self._dict.keys()
def values(self):
return self._dict.values()
class HookHelper:
def __init__(self, model, fetch_dict, out_dict, hook_type='forward_out'):
self.model = weakref.proxy(model)
# XXX: A HookHelper object should only be used as a context manager and should not
# persist in memory since it may keep references to some very large objects.
self.model = model
self.fetch_dict = fetch_dict
# Subclass the built-in list to make it weak referenceable
class _list(list):
pass
for entry in self.fetch_dict.values():
# entry is expected to be a string or a non-nested tuple
if isinstance(entry, tuple):
for key in entry:
out_dict[key] = _list()
else:
out_dict[entry] = _list()
self.out_dict = weakref.WeakValueDictionary(out_dict)
self.out_dict = out_dict
self._handles = []
if hook_type not in ('forward_in', 'forward_out', 'backward_out'):
if hook_type not in ('forward_in', 'forward_out', 'backward'):
raise NotImplementedError("Hook type is not implemented.")
self.hook_type = hook_type
def _proto_hook(x, entry):
# x should be a tensor or a tuple
def __enter__(self):
def _proto_forward_hook(x, entry):
# x should be a tensor or a tuple;
# entry is expected to be a string or a non-nested tuple.
if isinstance(entry, tuple):
for key, f in zip(entry, x):
self.out_dict[key].append(f.detach().clone())
self.out_dict[key] = f.data.clone()
else:
self.out_dict[entry].append(x.detach().clone())
def _forward_in_hook(m, x, y, entry):
# x is a tuple
return _proto_hook(x[0] if len(x)==1 else x, entry)
def _forward_out_hook(m, x, y, entry):
# y is a tensor or a tuple
return _proto_hook(y, entry)
def _backward_out_hook(m, grad_in, grad_out, entry):
# grad_out is a tuple
return _proto_hook(grad_out[0] if len(grad_out)==1 else grad_out, entry)
self.out_dict[entry] = x.data.clone()
self._hook_func, self._reg_func_name = {
'forward_in': (_forward_in_hook, 'register_forward_hook'),
'forward_out': (_forward_out_hook, 'register_forward_hook'),
'backward_out': (_backward_out_hook, 'register_backward_hook'),
}[hook_type]
def __enter__(self):
for name, module in self.model.named_modules():
if name in self.fetch_dict:
entry = self.fetch_dict[name]
self._handles.append(
getattr(module, self._reg_func_name)(
lambda *args, entry=entry: self._hook_func(*args, entry=entry)
if self.hook_type == 'forward_in':
# NOTE: Register forward hooks for MODULEs.
for name, module in self.model.named_modules():
if name in self.fetch_dict:
entry = self.fetch_dict[name]
self._handles.append(
module.register_forward_hook(
lambda m, x, y, entry=entry:
# x is a tuple
_proto_forward_hook(x[0] if len(x)==1 else x, entry)
)
)
elif self.hook_type == 'forward_out':
# NOTE: Register forward hooks for MODULEs.
for name, module in self.model.named_modules():
if name in self.fetch_dict:
entry = self.fetch_dict[name]
self._handles.append(
module.register_forward_hook(
lambda m, x, y, entry=entry:
# y is a tensor or a tuple
_proto_forward_hook(y, entry)
)
)
elif self.hook_type == 'backward':
# NOTE: Register backward hooks for TENSORs.
for name, param in self.model.named_parameters():
if name in self.fetch_dict:
entry = self.fetch_dict[name]
self._handles.append(
param.register_hook(
lambda grad, entry=entry:
_proto_forward_hook(grad, entry)
)
)
)
else:
raise NotImplementedError
def __exit__(self, exc_type, exc_val, ext_tb):
for handle in self._handles:
Loading