DETR

Transformer Detection Transformer

class lucid.models.DETR(config: DETRConfig)

DETR (DEtection TRansformer) is a fully end-to-end object detector that replaces hand-crafted components (anchors, NMS) with a Transformer encoder-decoder. It predicts a fixed set of objects via learned object queries and trains with bipartite (Hungarian) matching and a set-based loss (classification + box L1 + GIoU).

Note

This implementation follows the baseline DETR: ResNet backbone -> 1x1 conv to d_model -> Transformer (6 enc/6 dec by default) -> class and box heads. It supports auxiliary losses from intermediate decoder layers.

        %%{init: {"flowchart":{"curve":"monotoneX","nodeSpacing":50,"rankSpacing":50}} }%%
flowchart LR
  linkStyle default stroke-width:2.0px
  subgraph sg_m0["<span style='font-size:20px;font-weight:700'>detr_r50</span>"]
  style sg_m0 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
    subgraph sg_m1["_BackboneBase"]
    style sg_m1 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
      subgraph sg_m2["body"]
        direction TB;
      style sg_m2 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
        subgraph sg_m3["stem"]
          direction TB;
        style sg_m3 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
          m4["Conv2d<br/><span style='font-size:11px;color:#c53030;font-weight:400'>(1,3,224,224) → (1,64,112,112)</span>"];
          m5["BatchNorm2d"];
          m6["ReLU"];
        end
        m7["MaxPool2d<br/><span style='font-size:11px;color:#b7791f;font-weight:400'>(1,64,112,112) → (1,64,56,56)</span>"];
        subgraph sg_m8["layer1 x 4"]
          direction TB;
        style sg_m8 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
          m8_in(["Input"]);
          m8_out(["Output"]);
  style m8_in fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
  style m8_out fill:#e2e8f0,stroke:#64748b,stroke-width:1px;
          m9(["_Bottleneck x 3<br/><span style='font-size:11px;font-weight:400'>(1,64,56,56) → (1,256,56,56)</span>"]);
        end
      end
    end
    m10["_SpatialPosEncoding<br/><span style='font-size:11px;font-weight:400'>(1,7,7) → (1,256,7,7)</span>"];
    subgraph sg_m11["_Transformer"]
    style sg_m11 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
      subgraph sg_m12["_TransformerEncoder"]
        direction TB;
      style sg_m12 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
        subgraph sg_m13["layers"]
          direction TB;
        style sg_m13 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
          m14(["_TransformerEncoderLayer x 6"]);
        end
        m15["LayerNorm"];
      end
      subgraph sg_m16["_TransformerDecoder"]
        direction TB;
      style sg_m16 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
        subgraph sg_m17["layers"]
          direction TB;
        style sg_m17 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
          m18(["_TransformerDecoderLayer x 6<br/><span style='font-size:11px;font-weight:400'>(1,100,256)x2 → (1,100,256)</span>"]);
        end
        m19["LayerNorm"];
      end
    end
    m20["Embedding"];
    m21["Conv2d<br/><span style='font-size:11px;color:#c53030;font-weight:400'>(1,2048,7,7) → (1,256,7,7)</span>"];
    m22["Linear<br/><span style='font-size:11px;color:#2b6cb0;font-weight:400'>(1,100,256) → (1,100,92)</span>"];
    subgraph sg_m23["_MLP"]
    style sg_m23 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
      subgraph sg_m24["layers"]
        direction TB;
      style sg_m24 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
        m25["Linear"];
        m26["ReLU"];
        m27["Linear"];
        m28["ReLU"];
        m29["Linear<br/><span style='font-size:11px;color:#2b6cb0;font-weight:400'>(1,100,256) → (1,100,4)</span>"];
      end
    end
    m30["_HungarianMatcher"];
  end
  input["Input<br/><span style='font-size:11px;color:#a67c00;font-weight:400'>(1,3,224,224)</span>"];
  output["Output<br/><span style='font-size:11px;color:#a67c00;font-weight:400'>(1,100,92)x12</span>"];
  style input fill:#fff3cd,stroke:#a67c00,stroke-width:1px;
  style output fill:#fff3cd,stroke:#a67c00,stroke-width:1px;
  style m4 fill:#ffe8e8,stroke:#c53030,stroke-width:1px;
  style m5 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m6 fill:#faf5ff,stroke:#6b46c1,stroke-width:1px;
  style m7 fill:#fefcbf,stroke:#b7791f,stroke-width:1px;
  style m15 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m19 fill:#e6fffa,stroke:#2c7a7b,stroke-width:1px;
  style m20 fill:#f1f5f9,stroke:#475569,stroke-width:1px;
  style m21 fill:#ffe8e8,stroke:#c53030,stroke-width:1px;
  style m22 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  style m25 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  style m26 fill:#faf5ff,stroke:#6b46c1,stroke-width:1px;
  style m27 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  style m28 fill:#faf5ff,stroke:#6b46c1,stroke-width:1px;
  style m29 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  input --> m4;
  m10 --> m21;
  m14 --> m15;
  m15 -.-> m18;
  m18 --> m19;
  m19 -.-> m18;
  m19 -.-> m22;
  m21 --> m14;
  m22 --> m25;
  m25 --> m26;
  m26 --> m27;
  m27 --> m28;
  m28 --> m29;
  m29 -.-> m22;
  m29 --> output;
  m4 --> m5;
  m5 --> m6;
  m6 --> m7;
  m7 -.-> m9;
  m8_in -.-> m9;
  m8_out --> m10;
  m8_out -.-> m8_in;
  m9 -.-> m8_in;
  m9 --> m8_out;
    

Class Signature

class DETR(nn.Module):
    def __init__(self, config: DETRConfig) -> None

Parameters

  • config (DETRConfig): Configuration object that packages the backbone, transformer, query count, loss coefficients, and optional matcher used to build the detector.

Configuration

  • backbone (nn.Module): CNN feature extractor (e.g., ResNet-50/101). It must expose a positive integer num_channels attribute and return a stride-32 feature map of shape (B, C_backbone, H, W).

  • transformer (nn.Module): Encoder-decoder with hidden size d_model (e.g., 256). It must expose a positive integer d_model attribute and return DETR decoder states from its forward.

  • num_classes (int): Number of foreground categories (COCO: 91).

  • num_queries (int, default 100): Learned object queries; maximum detections per image.

  • aux_loss (bool, default True): If True, returns and trains on intermediate decoder outputs.

  • matcher (nn.Module | None): Bipartite matcher used during training. Defaults to the standard DETR Hungarian matcher if None.

  • class_loss_coef (float, default 1.0): Weight for classification loss.

  • bbox_loss_coef (float, default 5.0): Weight for L1 box loss.

  • giou_loss_coef (float, default 2.0): Weight for generalized IoU loss.

  • eos_coef (float, default 0.1): Weight for the “no-object” class.

Inputs

  • x (Tensor): Input image batch of shape (B, 3, H, W). Images are typically resized so the short side ~ 800 and long side \(\le\) 1312, then padded to multiples of 32.

  • mask (Tensor[bool], optional): Padding mask (B, H, W) where True marks padded (invalid) pixels. If omitted, an all-False mask is assumed.

Targets (Training)

Provide a list of length B. Each element is a dict with:

  • “class_id”: Tensor[int64] of shape (N_i,) with class indices in [0, num_classes-1].

  • “box”: Tensor[float32] of shape (N_i, 4) with normalized boxes in (cx, cy, w, h) format, values in [0, 1] relative to the input image size (after any resize/pad).

Important

Boxes must be center-x, center-y, width, height in [0,1]. If your dataset is in pixels and/or xyxy, convert before passing targets.

Returns

Evaluation / Inference (`aux_loss=False`)

  • pred_logits: (B, num_queries, num_classes + 1) - raw class scores; the last channel is “no-object”.

  • pred_boxes: (B, num_queries, 4) - normalized (cx, cy, w, h) in [0,1].

Training / Aux mode (`aux_loss=True`)

  • A list: intermediate decoder outputs followed by the final output. Each element is a tuple: (pred_logits_l, pred_boxes_l) with the same shapes as above.

Loss & Matching

Training uses Hungarian matching between predictions and ground-truth objects with the following costs:

  • Classification (CE) with “no-object” eos_coef (e.g., 0.1).

  • L1 box loss (weight = bbox_loss_coef, default 5.0).

  • Generalized IoU loss (weight = giou_loss_coef, default 2.0).

The total DETR loss is the weighted sum over matched pairs; unmatched queries are trained toward “no-object”.

Details

  • Backbone projection: a 1x1 conv maps backbone channels to d_model before flattening.

  • Positional encoding: sine-cosine 2D encoding with num_pos_feats = d_model // 2; shape (B, d_model, H, W).

  • Queries: learnable embeddings of shape (num_queries, d_model).

  • Decoder outputs: each query yields one class distribution and one box.

  • Auxiliary losses: if enabled, identical heads are applied to intermediate decoder layers (deep supervision).

Methods

DETR.forward(x: Tensor, mask: Tensor | None = None) tuple[Tensor, Tensor] | list[tuple[Tensor, Tensor]]
DETR.predict(x: Tensor, k: int = 100) list[list[DetectionDict]]
DETR.get_loss(x: Tensor, targets: Sequence[dict[str, Tensor]], mask: Tensor | None = None) Tensor

Examples

Forward (inference)

import lucid
from lucid.models import detr_r50

model = detr_r50(pretrained_backbone=True, num_classes=91, num_queries=100, aux_loss=False).eval()
x = lucid.random.randn(1, 3, 800, 800)
logits, boxes = model(x)  # (1, 100, 92), (1, 100, 4)
probs = lucid.softmax(logits, axis=-1)[..., :-1]   # drop no-object
scores = probs.max(axis=-1)                        # (1, 100)

Targets format

# Single image example (B=1) with two objects
target = {
    "class_id": lucid.tensor([5, 17], dtype=lucid.Int64),
    "box": lucid.tensor([[0.52, 0.44, 0.20, 0.25],
                         [0.28, 0.70, 0.15, 0.10]], dtype=lucid.Float32)
}
targets = [target]

Training step (sketch)

model.train()

# Example loss call
loss = model.get_loss(x, targets)

Note

  • No NMS: DETR does not require non-maximum suppression at inference time.

  • num_queries: controls the maximum number of detections per image (default 100).

  • Input size: commonly short side ~ 800, long side \(\le\) 1312, padded to /32. The model supports arbitrary sizes.