fn
register_module_load_state_dict_post_hook
→RemovableHandleregister_module_load_state_dict_post_hook(hook: Callable[..., object])Register a global post-hook fired after every Module's load_state_dict.
Invoked after the copy from state-dict to parameters has happened.
The hook receives the module and a NamedTuple-like incompatible_keys
object summarising any missing_keys / unexpected_keys that
occurred — useful for warning the user, clearing those keys to
silence strict=True errors, or running follow-up reinitialisation
on any parameter that wasn't loaded.
Parameters
hookCallableSignature:
hook(module, incompatible_keys) -> None.
incompatible_keys has two list attributes / fields:
missing_keys and unexpected_keys. Mutating either list
is reflected in the caller's error handling.Returns
RemovableHandleHandle for later deregistration.
Notes
Common application: re-initialise newly-added layers after loading a legacy checkpoint that doesn't contain entries for them — keeps the rest of the model from being clobbered while bringing the new parameters to a sensible starting point.
Examples
>>> from lucid.nn.hooks import register_module_load_state_dict_post_hook
>>> def reinit_missing(mod, incompatible):
... for key in incompatible.missing_keys:
... print(f'(post-load) {key} not in checkpoint — keeping init values')
>>> h = register_module_load_state_dict_post_hook(reinit_missing)