fn

register_module_load_state_dict_post_hook

RemovableHandle
register_module_load_state_dict_post_hook(hook: Callable[..., object])
source

Register a global post-hook fired after every Module's load_state_dict.

Invoked after the copy from state-dict to parameters has happened. The hook receives the module and a NamedTuple-like incompatible_keys object summarising any missing_keys / unexpected_keys that occurred — useful for warning the user, clearing those keys to silence strict=True errors, or running follow-up reinitialisation on any parameter that wasn't loaded.

Parameters

hookCallable
Signature: hook(module, incompatible_keys) -> None. incompatible_keys has two list attributes / fields: missing_keys and unexpected_keys. Mutating either list is reflected in the caller's error handling.

Returns

RemovableHandle

Handle for later deregistration.

Notes

Common application: re-initialise newly-added layers after loading a legacy checkpoint that doesn't contain entries for them — keeps the rest of the model from being clobbered while bringing the new parameters to a sensible starting point.

Examples

>>> from lucid.nn.hooks import register_module_load_state_dict_post_hook
>>> def reinit_missing(mod, incompatible):
...     for key in incompatible.missing_keys:
...         print(f'(post-load) {key} not in checkpoint — keeping init values')
>>> h = register_module_load_state_dict_post_hook(reinit_missing)