CoAtNet

ConvNet

class lucid.models.CoAtNet(config: CoAtNetConfig)

The CoAtNet module in lucid.nn implements the CoAtNet architecture, a hybrid model combining convolutional and attention-based mechanisms. It leverages the strengths of both convolutional neural networks (CNNs) and vision transformers (ViTs), making it highly efficient for image classification tasks. Model structure is defined through CoAtNetConfig.

CoAtNet utilizes depthwise convolutions, relative position encoding, and pre-normalization to enhance training stability and performance.

        %%{init: {"flowchart":{"curve":"monotoneX","nodeSpacing":50,"rankSpacing":50},"themeCSS":".nodeLabel, .edgeLabel, .cluster text, .node text { fill: #000000 !important; } .node foreignObject *, .cluster foreignObject * { color: #000000 !important; }"} }%%
flowchart LR
  linkStyle default stroke-width:2.0px
  subgraph sg_m0["<span style='font-size:20px;font-weight:700'>coatnet_0</span>"]
  style sg_m0 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
    subgraph sg_m1["s0"]
    style sg_m1 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
      subgraph sg_m2["Sequential x 2"]
      style sg_m2 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
        m2_in(["Input"]);
        m2_out(["Output"]);
  style m2_in fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
  style m2_out fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
        m3["Conv2d<br/><span style='font-size:11px;font-weight:400'>(1,3,224,224) → (1,64,112,112)</span>"];
        m4["BatchNorm2d"];
        m5["GELU"];
      end
    end
    subgraph sg_m6["s1 x 2"]
    style sg_m6 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
      m6_in(["Input"]);
      m6_out(["Output"]);
  style m6_in fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
  style m6_out fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
      subgraph sg_m7["_MBConv"]
      style sg_m7 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
        m8["MaxPool2d<br/><span style='font-size:11px;font-weight:400'>(1,64,112,112) → (1,64,56,56)</span>"];
        m9["Conv2d<br/><span style='font-size:11px;font-weight:400'>(1,64,56,56) → (1,96,56,56)</span>"];
        m10["_PreNorm<br/><span style='font-size:11px;font-weight:400'>(1,64,112,112) → (1,96,56,56)</span>"];
      end
      subgraph sg_m11["_MBConv"]
      style sg_m11 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
        m12["_PreNorm"];
      end
    end
    subgraph sg_m13["s3 x 2"]
    style sg_m13 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
      m13_in(["Input"]);
      m13_out(["Output"]);
  style m13_in fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
  style m13_out fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
      subgraph sg_m14["_Transformer"]
      style sg_m14 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
        m15(["MaxPool2d x 2<br/><span style='font-size:11px;font-weight:400'>(1,192,28,28) → (1,192,14,14)</span>"]);
        m16["Conv2d<br/><span style='font-size:11px;font-weight:400'>(1,192,14,14) → (1,384,14,14)</span>"];
        m17(["Sequential x 2<br/><span style='font-size:11px;font-weight:400'>(1,192,14,14) → (1,384,14,14)</span>"]);
      end
      subgraph sg_m18["_Transformer x 4"]
      style sg_m18 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
        m18_in(["Input"]);
        m18_out(["Output"]);
  style m18_in fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
  style m18_out fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
        m19(["Sequential x 2"]);
      end
    end
    m20["AdaptiveAvgPool2d<br/><span style='font-size:11px;font-weight:400'>(1,768,7,7) → (1,768,1,1)</span>"];
    m21["Linear<br/><span style='font-size:11px;font-weight:400'>(1,768) → (1,1000)</span>"];
  end
  input["Input<br/><span style='font-size:11px;color:#000000;font-weight:400'>(1,3,224,224)</span>"];
  output["Output<br/><span style='font-size:11px;color:#000000;font-weight:400'>(1,1000)</span>"];
  style input fill:#fff3cd,stroke:#a67c00,stroke-width:1px;
  style output fill:#fff3cd,stroke:#a67c00,stroke-width:1px;
  style m3 fill:#ffe8e8,stroke:#c53030,stroke-width:1px;
  style m4 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m5 fill:#faf5ff,stroke:#6b46c1,stroke-width:1px;
  style m8 fill:#fefcbf,stroke:#b7791f,stroke-width:1px;
  style m9 fill:#ffe8e8,stroke:#c53030,stroke-width:1px;
  style m15 fill:#fefcbf,stroke:#b7791f,stroke-width:1px;
  style m16 fill:#ffe8e8,stroke:#c53030,stroke-width:1px;
  style m20 fill:#fefcbf,stroke:#b7791f,stroke-width:1px;
  style m21 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  input -.-> m3;
  m10 --> m12;
  m12 --> m6_in;
  m12 --> m6_out;
  m13_in -.-> m15;
  m13_out --> m20;
  m15 --> m16;
  m15 --> m17;
  m16 -.-> m15;
  m17 -.-> m19;
  m18_in -.-> m19;
  m18_out --> m13_in;
  m18_out -.-> m18_in;
  m19 --> m13_out;
  m19 -.-> m18_in;
  m19 --> m18_out;
  m20 --> m21;
  m21 --> output;
  m2_in -.-> m3;
  m2_out -.-> m8;
  m3 --> m4;
  m4 --> m5;
  m5 --> m2_in;
  m5 --> m2_out;
  m6_in -.-> m8;
  m6_out -.-> m15;
  m8 --> m9;
  m9 --> m10;
    

Class Signature

class CoAtNet(nn.Module):
    def __init__(self, config: CoAtNetConfig) -> None

Parameters

  • config (CoAtNetConfig): Configuration object describing the input resolution, stage depths, stage widths, classifier size, attention heads, block types, and optional scaled tandem stage settings.

Hybrid Architecture

The CoAtNet model employs a hybrid structure that fuses convolutional and transformer blocks for enhanced representation learning:

  1. Early Convolutional Blocks:

    • The initial stages use convolution-based feature extraction (C blocks).

    • These layers focus on capturing local patterns efficiently.

    • Convolutions perform feature extraction using:

      \[Y = W * X + b\]
  2. Transformer-Based Blocks:

    • Later stages transition into transformer blocks (T blocks).

    • These layers incorporate self-attention to capture long-range dependencies.

    • Self-attention is computed as:

      \[\mathbf{A} = \text{softmax} \left( \frac{QK^T}{\sqrt{d_k}} \right)V\]
  3. Pre-Normalization:

    • Each transformer block applies Layer Normalization before the attention mechanism.

    • Helps improve gradient flow and stability during training.

    • The normalization step follows:

      \[\hat{x} = \frac{x - \mu}{\sigma + \epsilon}\]
  4. Relative Position Encoding:

    • Unlike absolute position encoding in traditional ViTs, CoAtNet leverages relative position encoding.

    • The attention mechanism incorporates positional information dynamically:

      \[A_{ij} = \frac{Q_i K_j^T}{\sqrt{d_k}} + B_{ij}\]
    • The relative position bias matrix ( B_{ij} ) is learnable and helps in modeling spatial relationships.

  5. Depthwise Convolutions:

    • Used to reduce computational complexity while maintaining strong feature extraction capabilities.

    • Reduces the number of parameters compared to traditional convolutional layers.

    • The depthwise convolution operation is:

      \[Y_{i,j} = \sum_{k} X_{i+k, j+k} W_k\]
  6. Scaling Strategy:

    • CoAtNet scales efficiently across depth (D), width (W), and resolution (R), making it highly versatile for various image sizes and computational constraints.

Examples

import lucid.models as models

config = models.CoAtNetConfig(
    img_size=(224, 224),
    in_channels=3,
    num_blocks=(2, 2, 3, 5, 2),
    channels=(64, 96, 192, 384, 768),
    num_classes=1000,
    num_heads=32,
    block_types=("C", "C", "T", "T"),
)
model = models.CoAtNet(config)

input_ = lucid.random.randn(1, 3, 224, 224)
output = model(input_)
print(output.shape)