BERTForPreTraining¶
- class lucid.models.BERTForPreTraining(config: BERTConfig)¶
The BERTForPreTraining class combines masked language modeling and next-sentence prediction heads on top of a BERT backbone.
Class Signature¶
class BERTForPreTraining(config: BERTConfig)
Parameters¶
config (BERTConfig): BERT configuration for encoder + pooled output.
Methods¶
- BERTForPreTraining.forward(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None) tuple[Tensor, Tensor]
Return pretraining outputs as (prediction_scores, seq_relationship_scores), combining MLM and NSP heads.
- BERTForPreTraining.get_mlm_loss(mlm_labels: Tensor, input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, *, ignore_index: int = -100, reduction: str | None = 'mean') Tensor
Compute the masked language modeling (MLM) loss from token labels.
- BERTForPreTraining.get_nsp_loss(nsp_labels: Tensor, input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, *, reduction: str | None = 'mean') Tensor
Compute the next sentence prediction (NSP) loss from sequence labels.
- BERTForPreTraining.get_loss(mlm_labels: Tensor, nsp_labels: Tensor, input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, *, mlm_weight: float = 1.0, nsp_weight: float = 1.0, ignore_index: int = -100, reduction: str | None = 'mean') Tensor
Compute the combined pretraining loss as a weighted sum of MLM and NSP losses.
- BERTForPreTraining.get_loss_from_text(tokenizer: BERTTokenizerFast, text_a: str, text_b: str | None = None, *, nsp_label: int = 0, device: Literal['cpu', 'gpu'] = 'cpu', mask_token_id: int | None = None, mlm_probability: float = 0.15, mask_replace_prob: float = 0.8, random_replace_prob: float = 0.1, ignore_index: int = -100, reduction: str | None = 'mean', mlm_weight: float = 1.0, nsp_weight: float = 1.0) Tensor
Compute pretraining loss directly from raw text pairs using BERTTokenizerFast.
- BERTForPreTraining.create_masked_lm_inputs(input_ids: Tensor, attention_mask: Tensor | None = None, special_tokens_mask: Tensor | None = None, *, mask_token_id: int = 103, mlm_probability: float = 0.15, mask_replace_prob: float = 0.8, random_replace_prob: float = 0.1, ignore_index: int = -100) tuple[Tensor, Tensor]
Create BERT-style masked inputs and MLM labels for pretraining batches.
- BERTForPreTraining.predict_mlm_token_ids(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None) Tensor
Predict token ids for the MLM branch.
- BERTForPreTraining.predict_nsp_labels(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None) Tensor
Predict binary NSP labels.
- BERTForPreTraining.get_mlm_accuracy(mlm_labels: Tensor, input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, *, ignore_index: int = -100) Tensor
Compute masked-token accuracy for MLM labels.
- BERTForPreTraining.get_nsp_accuracy(nsp_labels: Tensor, input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None) Tensor
Compute classification accuracy for NSP labels.
- BERTForPreTraining.get_accuracy(mlm_labels: Tensor, nsp_labels: Tensor, input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, *, ignore_index: int = -100, mlm_weight: float = 1.0, nsp_weight: float = 1.0) tuple[Tensor, Tensor, Tensor]
Return (mlm_accuracy, nsp_accuracy, weighted_accuracy) for joint pretraining.
Examples¶
>>> import lucid.models as models
>>> model = models.bert_for_pre_training_base()
>>> print(model)
BERTForPreTraining(...)
>>> prediction_scores, seq_relationship_scores = model(
... input_ids=input_ids,
... attention_mask=attention_mask,
... token_type_ids=token_type_ids,
... )
>>> mlm_loss = model.get_mlm_loss(
... mlm_labels=mlm_labels,
... input_ids=input_ids,
... attention_mask=attention_mask,
... token_type_ids=token_type_ids,
... )
>>> nsp_loss = model.get_nsp_loss(
... nsp_labels=nsp_labels,
... input_ids=input_ids,
... attention_mask=attention_mask,
... token_type_ids=token_type_ids,
... )
>>> total_loss = model.get_loss(
... mlm_labels=mlm_labels,
... nsp_labels=nsp_labels,
... input_ids=input_ids,
... attention_mask=attention_mask,
... token_type_ids=token_type_ids,
... )
>>> tokenizer = models.BERTTokenizerFast.from_pretrained(".data/bert/pretrained")
>>> total_loss = model.get_loss_from_text(
... tokenizer=tokenizer,
... text_a="Machine learning helps us build useful systems.",
... text_b="Tokenization quality strongly affects language model performance.",
... nsp_label=0,
... device="gpu",
... )
>>> masked_input_ids, mlm_labels = model.create_masked_lm_inputs(
... input_ids=input_ids,
... attention_mask=attention_mask,
... )
>>> mlm_pred_ids = model.predict_mlm_token_ids(
... input_ids=masked_input_ids,
... attention_mask=attention_mask,
... token_type_ids=token_type_ids,
... )
>>> nsp_pred = model.predict_nsp_labels(
... input_ids=input_ids,
... attention_mask=attention_mask,
... token_type_ids=token_type_ids,
... )
>>> mlm_acc = model.get_mlm_accuracy(
... mlm_labels=mlm_labels,
... input_ids=masked_input_ids,
... attention_mask=attention_mask,
... token_type_ids=token_type_ids,
... )
>>> nsp_acc = model.get_nsp_accuracy(
... nsp_labels=nsp_labels,
... input_ids=input_ids,
... attention_mask=attention_mask,
... token_type_ids=token_type_ids,
... )
>>> mlm_acc, nsp_acc, weighted_acc = model.get_accuracy(
... mlm_labels=mlm_labels,
... nsp_labels=nsp_labels,
... input_ids=masked_input_ids,
... attention_mask=attention_mask,
... token_type_ids=token_type_ids,
... )