fn

triplet_margin_with_distance_loss

Tensor
triplet_margin_with_distance_loss(anchor: Tensor, positive: Tensor, negative: Tensor, distance_function: object | None = None, margin: float = 1.0, swap: bool = False, reduction: str = 'mean')
source

Triplet margin loss with a user-supplied distance function.

Identical in form to triplet_margin_loss but lets the caller plug in any binary distance callable — useful when the embedding space is non-Euclidean (e.g., learned Mahalanobis distances, cosine distance, or hyperbolic embeddings). When no distance_function is supplied it defaults to the L2L_2 pairwise distance, matching the reference framework semantics.

Parameters

anchorTensor
Anchor embedding of shape (N,D)(N, D).
positiveTensor
Positive sample embedding of the same shape.
negativeTensor
Negative sample embedding of the same shape.
distance_functioncallable or None= None
Function (x, y) -> Tensor returning a non-negative distance of shape (N,)(N,). Defaults to L2L_2 pairwise distance.
marginfloat= 1.0
Minimum desired margin between positive and negative distances (default 1.0).
swapbool= False
Enable the Balntas-2016 anchor-swap heuristic: replace d(a,n)d(a, n) with min ⁣(d(a,n),d(p,n))\min\!\big(d(a, n), d(p, n)\big) so the harder negative drives the gradient (default False).
reductionstr= 'mean'
"mean" (default), "sum", or "none".

Returns

Tensor

Scalar or per-triplet tensor.

Notes

Per-triplet loss:

Li=max ⁣(0,  d(ai,pi)d(ai,ni)+margin)L_i = \max\!\big(0,\; d(a_i, p_i) - d(a_i, n_i) + \text{margin}\big)

The Lucid module wrapper lucid.nn.TripletMarginWithDistanceLoss forwards into this function; both surfaces are valid entry-points.

Examples

>>> import lucid
>>> from lucid.nn.functional import (
...     triplet_margin_with_distance_loss,
...     pairwise_distance,
... )
>>> def manhattan(a, b):
...     return pairwise_distance(a, b, p=1.0)
>>> a = lucid.tensor([[1.0, 0.0]])
>>> p = lucid.tensor([[1.0, 0.1]])
>>> n = lucid.tensor([[0.0, 1.0]])
>>> triplet_margin_with_distance_loss(a, p, n, distance_function=manhattan)
Tensor(0.2)