GRUCell
ModuleGRUCell(input_size: int, hidden_size: int, bias: bool = True, device: DeviceLike = None, dtype: DTypeLike = None)Single time-step Gated Recurrent Unit (GRU) cell.
Computes one recurrent update using the three-gate GRU equations:
The reset gate controls how much of the previous hidden state leaks into the candidate ; setting it near zero makes the cell ignore past context. The update gate interpolates between the old hidden state and the candidate, allowing the cell to retain information over many steps without an explicit cell state.
The three gate weight matrices are stacked into single parameters
of shape (3H, *) in gate order [r; z; n].
Parameters
input_sizeinthidden_sizeintbiasbool= TrueFalse, no bias terms are used. Default: True.deviceDeviceLike= NonedtypeDTypeLike= NoneAttributes
weight_ihParameter, shape ``(3 * hidden_size, input_size)``weight_hhParameter, shape ``(3 * hidden_size, hidden_size)``bias_ihParameter or None, shape ``(3 * hidden_size,)``None when bias=False.bias_hhParameter or None, shape ``(3 * hidden_size,)``None when bias=False.Notes
- x:
(N, input_size)— batch of input vectors. - hx (optional):
(N, hidden_size)— initial hidden state. Defaults to zeros whenNone. - Output
h_t:(N, hidden_size).
Weights are initialised from .
The GRU has fewer parameters than the LSTM (no cell state, three gates instead of four) and often converges faster on shorter sequences while matching LSTM quality on many benchmarks.
GRU : Multi-layer, multi-step GRU module. LSTMCell : Single-step LSTM cell with separate cell state. RNNCell : Vanilla single-step cell without gating.
Examples
Manual sequence loop:
>>> import lucid, lucid.nn as nn
>>> cell = nn.GRUCell(input_size=8, hidden_size=16)
>>> x_seq = lucid.randn(6, 4, 8) # (L=6, N=4, I=8)
>>> h = lucid.zeros(4, 16)
>>> for t in range(6):
... h = cell(x_seq[t], h)
>>> h.shape
(4, 16)
No explicit initial state (defaults to zeros):
>>> cell2 = nn.GRUCell(4, 12)
>>> h2 = cell2(lucid.randn(3, 4))
>>> h2.shape
(3, 12)Methods (3)
__init__
→None__init__(input_size: int, hidden_size: int, bias: bool = True, device: DeviceLike = None, dtype: DTypeLike = None)Initialise the GRUCell module. See the class docstring for parameter semantics.
forward
→Tensor or tuple of Tensorforward(x: Tensor, hx: Tensor | None = None)Run the recurrent forward pass.
Parameters
xTensorhxTensor= NoneReturns
Tensor or tuple of TensorOutput and (optionally) the new hidden state; see the class docstring.
extra_repr
→strextra_repr()Return a string representation of the layer's configuration.