ResNet¶
ConvNet
- class lucid.models.ResNet(config: ResNetConfig)¶
The ResNet class provides a configurable implementation of the residual network architecture. Model structure is described by ResNetConfig, which controls the residual block family, stage depths, input stem, stage widths, and block-level keyword arguments for custom ResNet variants.
%%{init: {"flowchart":{"curve":"monotoneX","nodeSpacing":50,"rankSpacing":50}} }%%
flowchart LR
linkStyle default stroke-width:2.0px
subgraph sg_m0["<span style='font-size:20px;font-weight:700'>resnet_101</span>"]
style sg_m0 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
subgraph sg_m1["stem"]
style sg_m1 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
m2["Conv2d<br/><span style='font-size:11px;color:#c53030;font-weight:400'>(1,3,224,224) → (1,64,112,112)</span>"];
m3["BatchNorm2d"];
m4["ReLU"];
end
m5["MaxPool2d<br/><span style='font-size:11px;color:#b7791f;font-weight:400'>(1,64,112,112) → (1,64,56,56)</span>"];
subgraph sg_m6["layer1 x 4"]
style sg_m6 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
m6_in(["Input"]);
m6_out(["Output"]);
style m6_in fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
style m6_out fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
m7(["_Bottleneck x 3<br/><span style='font-size:11px;font-weight:400'>(1,64,56,56) → (1,256,56,56)</span>"]);
end
m8["AdaptiveAvgPool2d<br/><span style='font-size:11px;color:#b7791f;font-weight:400'>(1,2048,7,7) → (1,2048,1,1)</span>"];
m9["Linear<br/><span style='font-size:11px;color:#2b6cb0;font-weight:400'>(1,2048) → (1,1000)</span>"];
end
input["Input<br/><span style='font-size:11px;color:#a67c00;font-weight:400'>(1,3,224,224)</span>"];
output["Output<br/><span style='font-size:11px;color:#a67c00;font-weight:400'>(1,1000)</span>"];
style input fill:#fff3cd,stroke:#a67c00,stroke-width:1px;
style output fill:#fff3cd,stroke:#a67c00,stroke-width:1px;
style m2 fill:#ffe8e8,stroke:#c53030,stroke-width:1px;
style m3 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
style m4 fill:#faf5ff,stroke:#6b46c1,stroke-width:1px;
style m5 fill:#fefcbf,stroke:#b7791f,stroke-width:1px;
style m8 fill:#fefcbf,stroke:#b7791f,stroke-width:1px;
style m9 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
input --> m2;
m2 --> m3;
m3 --> m4;
m4 --> m5;
m5 -.-> m7;
m6_in -.-> m7;
m6_out -.-> m6_in;
m6_out --> m8;
m7 -.-> m6_in;
m7 --> m6_out;
m8 --> m9;
m9 --> output;
Class Signature¶
class ResNet(nn.Module):
def __init__(self, config: ResNetConfig)
Parameters¶
- config (ResNetConfig):
Configuration object that defines the residual block family, stage depths, classifier size, input channels, stem settings, stage widths, and extra per-block keyword arguments.
Attributes¶
- config (ResNetConfig):
The configuration used to construct the model.
- stem (nn.Module):
The initial stem layer that processes the input tensor.
- maxpool (nn.MaxPool2d):
Downsamples the stem activations before the residual stages.
- layer1, layer2, layer3, layer4 (nn.Sequential):
Residual stages generated from the stage depths and widths in the configuration.
- block (nn.Module):
Stores the resolved residual block class used for building the stages.
- avgpool (nn.AdaptiveAvgPool2d):
Global average pooling layer applied before the classifier.
- fc (nn.Linear):
Final classification head.
Forward Calculation¶
The forward pass of the ResNet model includes:
Stem: Initial convolutional layers for feature extraction.
Residual Stages: Four stages of residual blocks defined by ResNetConfig.layers.
Global Pooling: A global average pooling layer reduces the spatial dimensions.
Classifier: A fully connected layer maps the features to class scores.
Examples¶
Basic Example:
>>> import lucid
>>> import lucid.models as models
>>> config = models.ResNetConfig(
... block="basic",
... layers=[2, 2, 2, 2],
... num_classes=10,
... in_channels=1,
... stem_type="deep",
... stem_width=32,
... )
>>> model = models.ResNet(config)
>>> input_tensor = lucid.zeros(8, 1, 224, 224)
>>> output = model(input_tensor) # Forward pass
>>> print(output.shape)
(8, 10)
Note
Use ResNetConfig to define custom block families, stage widths, or stem settings.
Factory helpers such as resnet_50 and wide_resnet_50 create preset configurations for common historical variants.
Setting stem_type=”deep” enables the multi-layer deep stem used by some ResNet-derived models.