BERTForPreTraining

class lucid.models.BERTForPreTraining(config: BERTConfig)

The BERTForPreTraining class combines masked language modeling and next-sentence prediction heads on top of a BERT backbone.

Class Signature

class BERTForPreTraining(config: BERTConfig)

Parameters

  • config (BERTConfig): BERT configuration for encoder + pooled output.

Methods

BERTForPreTraining.forward(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None) tuple[Tensor, Tensor]

Return pretraining outputs as (prediction_scores, seq_relationship_scores), combining MLM and NSP heads.

BERTForPreTraining.get_mlm_loss(mlm_labels: Tensor, input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, *, ignore_index: int = -100, reduction: str | None = 'mean') Tensor

Compute the masked language modeling (MLM) loss from token labels.

BERTForPreTraining.get_nsp_loss(nsp_labels: Tensor, input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, *, reduction: str | None = 'mean') Tensor

Compute the next sentence prediction (NSP) loss from sequence labels.

BERTForPreTraining.get_loss(mlm_labels: Tensor, nsp_labels: Tensor, input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, *, mlm_weight: float = 1.0, nsp_weight: float = 1.0, ignore_index: int = -100, reduction: str | None = 'mean') Tensor

Compute the combined pretraining loss as a weighted sum of MLM and NSP losses.

BERTForPreTraining.get_loss_from_text(tokenizer: BERTTokenizerFast, text_a: str, text_b: str | None = None, *, nsp_label: int = 0, device: Literal['cpu', 'gpu'] = 'cpu', mask_token_id: int | None = None, mlm_probability: float = 0.15, mask_replace_prob: float = 0.8, random_replace_prob: float = 0.1, ignore_index: int = -100, reduction: str | None = 'mean', mlm_weight: float = 1.0, nsp_weight: float = 1.0) Tensor

Compute pretraining loss directly from raw text pairs using BERTTokenizerFast.

BERTForPreTraining.create_masked_lm_inputs(input_ids: Tensor, attention_mask: Tensor | None = None, special_tokens_mask: Tensor | None = None, *, mask_token_id: int = 103, mlm_probability: float = 0.15, mask_replace_prob: float = 0.8, random_replace_prob: float = 0.1, ignore_index: int = -100) tuple[Tensor, Tensor]

Create BERT-style masked inputs and MLM labels for pretraining batches.

BERTForPreTraining.predict_mlm_token_ids(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None) Tensor

Predict token ids for the MLM branch.

BERTForPreTraining.predict_nsp_labels(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None) Tensor

Predict binary NSP labels.

BERTForPreTraining.get_mlm_accuracy(mlm_labels: Tensor, input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, *, ignore_index: int = -100) Tensor

Compute masked-token accuracy for MLM labels.

BERTForPreTraining.get_nsp_accuracy(nsp_labels: Tensor, input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None) Tensor

Compute classification accuracy for NSP labels.

BERTForPreTraining.get_accuracy(mlm_labels: Tensor, nsp_labels: Tensor, input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, *, ignore_index: int = -100, mlm_weight: float = 1.0, nsp_weight: float = 1.0) tuple[Tensor, Tensor, Tensor]

Return (mlm_accuracy, nsp_accuracy, weighted_accuracy) for joint pretraining.

Examples

>>> import lucid.models as models
>>> model = models.bert_for_pre_training_base()
>>> print(model)
BERTForPreTraining(...)
>>> prediction_scores, seq_relationship_scores = model(
...     input_ids=input_ids,
...     attention_mask=attention_mask,
...     token_type_ids=token_type_ids,
... )
>>> mlm_loss = model.get_mlm_loss(
...     mlm_labels=mlm_labels,
...     input_ids=input_ids,
...     attention_mask=attention_mask,
...     token_type_ids=token_type_ids,
... )
>>> nsp_loss = model.get_nsp_loss(
...     nsp_labels=nsp_labels,
...     input_ids=input_ids,
...     attention_mask=attention_mask,
...     token_type_ids=token_type_ids,
... )
>>> total_loss = model.get_loss(
...     mlm_labels=mlm_labels,
...     nsp_labels=nsp_labels,
...     input_ids=input_ids,
...     attention_mask=attention_mask,
...     token_type_ids=token_type_ids,
... )
>>> tokenizer = models.BERTTokenizerFast.from_pretrained(".data/bert/pretrained")
>>> total_loss = model.get_loss_from_text(
...     tokenizer=tokenizer,
...     text_a="Machine learning helps us build useful systems.",
...     text_b="Tokenization quality strongly affects language model performance.",
...     nsp_label=0,
...     device="gpu",
... )
>>> masked_input_ids, mlm_labels = model.create_masked_lm_inputs(
...     input_ids=input_ids,
...     attention_mask=attention_mask,
... )
>>> mlm_pred_ids = model.predict_mlm_token_ids(
...     input_ids=masked_input_ids,
...     attention_mask=attention_mask,
...     token_type_ids=token_type_ids,
... )
>>> nsp_pred = model.predict_nsp_labels(
...     input_ids=input_ids,
...     attention_mask=attention_mask,
...     token_type_ids=token_type_ids,
... )
>>> mlm_acc = model.get_mlm_accuracy(
...     mlm_labels=mlm_labels,
...     input_ids=masked_input_ids,
...     attention_mask=attention_mask,
...     token_type_ids=token_type_ids,
... )
>>> nsp_acc = model.get_nsp_accuracy(
...     nsp_labels=nsp_labels,
...     input_ids=input_ids,
...     attention_mask=attention_mask,
...     token_type_ids=token_type_ids,
... )
>>> mlm_acc, nsp_acc, weighted_acc = model.get_accuracy(
...     mlm_labels=mlm_labels,
...     nsp_labels=nsp_labels,
...     input_ids=masked_input_ids,
...     attention_mask=attention_mask,
...     token_type_ids=token_type_ids,
... )