Skip to content
Snippets Groups Projects
Commit 39d0b776 authored by Bobholamovic's avatar Bobholamovic
Browse files

Fix direct param not updated by optimizer

parent 31e99e9f
Branches
No related tags found
1 merge request!2Update outdated code
...@@ -274,12 +274,16 @@ def optim_factory(optim_names, models, C): ...@@ -274,12 +274,16 @@ def optim_factory(optim_names, models, C):
optims = [] optims = []
for name, model in zip(name_list, models): for name, model in zip(name_list, models):
param_groups = [{'params': module.parameters(), 'name': module_name} for module_name, module in model.named_children()] param_groups = [{'params': module.parameters(), 'name': module_name} for module_name, module in model.named_children()]
if next(model.parameters(recurse=False), None) is not None:
param_groups.append({'params': model.parameters(recurse=False), 'name': '_direct'})
optims.append(single_optim_factory(name, param_groups, C)) optims.append(single_optim_factory(name, param_groups, C))
return DuckOptimizer(*optims) return DuckOptimizer(*optims)
else: else:
return single_optim_factory( return single_optim_factory(
optim_names, optim_names,
[{'params': module.parameters(), 'name': module_name} for module_name, module in models.named_children()], [{'params': module.parameters(), 'name': module_name} for module_name, module in models.named_children()] +
([{'params': models.parameters(recurse=False), 'name': '_direct'}]
if next(models.parameters(recurse=False), None) is not None else []),
C C
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment