Module Hooks¶
This page documents nn.Module hook APIs. Hooks let you observe or modify inputs, outputs, and gradients during forward and backward passes, as well as customize state_dict save/load flows.
Overview¶
Hooks are registered on a module and return a remover callable. Call the remover to detach the hook.
Tip
Use hooks for debugging, logging, and lightweight instrumentation. For core model behavior, prefer explicit code in forward.
Warning
Hooks run inside the forward/backward path. Heavy work inside hooks can slow training or change determinism.
Forward Hooks¶
Forward pre-hook¶
def register_forward_pre_hook(self, hook: Callable, *, with_kwargs: bool = False)
Signature:
hook(module, args) -> args | None
If with_kwargs=True: hook(module, args, kwargs) -> (args, kwargs) | None
Caution
If you return new inputs from a pre-hook, ensure shapes/dtypes are compatible with the module’s forward.
Example:
import lucid
import lucid.nn as nn
class Scale(nn.Module):
def forward(self, x):
return x * 2
def pre_hook(module, args):
(x,) = args
return (x + 1,)
m = Scale()
remove = m.register_forward_pre_hook(pre_hook)
x = lucid.ones((2, 2))
y = m(x) # effectively (x + 1) * 2
remove()
Forward hook¶
def register_forward_hook(self, hook: Callable, *, with_kwargs: bool = False)
Signature:
hook(module, args, output) -> output | None
If with_kwargs=True: hook(module, args, kwargs, output) -> output | None
Note
Returning None keeps the original output. Returning a value replaces the output seen by downstream modules.
Example:
import lucid
import lucid.nn as nn
class LinearBias(nn.Module):
def __init__(self):
super().__init__()
self.b = nn.Parameter(lucid.ones((1,)))
def forward(self, x):
return x + self.b
def post_hook(module, args, output):
return output * 3
m = LinearBias()
remove = m.register_forward_hook(post_hook)
x = lucid.ones((1,))
y = m(x) # output is multiplied by 3
remove()
Backward Hooks¶
Backward hook (output tensor)¶
def register_backward_hook(self, hook: Callable)
Signature:
hook(tensor, grad) -> None
Important
This hook attaches to the module’s output tensor and only runs when the module returns a single Tensor.
Example:
import lucid
import lucid.nn as nn
class Square(nn.Module):
def forward(self, x):
return x * x
def grad_hook(tensor, grad):
print("grad:", grad)
m = Square()
m.register_backward_hook(grad_hook)
x = lucid.ones((1,), requires_grad=True)
y = m(x)
y.backward()
Full backward pre-hook¶
def register_full_backward_pre_hook(self, hook: Callable)
Signature:
hook(module, grad_output_tuple) -> grad_output_tuple | None
Note
grad_output_tuple contains gradients for each Tensor output. Non-Tensor outputs are omitted.
Example:
import lucid
import lucid.nn as nn
class Add(nn.Module):
def forward(self, x, y):
return x + y
def pre_full_backward(module, grad_out):
print("grad_out:", grad_out)
return grad_out
m = Add()
m.register_full_backward_pre_hook(pre_full_backward)
x = lucid.ones((1,), requires_grad=True)
y = lucid.ones((1,), requires_grad=True)
out = m(x, y)
out.backward()
Full backward hook¶
def register_full_backward_hook(self, hook: Callable)
Signature:
hook(module, grad_input_tuple, grad_output_tuple) -> None
Note
grad_input_tuple is aligned with positional inputs only. Keyword-only inputs are not included and non-Tensor inputs appear as None.
Example:
import lucid
import lucid.nn as nn
class Mul(nn.Module):
def forward(self, x, y):
return x * y
def full_backward(module, grad_in, grad_out):
print("grad_in:", grad_in)
print("grad_out:", grad_out)
m = Mul()
m.register_full_backward_hook(full_backward)
x = lucid.ones((1,), requires_grad=True)
y = lucid.ones((1,), requires_grad=True)
out = m(x, y)
out.backward()
State Dict Hooks¶
State dict pre-hook¶
def register_state_dict_pre_hook(self, hook: Callable)
Signature:
hook(module, prefix, keep_vars) -> None
Tip
Use this to set up temporary metadata or logging before a save.
Example:
import lucid.nn as nn
def pre_state(module, prefix, keep_vars):
print("saving with prefix:", prefix)
m = nn.Module()
m.register_state_dict_pre_hook(pre_state)
_ = m.state_dict()
State dict hook¶
def register_state_dict_hook(self, hook: Callable)
Signature:
hook(module, state_dict, prefix, keep_vars) -> None
Warning
Mutating the state_dict changes what gets saved. Keep changes minimal and well-documented.
Example:
import lucid.nn as nn
def post_state(module, state_dict, prefix, keep_vars):
state_dict[prefix + "note"] = "custom"
m = nn.Module()
m.register_state_dict_hook(post_state)
sd = m.state_dict()
Load state dict pre-hook¶
def register_load_state_dict_pre_hook(self, hook: Callable)
Signature:
hook(module, state_dict, strict) -> None
Caution
If you modify keys here, ensure they still match the module’s current structure when strict=True.
Example:
import lucid.nn as nn
def pre_load(module, state_dict, strict):
state_dict.pop("legacy_key", None)
m = nn.Module()
m.register_load_state_dict_pre_hook(pre_load)
m.load_state_dict({})
Load state dict post-hook¶
def register_load_state_dict_post_hook(self, hook: Callable)
Signature:
hook(module, missing_keys, unexpected_keys, strict) -> None
Tip
This is a good place to emit warnings or metrics when keys are missing.
Example:
import lucid.nn as nn
def post_load(module, missing, unexpected, strict):
if missing:
print("missing:", missing)
m = nn.Module()
m.register_load_state_dict_post_hook(post_load)
m.load_state_dict({}, strict=False)
Notes¶
register_backward_hook attaches to the output tensor. It only runs when the module returns a single Tensor.
grad_input_tuple in full backward hooks is aligned with positional args. Non-Tensor inputs appear as None. Keyword-only inputs are not included.
grad_output_tuple is built from Tensor outputs only.
Hook registration returns a callable to remove the hook.