PVT

Transformer Vision Transformer

class lucid.models.PVT(config: PVTConfig)

The PVT class implements the Pyramid Vision Transformer (PVT), a hierarchical vision transformer designed for image classification. PVT introduces a multi-stage architecture with progressive spatial reduction, enabling efficient modeling of global and local features. The model supports various configurations for embedding dimensions, attention heads, depth, and other hyperparameters. Model structure is defined through PVTConfig.

        %%{init: {"flowchart":{"curve":"monotoneX","nodeSpacing":50,"rankSpacing":50}} }%%
flowchart LR
  linkStyle default stroke-width:2.0px
  subgraph sg_m0["<span style='font-size:20px;font-weight:700'>pvt_medium</span>"]
  style sg_m0 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
    subgraph sg_m1["_PatchEmbed x 4"]
    style sg_m1 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
      m1_in(["Input"]);
      m1_out(["Output"]);
  style m1_in fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
  style m1_out fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
      m2["Conv2d<br/><span style='font-size:11px;color:#c53030;font-weight:400'>(1,3,224,224) → (1,64,56,56)</span>"];
      m3["LayerNorm"];
    end
    m4(["Dropout x 4"]);
    subgraph sg_m5["block1 x 3"]
    style sg_m5 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
      m5_in(["Input"]);
      m5_out(["Output"]);
  style m5_in fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
  style m5_out fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
      subgraph sg_m6["_Block x 3"]
        direction TB;
      style sg_m6 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
        m6_in(["Input"]);
        m6_out(["Output"]);
  style m6_in fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
  style m6_out fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
        m7["LayerNorm"];
        subgraph sg_m8["_SRAttention"]
          direction TB;
        style sg_m8 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
          m9(["Linear x 2"]);
          m10(["Dropout x 3"]);
          m11["Conv2d<br/><span style='font-size:11px;color:#c53030;font-weight:400'>(1,64,56,56) → (1,64,7,7)</span>"];
          m12["LayerNorm"];
        end
        m13["Identity"];
        m14["LayerNorm"];
        subgraph sg_m15["_MLP"]
          direction TB;
        style sg_m15 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
          m16["Linear<br/><span style='font-size:11px;color:#2b6cb0;font-weight:400'>(1,3136,64) → (1,3136,512)</span>"];
          m17["GELU"];
          m18["Linear<br/><span style='font-size:11px;color:#2b6cb0;font-weight:400'>(1,3136,512) → (1,3136,64)</span>"];
          m19["Dropout"];
        end
      end
    end
    subgraph sg_m20["block4"]
    style sg_m20 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
      subgraph sg_m21["_Block x 3"]
        direction TB;
      style sg_m21 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
        m21_in(["Input"]);
        m21_out(["Output"]);
  style m21_in fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
  style m21_out fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
        m22["LayerNorm"];
        subgraph sg_m23["_SRAttention"]
          direction TB;
        style sg_m23 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
          m24(["Linear x 2"]);
          m25(["Dropout x 3"]);
        end
        m26["Identity"];
        m27["LayerNorm"];
        subgraph sg_m28["_MLP"]
          direction TB;
        style sg_m28 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
          m29["Linear<br/><span style='font-size:11px;color:#2b6cb0;font-weight:400'>(1,50,512) → (1,50,2048)</span>"];
          m30["GELU"];
          m31["Linear<br/><span style='font-size:11px;color:#2b6cb0;font-weight:400'>(1,50,2048) → (1,50,512)</span>"];
          m32["Dropout"];
        end
      end
    end
    m33["LayerNorm"];
    m34["Linear<br/><span style='font-size:11px;color:#2b6cb0;font-weight:400'>(1,512) → (1,1000)</span>"];
  end
  input["Input<br/><span style='font-size:11px;color:#a67c00;font-weight:400'>(1,3,224,224)</span>"];
  output["Output<br/><span style='font-size:11px;color:#a67c00;font-weight:400'>(1,1000)</span>"];
  style input fill:#fff3cd,stroke:#a67c00,stroke-width:1px;
  style output fill:#fff3cd,stroke:#a67c00,stroke-width:1px;
  style m2 fill:#ffe8e8,stroke:#c53030,stroke-width:1px;
  style m3 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m4 fill:#edf2f7,stroke:#4a5568,stroke-width:1px;
  style m7 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m9 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  style m10 fill:#edf2f7,stroke:#4a5568,stroke-width:1px;
  style m11 fill:#ffe8e8,stroke:#c53030,stroke-width:1px;
  style m12 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m13 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  style m14 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m16 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  style m17 fill:#faf5ff,stroke:#6b46c1,stroke-width:1px;
  style m18 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  style m19 fill:#edf2f7,stroke:#4a5568,stroke-width:1px;
  style m22 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m24 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  style m25 fill:#edf2f7,stroke:#4a5568,stroke-width:1px;
  style m26 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  style m27 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m29 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  style m30 fill:#faf5ff,stroke:#6b46c1,stroke-width:1px;
  style m31 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  style m32 fill:#edf2f7,stroke:#4a5568,stroke-width:1px;
  style m33 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m34 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  input -.-> m2;
  m10 -.-> m13;
  m11 --> m12;
  m12 -.-> m9;
  m13 --> m14;
  m13 -.-> m6_in;
  m14 --> m16;
  m16 --> m17;
  m17 -.-> m19;
  m18 -.-> m19;
  m19 -.-> m13;
  m19 --> m18;
  m19 --> m5_out;
  m19 --> m6_out;
  m1_in -.-> m2;
  m1_out -.-> m4;
  m2 --> m3;
  m21_in -.-> m22;
  m21_out -.-> m21_in;
  m21_out --> m33;
  m22 --> m24;
  m24 --> m25;
  m25 -.-> m26;
  m26 -.-> m21_in;
  m26 --> m27;
  m27 --> m29;
  m29 --> m30;
  m3 --> m1_out;
  m3 -.-> m4;
  m30 -.-> m32;
  m31 -.-> m32;
  m32 --> m21_out;
  m32 -.-> m26;
  m32 --> m31;
  m33 --> m34;
  m34 --> output;
  m4 -.-> m22;
  m4 --> m5_in;
  m4 -.-> m7;
  m5_in -.-> m7;
  m5_out -.-> m1_in;
  m6_in -.-> m7;
  m6_out -.-> m1_in;
  m6_out -.-> m6_in;
  m7 -.-> m9;
  m9 --> m10;
  m9 --> m11;
    

Class Signature

class PVT(nn.Module):
    def __init__(self, config: PVTConfig) -> None

Parameters

  • config (PVTConfig): Configuration object describing the image size, initial patch embedding, four-stage embedding widths, attention heads, depth schedule, and spatial reduction ratios.

Examples

>>> import lucid.models as models
>>> config = models.PVTConfig(
...     img_size=224,
...     num_classes=1000,
...     patch_size=4,
...     embed_dims=(64, 128, 320, 512),
...     num_heads=(1, 2, 5, 8),
...     depths=(3, 4, 6, 3),
... )
>>> model = models.PVT(config)