SwinTransformer_V2

Transformer Vision Transformer

class lucid.models.SwinTransformer_V2(config: SwinTransformerV2Config)

The SwinTransformer_V2 class extends the original Swin Transformer architecture with enhancements such as log-spaced relative positional bias and normalization improvements. It maintains the hierarchical structure and shifted window self-attention mechanism while providing better scalability and performance across different image resolutions. Model structure is defined through SwinTransformerV2Config.

        %%{init: {"flowchart":{"curve":"monotoneX","nodeSpacing":50,"rankSpacing":50}} }%%
flowchart LR
  linkStyle default stroke-width:2.0px
  subgraph sg_m0["<span style='font-size:20px;font-weight:700'>swin_v2_base</span>"]
  style sg_m0 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
    subgraph sg_m1["_PatchEmbed"]
    style sg_m1 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
      m2["Conv2d<br/><span style='font-size:11px;color:#c53030;font-weight:400'>(1,3,224,224) → (1,128,56,56)</span>"];
      m3["LayerNorm"];
    end
    m4["Dropout"];
    subgraph sg_m5["layers"]
    style sg_m5 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
      subgraph sg_m6["_BasicLayer x 3"]
        direction TB;
      style sg_m6 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
        m6_in(["Input"]);
        m6_out(["Output"]);
  style m6_in fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
  style m6_out fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
        subgraph sg_m7["blocks"]
          direction TB;
        style sg_m7 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
          m8(["_SwinTransformerBlock_V2 x 2"]);
        end
        subgraph sg_m9["_PatchMerging"]
          direction TB;
        style sg_m9 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
          m10["Linear<br/><span style='font-size:11px;color:#2b6cb0;font-weight:400'>(1,784,512) → (1,784,256)</span>"];
          m11["LayerNorm"];
        end
      end
      subgraph sg_m12["_BasicLayer"]
        direction TB;
      style sg_m12 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
        subgraph sg_m13["blocks"]
          direction TB;
        style sg_m13 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
          m14(["_SwinTransformerBlock_V2 x 2"]);
        end
      end
    end
    m15["LayerNorm"];
    m16["AdaptiveAvgPool1d<br/><span style='font-size:11px;color:#b7791f;font-weight:400'>(1,1024,49) → (1,1024,1)</span>"];
    m17["Linear<br/><span style='font-size:11px;color:#2b6cb0;font-weight:400'>(1,1024) → (1,1000)</span>"];
  end
  input["Input<br/><span style='font-size:11px;color:#a67c00;font-weight:400'>(1,3,224,224)</span>"];
  output["Output<br/><span style='font-size:11px;color:#a67c00;font-weight:400'>(1,1000)</span>"];
  style input fill:#fff3cd,stroke:#a67c00,stroke-width:1px;
  style output fill:#fff3cd,stroke:#a67c00,stroke-width:1px;
  style m2 fill:#ffe8e8,stroke:#c53030,stroke-width:1px;
  style m3 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m4 fill:#edf2f7,stroke:#4a5568,stroke-width:1px;
  style m10 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  style m11 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m15 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m16 fill:#fefcbf,stroke:#b7791f,stroke-width:1px;
  style m17 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  input --> m2;
  m10 -.-> m6_in;
  m11 --> m10;
  m11 --> m6_out;
  m14 --> m15;
  m15 --> m16;
  m16 --> m17;
  m17 --> output;
  m2 --> m3;
  m3 --> m4;
  m4 -.-> m8;
  m6_in -.-> m8;
  m6_out --> m14;
  m6_out -.-> m6_in;
  m8 --> m11;
    

Class Signature

class SwinTransformer_V2(nn.Module):
    def __init__(self, config: SwinTransformerV2Config) -> None

Parameters

  • config (SwinTransformerV2Config): Configuration object describing the image resolution, patch embedding, hierarchical stage layout, attention heads, window size, and classifier setup.

Examples

>>> import lucid.models as models
>>> config = models.SwinTransformerV2Config(
...     img_size=224,
...     patch_size=4,
...     in_channels=3,
...     num_classes=1000,
...     embed_dim=96,
...     depths=(2, 2, 6, 2),
...     num_heads=(3, 6, 12, 24),
... )
>>> swin_v2 = models.SwinTransformer_V2(config)