BERTForCausalLM¶
- class lucid.models.BERTForCausalLM(config: BERTConfig)¶
The BERTForCausalLM class configures BERT in decoder-style mode and applies a causal language modeling head.
Class Signature¶
class BERTForCausalLM(config: BERTConfig)
Parameters¶
config (BERTConfig): BERT configuration with decoder/caching options.
Methods¶
- BERTForCausalLM.forward(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, encoder_hidden_states: FloatTensor | None = None, encoder_attention_mask: Tensor | None = None, past_key_values: KVCache | EncoderDecoderCache | None = None, use_cache: bool | None = None, cache_position: Tensor | None = None) Tensor
Compute autoregressive token logits from decoder-style BERT outputs.
- BERTForCausalLM.get_loss(labels: Tensor, input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, encoder_hidden_states: FloatTensor | None = None, encoder_attention_mask: Tensor | None = None, past_key_values: KVCache | EncoderDecoderCache | None = None, use_cache: bool | None = None, cache_position: Tensor | None = None, *, shift_labels: bool = True, ignore_index: int = -100, reduction: str | None = 'mean') Tensor
Compute causal language modeling loss (with optional label shifting).
- BERTForCausalLM.predict_token_ids(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, encoder_hidden_states: FloatTensor | None = None, encoder_attention_mask: Tensor | None = None, past_key_values: KVCache | EncoderDecoderCache | None = None, use_cache: bool | None = None, cache_position: Tensor | None = None) Tensor
Return argmax token IDs for each step in the current sequence.
- BERTForCausalLM.get_accuracy(labels: Tensor, input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, encoder_hidden_states: FloatTensor | None = None, encoder_attention_mask: Tensor | None = None, past_key_values: KVCache | EncoderDecoderCache | None = None, use_cache: bool | None = None, cache_position: Tensor | None = None, *, shift_labels: bool = True, ignore_index: int = -100) Tensor
Compute token-level accuracy for causal LM targets.
- BERTForCausalLM.get_perplexity(labels: Tensor, input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, encoder_hidden_states: FloatTensor | None = None, encoder_attention_mask: Tensor | None = None, past_key_values: KVCache | EncoderDecoderCache | None = None, use_cache: bool | None = None, cache_position: Tensor | None = None, *, shift_labels: bool = True, ignore_index: int = -100) Tensor
Compute perplexity from mean causal LM loss.
- BERTForCausalLM.get_next_token_logits(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, encoder_hidden_states: FloatTensor | None = None, encoder_attention_mask: Tensor | None = None, past_key_values: KVCache | EncoderDecoderCache | None = None, use_cache: bool | None = None, cache_position: Tensor | None = None) Tensor
Return logits for the next token (last position only).
- BERTForCausalLM.predict_next_token_id(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, encoder_hidden_states: FloatTensor | None = None, encoder_attention_mask: Tensor | None = None, past_key_values: KVCache | EncoderDecoderCache | None = None, use_cache: bool | None = None, cache_position: Tensor | None = None) Tensor
Return argmax next-token IDs from last-position logits.
- BERTForCausalLM.get_loss_from_text(tokenizer: BERTTokenizerFast, text_a: str, text_b: str | None = None, *, device: Literal['cpu', 'gpu'] = 'cpu', shift_labels: bool = True, ignore_index: int = -100, reduction: str | None = 'mean') Tensor
Compute causal LM loss directly from raw text input.
- BERTForCausalLM.predict_next_token_id_from_text(tokenizer: BERTTokenizerFast, text_a: str, text_b: str | None = None, *, device: Literal['cpu', 'gpu'] = 'cpu') Tensor
Predict next-token IDs directly from raw text input.
Examples¶
>>> import lucid.models as models
>>> model = models.bert_for_causal_lm_base()
>>> print(model)
BERTForCausalLM(...)
>>> loss = model.get_loss(labels=input_ids, input_ids=input_ids, shift_labels=True)
>>> ppl = model.get_perplexity(labels=input_ids, input_ids=input_ids)
>>> next_ids = model.predict_next_token_id(input_ids=input_ids)
>>> tokenizer = models.BERTTokenizerFast.from_pretrained(".data/bert/pretrained")
>>> loss = model.get_loss_from_text(
... tokenizer=tokenizer,
... text_a="Language models predict the next token.",
... device="gpu",
... )
>>> next_ids = model.predict_next_token_id_from_text(
... tokenizer=tokenizer,
... text_a="Deep learning is",
... device="gpu",
... )