BERTForQuestionAnswering¶
- class lucid.models.BERTForQuestionAnswering(config: BERTConfig)¶
The BERTForQuestionAnswering class predicts start and end logits for extractive question answering.
Class Signature¶
class BERTForQuestionAnswering(config: BERTConfig)
Parameters¶
config (BERTConfig): BERT configuration for token span prediction.
Methods¶
- BERTForQuestionAnswering.forward(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None) tuple[Tensor, Tensor]
Compute start and end logits for extractive answer span prediction.
- BERTForQuestionAnswering.get_loss(start_positions: Tensor, end_positions: Tensor, input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, *, reduction: str | None = 'mean') Tensor
Compute QA training loss from start and end target positions.
- BERTForQuestionAnswering.predict_spans(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None) tuple[Tensor, Tensor]
Return argmax start and end indices per sample.
- BERTForQuestionAnswering.get_best_spans(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, *, max_answer_length: int = 30) tuple[Tensor, Tensor, Tensor]
Return best start/end span candidates with scores under a max answer length.
- BERTForQuestionAnswering.get_accuracy(start_positions: Tensor, end_positions: Tensor, input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None) Tensor
Compute exact-match span accuracy (both start and end must match).
- BERTForQuestionAnswering.predict_spans_from_text(tokenizer: BERTTokenizerFast, question: str, context: str, *, device: Literal['cpu', 'gpu'] = 'cpu') tuple[Tensor, Tensor]
Predict start/end spans directly from (question, context) text pairs.
- BERTForQuestionAnswering.predict_answer_from_text(tokenizer: BERTTokenizerFast, question: str, context: str, *, device: Literal['cpu', 'gpu'] = 'cpu', max_answer_length: int = 30) str
Return decoded extractive answer text directly from (question, context).
Examples¶
>>> import lucid.models as models
>>> model = models.bert_for_question_answering_base()
>>> print(model)
BERTForQuestionAnswering(...)
>>> start_logits, end_logits = model(input_ids=input_ids, attention_mask=attention_mask)
>>> loss = model.get_loss(start_positions, end_positions, input_ids=input_ids, attention_mask=attention_mask)
>>> best_start, best_end, best_score = model.get_best_spans(input_ids=input_ids, attention_mask=attention_mask)
>>> tokenizer = models.BERTTokenizerFast.from_pretrained(".data/bert/pretrained")
>>> answer = model.predict_answer_from_text(
... tokenizer=tokenizer,
... question="What helps model performance?",
... context="Tokenization quality strongly affects language model performance.",
... device="gpu",
... )