UNetConfig¶
- class lucid.models.UNetConfig(in_channels: int, out_channels: int, encoder_stages: tuple[lucid.models.vision.unet.UNetStageConfig, ...] | list[lucid.models.vision.unet.UNetStageConfig], decoder_stages: tuple[lucid.models.vision.unet.UNetStageConfig, ...] | list[lucid.models.vision.unet.UNetStageConfig] | None = None, bottleneck: lucid.models.vision.unet.UNetStageConfig | None = None, block: Literal['basic', 'res', 'convnext'] = 'basic', norm: Literal['batch', 'group', 'instance', 'none'] = 'batch', act: Literal['relu', 'leaky_relu', 'gelu', 'silu'] = 'relu', skip_merge: Literal['concat', 'add'] = 'concat', downsample_mode: Literal['conv', 'maxpool', 'avgpool'] = 'conv', upsample_mode: Literal['transpose', 'bilinear', 'trilinear', 'nearest'] = 'transpose', stem_channels: int | None = None, final_kernel_size: int = 1, deep_supervision: bool = False, align_corners: bool = False, bias: bool | None = None)¶
UNetConfig stores the architectural choices used by
lucid.models.UNet2d and lucid.models.UNet3d. It defines the
stage layout for the encoder and decoder along with block type, normalization,
activation, skip merge behavior, sampling strategy, and output head options.
Class Signature¶
@dataclass
class UNetConfig:
in_channels: int
out_channels: int
encoder_stages: tuple[UNetStageConfig, ...] | list[UNetStageConfig]
decoder_stages: tuple[UNetStageConfig, ...] | list[UNetStageConfig] | None = None
bottleneck: UNetStageConfig | None = None
block: Literal["basic", "res", "convnext"] = "basic"
norm: Literal["batch", "group", "instance", "none"] = "batch"
act: Literal["relu", "leaky_relu", "gelu", "silu"] = "relu"
skip_merge: Literal["concat", "add"] = "concat"
downsample_mode: Literal["conv", "maxpool", "avgpool"] = "conv"
upsample_mode: Literal["transpose", "bilinear", "trilinear", "nearest"] = "transpose"
stem_channels: int | None = None
final_kernel_size: int = 1
deep_supervision: bool = False
align_corners: bool = False
bias: bool | None = None
Parameters¶
in_channels (int): Number of channels in the input image tensor.
out_channels (int): Number of channels predicted by the final segmentation head.
encoder_stages (tuple[UNetStageConfig, …] | list[UNetStageConfig]): Stage specifications for the encoder path.
decoder_stages (tuple[UNetStageConfig, …] | list[UNetStageConfig] | None): Stage specifications for the decoder path. If None, the decoder is mirrored automatically from the encoder except for the deepest stage.
bottleneck (UNetStageConfig | None): Bottleneck stage inserted between the encoder and decoder. If None, a default bottleneck is inferred from the deepest encoder stage.
block (Literal[“basic”, “res”, “convnext”]): Block family identifier. The current implementation supports basic and res.
norm (Literal[“batch”, “group”, “instance”, “none”]): Normalization layer used inside blocks and attention blocks.
act (Literal[“relu”, “leaky_relu”, “gelu”, “silu”]): Activation function used inside blocks and the stem.
skip_merge (Literal[“concat”, “add”]): Skip connection merge strategy.
downsample_mode (Literal[“conv”, “maxpool”, “avgpool”]): Downsampling operation used between encoder stages.
upsample_mode (Literal[“transpose”, “bilinear”, “trilinear”, “nearest”]): Upsampling operation used between decoder stages. Use “trilinear” (or “bilinear”, which is automatically remapped) for
lucid.models.UNet3d.stem_channels (int | None): Output width of the input stem. If None, the first encoder stage width is used.
final_kernel_size (int): Kernel size of the final output projection.
deep_supervision (bool): Whether to attach auxiliary output heads to intermediate decoder stages.
align_corners (bool): Value passed to interpolation operations that use corner alignment semantics.
bias (bool | None): Whether convolution layers use bias terms. If None, this is inferred from the normalization choice.
Usage¶
import lucid.models as models
cfg = models.UNetConfig.from_channels(
in_channels=3,
out_channels=2,
channels=(32, 64, 128, 256),
num_blocks=(2, 2, 2, 3),
block="res",
norm="group",
upsample_mode="bilinear",
)
model = models.UNet2d(cfg)