Skip to content
Snippets Groups Projects
Commit f68ed034 authored by Bobholamovic's avatar Bobholamovic
Browse files

Disable module backward hooks in HookHelper

parent 3fe47d14
No related branches found
No related tags found
1 merge request!2Update outdated code
import math
import weakref
from collections import OrderedDict
import torch
import numpy as np
......@@ -21,61 +21,92 @@ def mod_crop(blob, N):
return blob[..., :nh, :nw]
class FeatureContainer:
r"""A simple wrapper for OrderedDict."""
def __init__(self):
self._dict = OrderedDict()
def __setitem__(self, key, val):
if key not in self._dict:
self._dict[key] = list()
self._dict[key].append(val)
def __getitem__(self, key):
return self._dict[key]
def __repr__(self):
return self._dict.__repr__()
def items(self):
return self._dict.items()
def keys(self):
return self._dict.keys()
def values(self):
return self._dict.values()
class HookHelper:
def __init__(self, model, fetch_dict, out_dict, hook_type='forward_out'):
self.model = weakref.proxy(model)
# XXX: A HookHelper object should only be used as a context manager and should not
# persist in memory since it may keep references to some very large objects.
self.model = model
self.fetch_dict = fetch_dict
# Subclass the built-in list to make it weak referenceable
class _list(list):
pass
for entry in self.fetch_dict.values():
# entry is expected to be a string or a non-nested tuple
if isinstance(entry, tuple):
for key in entry:
out_dict[key] = _list()
else:
out_dict[entry] = _list()
self.out_dict = weakref.WeakValueDictionary(out_dict)
self.out_dict = out_dict
self._handles = []
if hook_type not in ('forward_in', 'forward_out', 'backward_out'):
if hook_type not in ('forward_in', 'forward_out', 'backward'):
raise NotImplementedError("Hook type is not implemented.")
self.hook_type = hook_type
def _proto_hook(x, entry):
# x should be a tensor or a tuple
def __enter__(self):
def _proto_forward_hook(x, entry):
# x should be a tensor or a tuple;
# entry is expected to be a string or a non-nested tuple.
if isinstance(entry, tuple):
for key, f in zip(entry, x):
self.out_dict[key].append(f.detach().clone())
self.out_dict[key] = f.data.clone()
else:
self.out_dict[entry].append(x.detach().clone())
def _forward_in_hook(m, x, y, entry):
# x is a tuple
return _proto_hook(x[0] if len(x)==1 else x, entry)
def _forward_out_hook(m, x, y, entry):
# y is a tensor or a tuple
return _proto_hook(y, entry)
def _backward_out_hook(m, grad_in, grad_out, entry):
# grad_out is a tuple
return _proto_hook(grad_out[0] if len(grad_out)==1 else grad_out, entry)
self.out_dict[entry] = x.data.clone()
self._hook_func, self._reg_func_name = {
'forward_in': (_forward_in_hook, 'register_forward_hook'),
'forward_out': (_forward_out_hook, 'register_forward_hook'),
'backward_out': (_backward_out_hook, 'register_backward_hook'),
}[hook_type]
def __enter__(self):
for name, module in self.model.named_modules():
if name in self.fetch_dict:
entry = self.fetch_dict[name]
self._handles.append(
getattr(module, self._reg_func_name)(
lambda *args, entry=entry: self._hook_func(*args, entry=entry)
if self.hook_type == 'forward_in':
# NOTE: Register forward hooks for MODULEs.
for name, module in self.model.named_modules():
if name in self.fetch_dict:
entry = self.fetch_dict[name]
self._handles.append(
module.register_forward_hook(
lambda m, x, y, entry=entry:
# x is a tuple
_proto_forward_hook(x[0] if len(x)==1 else x, entry)
)
)
elif self.hook_type == 'forward_out':
# NOTE: Register forward hooks for MODULEs.
for name, module in self.model.named_modules():
if name in self.fetch_dict:
entry = self.fetch_dict[name]
self._handles.append(
module.register_forward_hook(
lambda m, x, y, entry=entry:
# y is a tensor or a tuple
_proto_forward_hook(y, entry)
)
)
elif self.hook_type == 'backward':
# NOTE: Register backward hooks for TENSORs.
for name, param in self.model.named_parameters():
if name in self.fetch_dict:
entry = self.fetch_dict[name]
self._handles.append(
param.register_hook(
lambda grad, entry=entry:
_proto_forward_hook(grad, entry)
)
)
)
else:
raise NotImplementedError
def __exit__(self, exc_type, exc_val, ext_tb):
for handle in self._handles:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment