FunctionCtx
FunctionCtx()Per-call context shared between Function.forward and
Function.backward.
A fresh FunctionCtx is created on every Function.apply
invocation. forward populates it with anything backward
will need: saved tensors (via save_for_backward),
non-differentiable output markers (via
mark_non_differentiable), and arbitrary user-defined
attributes (cached shapes, axis indices, scalar hyperparameters,
...) set with ordinary attribute assignment. The context is the
only legal channel for passing state from forward to backward —
capturing tensors through Python closures bypasses autograd's
bookkeeping and leaks memory.
Parameters
NoneFunctionCtx is instantiated by Function.apply
with no arguments. User code never constructs one
directly; it receives the instance as the first
positional argument of forward / backward.Attributes
needs_input_gradtuple of boolTensor input to forward,
indicating whether autograd would propagate a gradient to
that input. Use this to skip unneeded branches in
backward.saved_tensorstuple of Tensorsave_for_backward, in registration order.Notes
The context is what ties the forward and backward halves of a custom node together in the chain rule:
where is the upstream gradient and the same
ctx object is passed to both halves so can read
back whatever saved.
Examples
>>> import lucid
>>> from lucid.autograd import Function
>>> class Square(Function):
... @staticmethod
... def forward(ctx, x):
... ctx.save_for_backward(x)
... ctx.shape = x.shape
... return x * x
... @staticmethod
... def backward(ctx, grad_out):
... (x,) = ctx.saved_tensors
... return 2 * x * grad_outMethods (4)
__init__
→None__init__()Initialise an empty context with no saved tensors or extras.
save_for_backward
→Nonesave_for_backward(tensors: Tensor = ())Store tensors needed to compute the backward pass.
Parameters
*tensorsTensor= ()backward implementation will
require. They are retrieved later via the saved_tensors
property as a tuple in the same order.Notes
Each call replaces any tensors previously saved on this context.
saved_tensors
→tuple of Tensorsaved_tensors: tuple[Tensor, ...]Read back the tensors saved during forward.
Returns
tuple of TensorThe tensors stored by save_for_backward, wrapped
back into Python Tensor instances if the engine stored
raw TensorImpl handles.
mark_non_differentiable
→Nonemark_non_differentiable(tensors: Tensor = ())Declare that the given output tensors carry no gradient.
Parameters
*tensorsTensor= ()Function.forward for which autograd should
not propagate gradients (e.g., integer indices, masks).