PVT_V2

Transformer Vision Transformer

class lucid.models.PVT_V2(config: PVTV2Config)

The PVT_V2 class implements the second version of the Pyramid Vision Transformer (PVT-v2), a hierarchical transformer architecture enhanced for both computational efficiency and representational power compared to its predecessor, PVT.

Key Enhancements

  • Linear Attention (Optional): PVT-v2 introduces the option to use linear attention mechanisms, which reduce complexity from quadratic to linear in spatial dimensions, enabling faster inference on high-resolution inputs.

  • Deeper Spatial Reduction Control: The sr_ratios are retained from PVT but allow finer control per stage in PVT-v2, improving feature extraction and efficiency during multi-stage attention.

Model structure is defined through PVTV2Config.

        %%{init: {"flowchart":{"curve":"monotoneX","nodeSpacing":50,"rankSpacing":50}} }%%
flowchart LR
  linkStyle default stroke-width:2.0px
  subgraph sg_m0["<span style='font-size:20px;font-weight:700'>pvt_v2_b0</span>"]
  style sg_m0 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
    subgraph sg_m1["_OverlapPatchEmbed"]
      direction TB;
    style sg_m1 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
      m2["Conv2d<br/><span style='font-size:11px;color:#c53030;font-weight:400'>(1,3,224,224) → (1,32,56,56)</span>"];
      m3["LayerNorm"];
    end
    subgraph sg_m4["block1"]
      direction TB;
    style sg_m4 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
      subgraph sg_m5["_Block_V2 x 2"]
        direction TB;
      style sg_m5 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
        m5_in(["Input"]);
        m5_out(["Output"]);
  style m5_in fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
  style m5_out fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
        m6["LayerNorm"];
        m7["_LSRAttention"];
        m8["Identity"];
        m9["LayerNorm"];
        m10["_ConvMLP"];
      end
    end
    m11["LayerNorm"];
    subgraph sg_m12["_OverlapPatchEmbed"]
      direction TB;
    style sg_m12 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
      m13["Conv2d<br/><span style='font-size:11px;color:#c53030;font-weight:400'>(1,32,56,56) → (1,64,28,28)</span>"];
      m14["LayerNorm"];
    end
    subgraph sg_m15["block2"]
      direction TB;
    style sg_m15 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
      subgraph sg_m16["_Block_V2 x 2"]
        direction TB;
      style sg_m16 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
        m16_in(["Input"]);
        m16_out(["Output"]);
  style m16_in fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
  style m16_out fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
        m17["LayerNorm"];
        m18["_LSRAttention"];
        m19["Identity"];
        m20["LayerNorm"];
        m21["_ConvMLP"];
      end
    end
    m22["LayerNorm"];
    subgraph sg_m23["_OverlapPatchEmbed"]
      direction TB;
    style sg_m23 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
      m24["Conv2d<br/><span style='font-size:11px;color:#c53030;font-weight:400'>(1,64,28,28) → (1,160,14,14)</span>"];
      m25["LayerNorm"];
    end
    subgraph sg_m26["block3"]
      direction TB;
    style sg_m26 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
      subgraph sg_m27["_Block_V2 x 2"]
        direction TB;
      style sg_m27 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
        m27_in(["Input"]);
        m27_out(["Output"]);
  style m27_in fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
  style m27_out fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
        m28["LayerNorm"];
        m29["_LSRAttention"];
        m30["Identity"];
        m31["LayerNorm"];
        m32["_ConvMLP"];
      end
    end
    m33["LayerNorm"];
    subgraph sg_m34["_OverlapPatchEmbed"]
      direction TB;
    style sg_m34 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
      m35["Conv2d<br/><span style='font-size:11px;color:#c53030;font-weight:400'>(1,160,14,14) → (1,256,7,7)</span>"];
      m36["LayerNorm"];
    end
    subgraph sg_m37["block4"]
      direction TB;
    style sg_m37 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
      subgraph sg_m38["_Block_V2 x 2"]
        direction TB;
      style sg_m38 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
        m38_in(["Input"]);
        m38_out(["Output"]);
  style m38_in fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
  style m38_out fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
        m39["LayerNorm"];
        m40["_LSRAttention"];
        m41["Identity"];
        m42["LayerNorm"];
        m43["_ConvMLP"];
      end
    end
    m44["LayerNorm"];
    m45["Linear<br/><span style='font-size:11px;color:#2b6cb0;font-weight:400'>(1,256) → (1,1000)</span>"];
  end
  input["Input<br/><span style='font-size:11px;color:#a67c00;font-weight:400'>(1,3,224,224)</span>"];
  output["Output<br/><span style='font-size:11px;color:#a67c00;font-weight:400'>(1,1000)</span>"];
  style input fill:#fff3cd,stroke:#a67c00,stroke-width:1px;
  style output fill:#fff3cd,stroke:#a67c00,stroke-width:1px;
  style m2 fill:#ffe8e8,stroke:#c53030,stroke-width:1px;
  style m3 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m6 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m8 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  style m9 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m11 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m13 fill:#ffe8e8,stroke:#c53030,stroke-width:1px;
  style m14 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m17 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m19 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  style m20 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m22 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m24 fill:#ffe8e8,stroke:#c53030,stroke-width:1px;
  style m25 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m28 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m30 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  style m31 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m33 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m35 fill:#ffe8e8,stroke:#c53030,stroke-width:1px;
  style m36 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m39 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m41 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  style m42 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m44 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m45 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  input --> m2;
  m10 --> m5_out;
  m10 -.-> m8;
  m11 --> m13;
  m13 --> m14;
  m14 -.-> m17;
  m16_in -.-> m17;
  m16_out --> m22;
  m17 --> m18;
  m18 -.-> m19;
  m19 --> m16_in;
  m19 --> m20;
  m2 --> m3;
  m20 --> m21;
  m21 --> m16_out;
  m21 -.-> m19;
  m22 --> m24;
  m24 --> m25;
  m25 -.-> m28;
  m27_in -.-> m28;
  m27_out --> m33;
  m28 --> m29;
  m29 -.-> m30;
  m3 -.-> m6;
  m30 --> m27_in;
  m30 --> m31;
  m31 --> m32;
  m32 --> m27_out;
  m32 -.-> m30;
  m33 --> m35;
  m35 --> m36;
  m36 -.-> m39;
  m38_in -.-> m39;
  m38_out --> m44;
  m39 --> m40;
  m40 -.-> m41;
  m41 --> m38_in;
  m41 --> m42;
  m42 --> m43;
  m43 --> m38_out;
  m43 -.-> m41;
  m44 --> m45;
  m45 --> output;
  m5_in -.-> m6;
  m5_out --> m11;
  m6 --> m7;
  m7 -.-> m8;
  m8 --> m5_in;
  m8 --> m9;
  m9 --> m10;
    

Class Signature

class PVT_V2(nn.Module):
    def __init__(self, config: PVTV2Config) -> None

Parameters

  • config (PVTV2Config): Configuration object describing the overlap patch embedding, stage widths, attention heads, depth schedule, spatial reduction ratios, and optional linear attention path.

Examples

>>> import lucid.models as models
>>> config = models.PVTV2Config(
...     img_size=224,
...     patch_size=7,
...     in_channels=3,
...     num_classes=1000,
...     embed_dims=(64, 128, 320, 512),
...     num_heads=(1, 2, 5, 8),
...     depths=(3, 4, 6, 3),
... )
>>> model = models.PVT_V2(config)