Attention U-Net 2D

ConvNet Segmentation ConvNet

class lucid.models.AttentionUNet2d(config: AttentionUNetConfig)

AttentionUNet2d is a paper-faithful attention-gated variant of lucid.models.UNet2d. It preserves the encoder-decoder segmentation structure of U-Net but inserts additive attention gates on skip connections so decoder-side gating features can suppress irrelevant encoder responses before concatenation.

For volumetric inputs with shape \((N, C, D, H, W)\), see lucid.models.AttentionUNet3d.

Oktay, Ozan, et al. “Attention U-Net: Learning Where to Look for the Pancreas.” arXiv preprint arXiv:1804.03999 (2018).

Class Signature

class AttentionUNet2d(UNet2d):
    def __init__(self, config: AttentionUNetConfig) -> None

Parameters

  • config (AttentionUNetConfig): Attention U-Net configuration describing the encoder/decoder stage layout, skip-gating strategy, and segmentation output space.

Methods

AttentionUNet2d.forward(x: Tensor) Tensor | dict[str, Tensor | list[Tensor]]

Examples

Build a Paper-Style Attention U-Net 2D

import lucid
import lucid.models as models

cfg = models.AttentionUNetConfig.from_channels(
    in_channels=1,
    out_channels=3,
    channels=(32, 64, 128, 256),
    num_blocks=2,
)
model = models.AttentionUNet2d(cfg)

x = lucid.random.randn(2, 1, 128, 128)
out = model(x)
print(out["out"].shape)  # (2, 3, 128, 128)
print(len(out["aux"]))   # 2

Customize Gate Widths

import lucid.models as models

cfg = models.AttentionUNetConfig(
    in_channels=1,
    out_channels=2,
    encoder_stages=[
        models.UNetStageConfig(channels=32, num_blocks=2),
        models.UNetStageConfig(channels=64, num_blocks=2),
        models.UNetStageConfig(channels=128, num_blocks=2),
        models.UNetStageConfig(channels=256, num_blocks=2),
    ],
    attention=models.AttentionUNetGateConfig(
        inter_channels=(32, 64, 64),
    ),
)
model = models.AttentionUNet2d(cfg)

Notes

  • AttentionUNet2d expects image tensors with shape \((N, C, H, W)\). For 3D volumetric inputs \((N, C, D, H, W)\), use lucid.models.AttentionUNet3d.

  • It is intentionally constrained to the paper-faithful setting: block=”basic”, skip_merge=”concat”, additive gates, sigmoid attention coefficients, and grid attention.

  • The current default enables deep supervision, so lucid.models.AttentionUNet2d.forward() returns a dictionary with out and aux predictions unless deep_supervision=False.