class

Module

Module()
source

Base class for all neural network modules.

Every custom model should subclass this and implement forward. Submodules assigned as attributes are tracked automatically.

Notes

Attribute routing: Setting an attribute follows this priority order:

  1. If the value is a lucid.nn.Parameter → stored in _parameters.
  2. If the value is a Module → stored in _modules.
  3. Otherwise → plain Python attribute.

To register a non-parameter tensor (e.g. a running mean), call register_buffer explicitly.

Examples

>>> class MLP(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.fc1 = nn.Linear(10, 20)
...         self.fc2 = nn.Linear(20, 1)
...
...     def forward(self, x):
...         return self.fc2(lucid.relu(self.fc1(x)))
...
>>> model = MLP()
>>> model(lucid.randn(4, 10)).shape
(4, 1)

Methods (47)

dunder

__init__

None
__init__()
source

Initialise the instance. See the class docstring for parameter semantics.

dunder

__call__

_ModuleOutput
__call__(args: Tensor = (), kwargs: object = {})
source

Forward to the underlying callable (see class docstring).

fn

forward

_ModuleOutput
forward(args: Tensor = (), kwargs: object = {})
source

Override in subclasses to define the computation.

fn

parameters

Iterator[Parameter]
parameters(recurse: bool = True)
source

Yield all Parameters in this module (and children if recurse=True).

fn

named_parameters

Iterator[tuple[str, Parameter]]
named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True)
source

Yield (name, Parameter) pairs.

Parameters

remove_duplicatebool= True
If True (default), each unique Parameter object is yielded only once, even if referenced by multiple attributes. Mirrors reference framework.
fn

buffers

Iterator[Tensor]
buffers(recurse: bool = True)
source

Yield all buffer tensors.

fn

named_buffers

Iterator[tuple[str, Tensor]]
named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True)
source

Yield (name, buffer) pairs.

fn

modules

Iterator[Module]
modules()
source

Yield this module and all submodules (depth-first).

fn

named_modules

Iterator[tuple[str, Module]]
named_modules(memo: set[int] | None = None, prefix: str = '', remove_duplicate: bool = True)
source

Yield (name, module) pairs.

fn

children

Iterator[Module]
children()
source

Yield direct child modules.

fn

named_children

Iterator[tuple[str, Module]]
named_children()
source

Yield (name, child_module) pairs.

fn

get_submodule

Module
get_submodule(target: str)
source

Return submodule at dotted path, e.g. 'encoder.layer.0'.

fn

get_parameter

Parameter
get_parameter(target: str)
source

Return parameter at dotted path, e.g. 'fc.weight'.

fn

get_buffer

Tensor
get_buffer(target: str)
source

Return buffer at dotted path, e.g. 'bn.running_mean'.

fn

register_parameter

None
register_parameter(name: str, param: Parameter | None)
source

Register a Parameter under the given name.

fn

register_buffer

None
register_buffer(name: str, tensor: Tensor | None, persistent: bool = True)
source

Register a buffer tensor. Non-persistent buffers are excluded from state_dict.

fn

add_module

None
add_module(name: str, module: Module | None)
source

Add a child module.

fn

register_module

None
register_module(name: str, module: Module | None)
source

Alias for add_module.

fn

train

Self
train(mode: bool = True)
source

Set this module and all children to training mode.

fn

eval

Self
eval()
source

Set this module and all children to evaluation mode.

fn

to

Self
to(args: object = (), kwargs: object = {})
source

Move/cast all parameters and buffers, preserving Parameter object identity.

Floating-point dtype casts (.float(), .double(), .half(), .bfloat16()) skip integer buffers — e.g. BatchNorm.num_batches_tracked stays int64 — matching the reference framework so checkpoint round-trips don't quietly widen / narrow the counter type. Device moves still apply to every tensor.

fn

metal

Self
metal()
source

Move all parameters and buffers to Apple Metal GPU.

fn

cpu

Self
cpu()
source

Move all parameters and buffers to CPU.

fn

half

Self
half()
source

Cast all parameters and buffers to float16.

fn

float

Self
float()
source

Cast all parameters and buffers to float32.

fn

double

Self
double()
source

Cast all parameters and buffers to float64.

fn

bfloat16

Self
bfloat16()
source

Cast all parameters and buffers to bfloat16.

fn

type

Self
type(dst_type: object)
source

Cast all parameters and buffers to dst_type.

dst_type may be a lucid.dtype, a Python type (float, int), or a string ("float32", "float16", etc.). Delegates to to, which handles the conversion.

fn

apply

Self
apply(fn: Callable[[Module], None])
source

Apply fn recursively to every submodule (including self).

fn

zero_grad

None
zero_grad(set_to_none: bool = True)
source

Zero gradients of all parameters.

fn

requires_grad_

Self
requires_grad_(requires_grad: bool = True)
source

Set requires_grad for all parameters.

fn

share_memory

Self
share_memory()
source

No-op on Apple Silicon (unified memory is always shared).

fn

compile

Self
compile(args: object = (), kwargs: object = {})
source

No-op compatibility stub.

External codepaths often call model.compile() to opt into JIT acceleration; Lucid has no such layer, so this returns self unchanged rather than crashing the caller. Any positional or keyword arguments are accepted and ignored.

fn

to_empty

Self
to_empty(device: object = None, recurse: bool = True)
source

Move parameters and buffers to device without copying data.

The reference framework uses to_empty to materialise a model constructed on the meta device. Lucid has no meta device, but falls back to the standard to when called for parity with external code. recurse is honoured by to already.

fn

get_extra_state

object
get_extra_state()
source

Return extra state to include in state_dict. Override in subclasses.

fn

set_extra_state

None
set_extra_state(state: object)
source

Restore extra state loaded from state_dict. Override in subclasses.

fn

state_dict

dict[str, Tensor]
state_dict(destination: dict[str, Tensor] | None = None, prefix: str = '', keep_vars: bool = False)
source

Return a dict mapping parameter/buffer names to tensors.

The returned OrderedDict carries a _metadata attribute: {module_path: {"version": <int>}} for every module that defines a _version class attribute. lucid.save preserves this attribute across disk round-trips.

fn

load_state_dict

object
load_state_dict(state_dict: dict[str, Tensor], strict: bool = True, assign: bool = False)
source

Load parameters from a state_dict.

Calls each module's _load_from_state_dict recursively. Returns IncompatibleKeys(missing_keys, unexpected_keys) on success. Raises RuntimeError if strict=True and any keys are missing or unexpected, or if any error_msgs accumulated during loading.

Parameters

state_dictdict
A mapping from parameter/buffer names to tensors.
strictbool= True
If True (default) require an exact key match; raise on any missing or unexpected keys.
assignbool= False
If True replace each parameter/buffer object with the loaded tensor directly (allows shape/dtype changes). If False (default) copy data into the existing parameter preserving its dtype and device.
fn

register_load_state_dict_pre_hook

RemovableHandle
register_load_state_dict_pre_hook(hook: Callable[..., object])
source

Register a pre-hook called when this module loads state_dict.

Hook signature: hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None

The hook may mutate state_dict/missing/unexpected/error_msgs.

fn

register_load_state_dict_post_hook

RemovableHandle
register_load_state_dict_post_hook(hook: Callable[..., object])
source

Register a post-hook called after this module loads state_dict.

Hook signature: hook(module, incompatible_keys) -> None.

fn

register_forward_pre_hook

RemovableHandle
register_forward_pre_hook(hook: _ForwardPreHook, prepend: bool = False, with_kwargs: bool = False)
source

Register a hook called before forward().

fn

register_forward_hook

RemovableHandle
register_forward_hook(hook: _ForwardHook, prepend: bool = False, with_kwargs: bool = False, always_call: bool = False)
source

Register a hook called after forward().

fn

register_full_backward_pre_hook

RemovableHandle
register_full_backward_pre_hook(hook: _BackwardHook, prepend: bool = False)
source

Register a hook to be called before backward hooks.

fn

register_full_backward_hook

RemovableHandle
register_full_backward_hook(hook: _BackwardHook, prepend: bool = False)
source

Register a backward hook. Returns a RemovableHandle.

fn

register_backward_hook

RemovableHandle
register_backward_hook(hook: _BackwardHook)
source

Deprecated alias for register_full_backward_hook.

fn

extra_repr

str
extra_repr()
source

Override to add extra repr info (e.g. Linear shows in_features, etc.).

dunder

__repr__

str
__repr__()
source

Return a developer-facing string representation of the instance.