CoAtNet¶
ConvNet
- class lucid.models.CoAtNet(config: CoAtNetConfig)¶
The CoAtNet module in lucid.nn implements the CoAtNet architecture, a hybrid model combining convolutional and attention-based mechanisms. It leverages the strengths of both convolutional neural networks (CNNs) and vision transformers (ViTs), making it highly efficient for image classification tasks. Model structure is defined through CoAtNetConfig.
CoAtNet utilizes depthwise convolutions, relative position encoding, and pre-normalization to enhance training stability and performance.
%%{init: {"flowchart":{"curve":"monotoneX","nodeSpacing":50,"rankSpacing":50},"themeCSS":".nodeLabel, .edgeLabel, .cluster text, .node text { fill: #000000 !important; } .node foreignObject *, .cluster foreignObject * { color: #000000 !important; }"} }%%
flowchart LR
linkStyle default stroke-width:2.0px
subgraph sg_m0["<span style='font-size:20px;font-weight:700'>coatnet_0</span>"]
style sg_m0 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
subgraph sg_m1["s0"]
style sg_m1 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
subgraph sg_m2["Sequential x 2"]
style sg_m2 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
m2_in(["Input"]);
m2_out(["Output"]);
style m2_in fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
style m2_out fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
m3["Conv2d<br/><span style='font-size:11px;font-weight:400'>(1,3,224,224) → (1,64,112,112)</span>"];
m4["BatchNorm2d"];
m5["GELU"];
end
end
subgraph sg_m6["s1 x 2"]
style sg_m6 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
m6_in(["Input"]);
m6_out(["Output"]);
style m6_in fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
style m6_out fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
subgraph sg_m7["_MBConv"]
style sg_m7 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
m8["MaxPool2d<br/><span style='font-size:11px;font-weight:400'>(1,64,112,112) → (1,64,56,56)</span>"];
m9["Conv2d<br/><span style='font-size:11px;font-weight:400'>(1,64,56,56) → (1,96,56,56)</span>"];
m10["_PreNorm<br/><span style='font-size:11px;font-weight:400'>(1,64,112,112) → (1,96,56,56)</span>"];
end
subgraph sg_m11["_MBConv"]
style sg_m11 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
m12["_PreNorm"];
end
end
subgraph sg_m13["s3 x 2"]
style sg_m13 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
m13_in(["Input"]);
m13_out(["Output"]);
style m13_in fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
style m13_out fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
subgraph sg_m14["_Transformer"]
style sg_m14 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
m15(["MaxPool2d x 2<br/><span style='font-size:11px;font-weight:400'>(1,192,28,28) → (1,192,14,14)</span>"]);
m16["Conv2d<br/><span style='font-size:11px;font-weight:400'>(1,192,14,14) → (1,384,14,14)</span>"];
m17(["Sequential x 2<br/><span style='font-size:11px;font-weight:400'>(1,192,14,14) → (1,384,14,14)</span>"]);
end
subgraph sg_m18["_Transformer x 4"]
style sg_m18 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
m18_in(["Input"]);
m18_out(["Output"]);
style m18_in fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
style m18_out fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
m19(["Sequential x 2"]);
end
end
m20["AdaptiveAvgPool2d<br/><span style='font-size:11px;font-weight:400'>(1,768,7,7) → (1,768,1,1)</span>"];
m21["Linear<br/><span style='font-size:11px;font-weight:400'>(1,768) → (1,1000)</span>"];
end
input["Input<br/><span style='font-size:11px;color:#000000;font-weight:400'>(1,3,224,224)</span>"];
output["Output<br/><span style='font-size:11px;color:#000000;font-weight:400'>(1,1000)</span>"];
style input fill:#fff3cd,stroke:#a67c00,stroke-width:1px;
style output fill:#fff3cd,stroke:#a67c00,stroke-width:1px;
style m3 fill:#ffe8e8,stroke:#c53030,stroke-width:1px;
style m4 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
style m5 fill:#faf5ff,stroke:#6b46c1,stroke-width:1px;
style m8 fill:#fefcbf,stroke:#b7791f,stroke-width:1px;
style m9 fill:#ffe8e8,stroke:#c53030,stroke-width:1px;
style m15 fill:#fefcbf,stroke:#b7791f,stroke-width:1px;
style m16 fill:#ffe8e8,stroke:#c53030,stroke-width:1px;
style m20 fill:#fefcbf,stroke:#b7791f,stroke-width:1px;
style m21 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
input -.-> m3;
m10 --> m12;
m12 --> m6_in;
m12 --> m6_out;
m13_in -.-> m15;
m13_out --> m20;
m15 --> m16;
m15 --> m17;
m16 -.-> m15;
m17 -.-> m19;
m18_in -.-> m19;
m18_out --> m13_in;
m18_out -.-> m18_in;
m19 --> m13_out;
m19 -.-> m18_in;
m19 --> m18_out;
m20 --> m21;
m21 --> output;
m2_in -.-> m3;
m2_out -.-> m8;
m3 --> m4;
m4 --> m5;
m5 --> m2_in;
m5 --> m2_out;
m6_in -.-> m8;
m6_out -.-> m15;
m8 --> m9;
m9 --> m10;
Class Signature¶
class CoAtNet(nn.Module):
def __init__(self, config: CoAtNetConfig) -> None
Parameters¶
config (CoAtNetConfig): Configuration object describing the input resolution, stage depths, stage widths, classifier size, attention heads, block types, and optional scaled tandem stage settings.
Hybrid Architecture¶
The CoAtNet model employs a hybrid structure that fuses convolutional and transformer blocks for enhanced representation learning:
Early Convolutional Blocks:
The initial stages use convolution-based feature extraction (C blocks).
These layers focus on capturing local patterns efficiently.
Convolutions perform feature extraction using:
\[Y = W * X + b\]
Transformer-Based Blocks:
Later stages transition into transformer blocks (T blocks).
These layers incorporate self-attention to capture long-range dependencies.
Self-attention is computed as:
\[\mathbf{A} = \text{softmax} \left( \frac{QK^T}{\sqrt{d_k}} \right)V\]
Pre-Normalization:
Each transformer block applies Layer Normalization before the attention mechanism.
Helps improve gradient flow and stability during training.
The normalization step follows:
\[\hat{x} = \frac{x - \mu}{\sigma + \epsilon}\]
Relative Position Encoding:
Unlike absolute position encoding in traditional ViTs, CoAtNet leverages relative position encoding.
The attention mechanism incorporates positional information dynamically:
\[A_{ij} = \frac{Q_i K_j^T}{\sqrt{d_k}} + B_{ij}\]The relative position bias matrix ( B_{ij} ) is learnable and helps in modeling spatial relationships.
Depthwise Convolutions:
Used to reduce computational complexity while maintaining strong feature extraction capabilities.
Reduces the number of parameters compared to traditional convolutional layers.
The depthwise convolution operation is:
\[Y_{i,j} = \sum_{k} X_{i+k, j+k} W_k\]
Scaling Strategy:
CoAtNet scales efficiently across depth (D), width (W), and resolution (R), making it highly versatile for various image sizes and computational constraints.
Examples¶
import lucid.models as models
config = models.CoAtNetConfig(
img_size=(224, 224),
in_channels=3,
num_blocks=(2, 2, 3, 5, 2),
channels=(64, 96, 192, 384, 768),
num_classes=1000,
num_heads=32,
block_types=("C", "C", "T", "T"),
)
model = models.CoAtNet(config)
input_ = lucid.random.randn(1, 3, 224, 224)
output = model(input_)
print(output.shape)