class

GRU

extends_CellNamingMixinModule
GRU(input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, batch_first: bool = False, dropout: float = 0.0, bidirectional: bool = False, device: DeviceLike = None, dtype: DTypeLike = None)
source

Multi-layer Gated Recurrent Unit (GRU) recurrent layer.

Applies a stack of GRU cells over an input sequence. At each time step tt the following equations are evaluated (see GRUCell for the full derivation):

rt=σ(Wirxt+Whrht1+br)zt=σ(Wizxt+Whzht1+bz)nt=tanh ⁣(Winxt+bin+rt(Whnht1+bhn))ht=(1zt)ht1+ztnt\begin{aligned} r_t &= \sigma(W_{ir}\,x_t + W_{hr}\,h_{t-1} + b_r) \\ z_t &= \sigma(W_{iz}\,x_t + W_{hz}\,h_{t-1} + b_z) \\ n_t &= \tanh\!\left(W_{in}\,x_t + b_{in} + r_t \odot (W_{hn}\,h_{t-1} + b_{hn})\right) \\ h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot n_t \end{aligned}

The output of layer \ell is used as the input of layer +1\ell + 1. When bidirectional=True two GRUs process the sequence in opposite directions and their outputs are concatenated along the feature axis at every time step.

Inter-layer dropout (probability dropout) is applied between adjacent layers during training, but not after the final layer.

Parameters

input_sizeint
Number of expected features in the input xtx_t.
hidden_sizeint
Number of features in the hidden state hth_t (denoted HH below).
num_layersint= 1
Number of stacked GRU layers. Default: 1.
biasbool= True
If False, all bias parameters are omitted. Default: True.
batch_firstbool= False
If True the input/output tensors have shape (N, L, *) instead of the default (L, N, *). Default: False.
dropoutfloat= 0.0
Dropout probability applied after every layer except the last. 0.0 disables dropout. Default: 0.0.
bidirectionalbool= False
If True, a bidirectional GRU is used; the output feature dimension becomes 2 * hidden_size. Default: False.
deviceDeviceLike= None
Device for weight allocation.
dtypeDTypeLike= None
Data type for weight tensors.

Notes

  • Input x: (L, N, input_size) or (N, L, input_size) when batch_first=True.
  • h_0 (optional): (D * num_layers, N, H) where D = 2 if bidirectional else 1. Defaults to zeros.
  • output: (L, N, D * H) or (N, L, D * H).
  • h_n: (D * num_layers, N, H) — hidden state at the final time step for each layer and direction.

Internally this module stores one GRUCell sub-module per layer per direction (named cell_l{layer} and cell_l{layer}_reverse). The _CellNamingMixin flattens these into weight_ih_l{layer} etc. for checkpoint compatibility with the reference framework.

flatten_parameters is a no-op retained for API compatibility.

PackedSequence input is not yet supported.

GRUCell : Single time-step GRU cell. LSTM : Long Short-Term Memory (carries a separate cell state). RNN : Vanilla Elman RNN (no gating).

Examples

Two-layer GRU, batch-first:
>>> import lucid, lucid.nn as nn
>>> gru = nn.GRU(8, 16, num_layers=2, batch_first=True)
>>> x = lucid.randn(2, 5, 8)       # (N=2, L=5, I=8)
>>> out, h_n = gru(x)
>>> out.shape, h_n.shape
((2, 5, 16), (2, 2, 16))
Bidirectional GRU:
>>> gru_bi = nn.GRU(8, 16, bidirectional=True, batch_first=True)
>>> x2 = lucid.randn(3, 10, 8)
>>> out2, h_n2 = gru_bi(x2)
>>> out2.shape    # D*H = 2*16 = 32
(3, 10, 32)
>>> h_n2.shape    # D*num_layers = 2*1 = 2
(2, 3, 16)

Methods (4)

dunder

__init__

None
__init__(input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, batch_first: bool = False, dropout: float = 0.0, bidirectional: bool = False, device: DeviceLike = None, dtype: DTypeLike = None)
source

Initialise the GRU module. See the class docstring for parameter semantics.

fn

flatten_parameters

None
flatten_parameters()
source

No-op for API compatibility (see LSTM.flatten_parameters).

fn

forward

Tensor or tuple of Tensor
forward(x: Tensor, hx: Tensor | None = None)
source

Run the recurrent forward pass.

Parameters

xTensor
See the class docstring.
hxTensor= None
See the class docstring.

Returns

Tensor or tuple of Tensor

Output and (optionally) the new hidden state; see the class docstring.

fn

extra_repr

str
extra_repr()
source

Return a string representation of the layer's configuration.