Transformer

Transformer

class lucid.models.Transformer(config: TransformerConfig)

The Transformer class in model provides a full implementation of the Transformer model, including positional encoding and the final vocabulary projection. This is distinct from nn.Transformer, which serves as a generic module template for building Transformer components.

        %%{init: {"flowchart":{"curve":"monotoneX","nodeSpacing":50,"rankSpacing":50}} }%%
flowchart LR
  linkStyle default stroke-width:2.0px
  subgraph sg_m0["<span style='font-size:20px;font-weight:700'>transformer_base</span>"]
  style sg_m0 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
    m1(["Embedding x 2<br/><span style='font-size:11px;color:#475569;font-weight:400'>(1,100) → (1,100,512)</span>"]);
    subgraph sg_m2["_PositionalEncoding"]
    style sg_m2 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
      m3["Dropout"];
    end
    subgraph sg_m4["Transformer"]
    style sg_m4 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
      subgraph sg_m5["TransformerEncoder"]
      style sg_m5 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
        subgraph sg_m6["layers"]
        style sg_m6 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
          subgraph sg_m7["TransformerEncoderLayer x 6"]
            direction TB;
          style sg_m7 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
            m7_in(["Input"]);
            m7_out(["Output"]);
  style m7_in fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
  style m7_out fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
            subgraph sg_m8["MultiHeadAttention"]
              direction TB;
            style sg_m8 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
              m9(["Linear x 4"]);
            end
            m10(["Linear x 2<br/><span style='font-size:11px;color:#2b6cb0;font-weight:400'>(1,100,512) → (1,100,2048)</span>"]);
            m11(["Dropout x 3"]);
            m12(["LayerNorm x 2"]);
          end
        end
        m13["LayerNorm"];
      end
      subgraph sg_m14["TransformerDecoder"]
      style sg_m14 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
        subgraph sg_m15["layers"]
        style sg_m15 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
          subgraph sg_m16["TransformerDecoderLayer x 6"]
            direction TB;
          style sg_m16 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
            m16_in(["Input"]);
            m16_out(["Output"]);
  style m16_in fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
  style m16_out fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
            subgraph sg_m17["MultiHeadAttention x 2"]
              direction TB;
            style sg_m17 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
              m17_in(["Input"]);
              m17_out(["Output"]);
  style m17_in fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
  style m17_out fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
              m18(["Linear x 4"]);
            end
            m19(["Linear x 2<br/><span style='font-size:11px;color:#2b6cb0;font-weight:400'>(1,100,512) → (1,100,2048)</span>"]);
            m20(["Dropout x 4"]);
            m21(["LayerNorm x 3"]);
          end
        end
        m22["LayerNorm"];
      end
    end
    m23["Linear<br/><span style='font-size:11px;color:#2b6cb0;font-weight:400'>(1,100,512) → (1,100,12000)</span>"];
  end
  input["Input<br/><span style='font-size:11px;color:#a67c00;font-weight:400'>(1,100)x2</span>"];
  output["Output<br/><span style='font-size:11px;color:#a67c00;font-weight:400'>(1,100,12000)</span>"];
  style input fill:#fff3cd,stroke:#a67c00,stroke-width:1px;
  style output fill:#fff3cd,stroke:#a67c00,stroke-width:1px;
  style m1 fill:#f1f5f9,stroke:#475569,stroke-width:1px;
  style m3 fill:#edf2f7,stroke:#4a5568,stroke-width:1px;
  style m9 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  style m10 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  style m11 fill:#edf2f7,stroke:#4a5568,stroke-width:1px;
  style m12 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m13 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m18 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  style m19 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  style m20 fill:#edf2f7,stroke:#4a5568,stroke-width:1px;
  style m21 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m22 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m23 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  input --> m1;
  m1 --> m3;
  m10 -.-> m11;
  m11 -.-> m10;
  m11 --> m12;
  m12 -.-> m10;
  m12 -.-> m7_in;
  m12 --> m7_out;
  m13 -.-> m18;
  m16_in -.-> m18;
  m16_out -.-> m16_in;
  m16_out --> m22;
  m17_in -.-> m18;
  m17_out -.-> m20;
  m18 --> m17_out;
  m18 -.-> m20;
  m19 -.-> m20;
  m20 -.-> m19;
  m20 --> m21;
  m21 -.-> m16_in;
  m21 --> m16_out;
  m21 --> m17_in;
  m21 -.-> m19;
  m22 --> m23;
  m23 --> output;
  m3 -.-> m9;
  m7_in -.-> m9;
  m7_out --> m13;
  m7_out -.-> m7_in;
  m9 -.-> m11;
    

Class Signature

class Transformer(config: TransformerConfig)

Parameters

  • config (TransformerConfig): Configuration object that stores vocabulary sizes, model width, layer counts, feedforward width, dropout, and positional encoding length.

Configuration

  • src_vocab_size (int): Size of the source vocabulary.

  • tgt_vocab_size (int): Size of the target vocabulary.

  • d_model (int): Dimension of the model’s hidden representations.

  • num_heads (int): Number of attention heads.

  • num_encoder_layers (int): Number of encoder layers.

  • num_decoder_layers (int): Number of decoder layers.

  • dim_feedforward (int): Feedforward width inside each Transformer block.

  • dropout (float): Dropout probability applied throughout the model.

  • max_len (int): Maximum positional encoding length.

Examples

>>> import lucid.models as models
>>> transformer = models.Transformer(
...     models.TransformerConfig(
...         src_vocab_size=5000,
...         tgt_vocab_size=5000,
...         d_model=512,
...         num_heads=8,
...         num_encoder_layers=6,
...         num_decoder_layers=6,
...         dim_feedforward=2048,
...         dropout=0.1,
...     )
... )
>>> print(transformer)
Transformer(src_vocab_size=5000, tgt_vocab_size=5000, d_model=512, ...)

This implementation follows the standard Transformer architecture and is ready to be trained for sequence-to-sequence tasks like machine translation.

Differences from nn.Transformer

  • This class implements a complete Transformer model, including positional encoding and the final projection to vocabulary space.

  • nn.Transformer, in contrast, provides a modular base class for constructing Transformer components but does not include full integration.