BERTForQuestionAnswering

class lucid.models.BERTForQuestionAnswering(config: BERTConfig)

The BERTForQuestionAnswering class predicts start and end logits for extractive question answering.

Class Signature

class BERTForQuestionAnswering(config: BERTConfig)

Parameters

  • config (BERTConfig): BERT configuration for token span prediction.

Methods

BERTForQuestionAnswering.forward(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None) tuple[Tensor, Tensor]

Compute start and end logits for extractive answer span prediction.

BERTForQuestionAnswering.get_loss(start_positions: Tensor, end_positions: Tensor, input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, *, reduction: str | None = 'mean') Tensor

Compute QA training loss from start and end target positions.

BERTForQuestionAnswering.predict_spans(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None) tuple[Tensor, Tensor]

Return argmax start and end indices per sample.

BERTForQuestionAnswering.get_best_spans(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, *, max_answer_length: int = 30) tuple[Tensor, Tensor, Tensor]

Return best start/end span candidates with scores under a max answer length.

BERTForQuestionAnswering.get_accuracy(start_positions: Tensor, end_positions: Tensor, input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None) Tensor

Compute exact-match span accuracy (both start and end must match).

BERTForQuestionAnswering.predict_spans_from_text(tokenizer: BERTTokenizerFast, question: str, context: str, *, device: Literal['cpu', 'gpu'] = 'cpu') tuple[Tensor, Tensor]

Predict start/end spans directly from (question, context) text pairs.

BERTForQuestionAnswering.predict_answer_from_text(tokenizer: BERTTokenizerFast, question: str, context: str, *, device: Literal['cpu', 'gpu'] = 'cpu', max_answer_length: int = 30) str

Return decoded extractive answer text directly from (question, context).

Examples

>>> import lucid.models as models
>>> model = models.bert_for_question_answering_base()
>>> print(model)
BERTForQuestionAnswering(...)
>>> start_logits, end_logits = model(input_ids=input_ids, attention_mask=attention_mask)
>>> loss = model.get_loss(start_positions, end_positions, input_ids=input_ids, attention_mask=attention_mask)
>>> best_start, best_end, best_score = model.get_best_spans(input_ids=input_ids, attention_mask=attention_mask)
>>> tokenizer = models.BERTTokenizerFast.from_pretrained(".data/bert/pretrained")
>>> answer = model.predict_answer_from_text(
...     tokenizer=tokenizer,
...     question="What helps model performance?",
...     context="Tokenization quality strongly affects language model performance.",
...     device="gpu",
... )