ZFNet¶
ConvNet
- class lucid.models.ZFNet(config: ZFNetConfig)¶
The ZFNet module in lucid.models implements the Zeiler and Fergus Net, an improvement over AlexNet with smaller convolutional filters and enhanced visualization techniques for understanding feature learning. It is configured through ZFNetConfig.
%%{init: {"flowchart":{"curve":"monotoneX","nodeSpacing":50,"rankSpacing":50}} }%%
flowchart LR
linkStyle default stroke-width:2.0px
subgraph sg_m0["<span style='font-size:20px;font-weight:700'>zfnet</span>"]
style sg_m0 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
subgraph sg_m1["conv"]
style sg_m1 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
m2["Conv2d<br/><span style='font-size:11px;color:#c53030;font-weight:400'>(1,3,224,224) → (1,96,110,110)</span>"];
m3["ReLU"];
m4["MaxPool2d<br/><span style='font-size:11px;color:#b7791f;font-weight:400'>(1,96,110,110) → (1,96,54,54)</span>"];
m5["Conv2d<br/><span style='font-size:11px;color:#c53030;font-weight:400'>(1,96,54,54) → (1,256,27,27)</span>"];
m6["ReLU"];
m7["MaxPool2d<br/><span style='font-size:11px;color:#b7791f;font-weight:400'>(1,256,27,27) → (1,256,13,13)</span>"];
m8["Conv2d<br/><span style='font-size:11px;color:#c53030;font-weight:400'>(1,256,13,13) → (1,384,13,13)</span>"];
m9["ReLU"];
m10["Conv2d"];
m11["ReLU"];
m12["Conv2d<br/><span style='font-size:11px;color:#c53030;font-weight:400'>(1,384,13,13) → (1,256,13,13)</span>"];
m13["ReLU"];
m14["MaxPool2d<br/><span style='font-size:11px;color:#b7791f;font-weight:400'>(1,256,13,13) → (1,256,6,6)</span>"];
end
m15["AdaptiveAvgPool2d"];
subgraph sg_m16["fc"]
style sg_m16 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
m17["Dropout"];
m18["Linear<br/><span style='font-size:11px;color:#2b6cb0;font-weight:400'>(1,9216) → (1,4096)</span>"];
m19["ReLU"];
m20["Dropout"];
m21["Linear"];
m22["ReLU"];
m23["Linear<br/><span style='font-size:11px;color:#2b6cb0;font-weight:400'>(1,4096) → (1,1000)</span>"];
end
end
input["Input<br/><span style='font-size:11px;color:#a67c00;font-weight:400'>(1,3,224,224)</span>"];
output["Output<br/><span style='font-size:11px;color:#a67c00;font-weight:400'>(1,1000)</span>"];
style input fill:#fff3cd,stroke:#a67c00,stroke-width:1px;
style output fill:#fff3cd,stroke:#a67c00,stroke-width:1px;
style m2 fill:#ffe8e8,stroke:#c53030,stroke-width:1px;
style m3 fill:#faf5ff,stroke:#6b46c1,stroke-width:1px;
style m4 fill:#fefcbf,stroke:#b7791f,stroke-width:1px;
style m5 fill:#ffe8e8,stroke:#c53030,stroke-width:1px;
style m6 fill:#faf5ff,stroke:#6b46c1,stroke-width:1px;
style m7 fill:#fefcbf,stroke:#b7791f,stroke-width:1px;
style m8 fill:#ffe8e8,stroke:#c53030,stroke-width:1px;
style m9 fill:#faf5ff,stroke:#6b46c1,stroke-width:1px;
style m10 fill:#ffe8e8,stroke:#c53030,stroke-width:1px;
style m11 fill:#faf5ff,stroke:#6b46c1,stroke-width:1px;
style m12 fill:#ffe8e8,stroke:#c53030,stroke-width:1px;
style m13 fill:#faf5ff,stroke:#6b46c1,stroke-width:1px;
style m14 fill:#fefcbf,stroke:#b7791f,stroke-width:1px;
style m15 fill:#fefcbf,stroke:#b7791f,stroke-width:1px;
style m17 fill:#edf2f7,stroke:#4a5568,stroke-width:1px;
style m18 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
style m19 fill:#faf5ff,stroke:#6b46c1,stroke-width:1px;
style m20 fill:#edf2f7,stroke:#4a5568,stroke-width:1px;
style m21 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
style m22 fill:#faf5ff,stroke:#6b46c1,stroke-width:1px;
style m23 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
input --> m2;
m10 --> m11;
m11 --> m12;
m12 --> m13;
m13 --> m14;
m14 --> m15;
m15 --> m17;
m17 --> m18;
m18 --> m19;
m19 --> m20;
m2 --> m3;
m20 --> m21;
m21 --> m22;
m22 --> m23;
m23 --> output;
m3 --> m4;
m4 --> m5;
m5 --> m6;
m6 --> m7;
m7 --> m8;
m8 --> m9;
m9 --> m10;
Class Signature¶
class ZFNet(nn.Module):
def __init__(self, config: ZFNetConfig)
Parameters¶
config (ZFNetConfig): A configuration object describing the output class count, input channels, dropout rate, and classifier hidden dimensions.
Attributes¶
config (ZFNetConfig): The configuration used to build the model.
conv (nn.Sequential): The convolutional layers, including pooling and ReLU activations.
avgpool (nn.AdaptiveAvgPool2d): Adaptive average pooling layer to reduce spatial dimensions to (6, 6).
fc (nn.Sequential): The fully connected layers with dropout and ReLU activations for classification.
Architecture¶
The architecture of ZFNet is as follows:
Convolutional Layers: - 5 convolutional layers with ReLU activations. - Smaller kernel sizes (7x7 in the first layer) compared to AlexNet for better feature learning. - MaxPooling layers after the 1st, 2nd, and 5th convolutional layers.
Fully Connected Layers: - 2 hidden fully connected layers, each with 4096 units and ReLU activations. - Output layer with num_classes units for classification.
Regularization: - Dropout is applied to fully connected layers to reduce overfitting.
Examples¶
Basic Example
import lucid.models as models
config = models.ZFNetConfig()
model = models.ZFNet(config)
# Input tensor with shape (1, 3, 224, 224)
input_ = Tensor.randn(1, 3, 224, 224)
# Perform forward pass
output = model(input_)
print(output.shape) # Shape: (1, 1000)
Explanation
The model processes the input through its convolutional and fully connected layers, producing logits for 1000 classes.
Custom Number of Classes
config = models.ZFNetConfig(
num_classes=10,
in_channels=1,
dropout=0.25,
classifier_hidden_features=(512, 256),
)
model = models.ZFNet(config)
input_ = Tensor.randn(1, 1, 224, 224)
output = model(input_)
print(output.shape) # Shape: (1, 10)