BERTForCausalLM

class lucid.models.BERTForCausalLM(config: BERTConfig)

The BERTForCausalLM class configures BERT in decoder-style mode and applies a causal language modeling head.

Class Signature

class BERTForCausalLM(config: BERTConfig)

Parameters

  • config (BERTConfig): BERT configuration with decoder/caching options.

Methods

BERTForCausalLM.forward(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, encoder_hidden_states: FloatTensor | None = None, encoder_attention_mask: Tensor | None = None, past_key_values: KVCache | EncoderDecoderCache | None = None, use_cache: bool | None = None, cache_position: Tensor | None = None) Tensor

Compute autoregressive token logits from decoder-style BERT outputs.

BERTForCausalLM.get_loss(labels: Tensor, input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, encoder_hidden_states: FloatTensor | None = None, encoder_attention_mask: Tensor | None = None, past_key_values: KVCache | EncoderDecoderCache | None = None, use_cache: bool | None = None, cache_position: Tensor | None = None, *, shift_labels: bool = True, ignore_index: int = -100, reduction: str | None = 'mean') Tensor

Compute causal language modeling loss (with optional label shifting).

BERTForCausalLM.predict_token_ids(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, encoder_hidden_states: FloatTensor | None = None, encoder_attention_mask: Tensor | None = None, past_key_values: KVCache | EncoderDecoderCache | None = None, use_cache: bool | None = None, cache_position: Tensor | None = None) Tensor

Return argmax token IDs for each step in the current sequence.

BERTForCausalLM.get_accuracy(labels: Tensor, input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, encoder_hidden_states: FloatTensor | None = None, encoder_attention_mask: Tensor | None = None, past_key_values: KVCache | EncoderDecoderCache | None = None, use_cache: bool | None = None, cache_position: Tensor | None = None, *, shift_labels: bool = True, ignore_index: int = -100) Tensor

Compute token-level accuracy for causal LM targets.

BERTForCausalLM.get_perplexity(labels: Tensor, input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, encoder_hidden_states: FloatTensor | None = None, encoder_attention_mask: Tensor | None = None, past_key_values: KVCache | EncoderDecoderCache | None = None, use_cache: bool | None = None, cache_position: Tensor | None = None, *, shift_labels: bool = True, ignore_index: int = -100) Tensor

Compute perplexity from mean causal LM loss.

BERTForCausalLM.get_next_token_logits(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, encoder_hidden_states: FloatTensor | None = None, encoder_attention_mask: Tensor | None = None, past_key_values: KVCache | EncoderDecoderCache | None = None, use_cache: bool | None = None, cache_position: Tensor | None = None) Tensor

Return logits for the next token (last position only).

BERTForCausalLM.predict_next_token_id(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, token_type_ids: LongTensor | None = None, position_ids: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, encoder_hidden_states: FloatTensor | None = None, encoder_attention_mask: Tensor | None = None, past_key_values: KVCache | EncoderDecoderCache | None = None, use_cache: bool | None = None, cache_position: Tensor | None = None) Tensor

Return argmax next-token IDs from last-position logits.

BERTForCausalLM.get_loss_from_text(tokenizer: BERTTokenizerFast, text_a: str, text_b: str | None = None, *, device: Literal['cpu', 'gpu'] = 'cpu', shift_labels: bool = True, ignore_index: int = -100, reduction: str | None = 'mean') Tensor

Compute causal LM loss directly from raw text input.

BERTForCausalLM.predict_next_token_id_from_text(tokenizer: BERTTokenizerFast, text_a: str, text_b: str | None = None, *, device: Literal['cpu', 'gpu'] = 'cpu') Tensor

Predict next-token IDs directly from raw text input.

Examples

>>> import lucid.models as models
>>> model = models.bert_for_causal_lm_base()
>>> print(model)
BERTForCausalLM(...)
>>> loss = model.get_loss(labels=input_ids, input_ids=input_ids, shift_labels=True)
>>> ppl = model.get_perplexity(labels=input_ids, input_ids=input_ids)
>>> next_ids = model.predict_next_token_id(input_ids=input_ids)
>>> tokenizer = models.BERTTokenizerFast.from_pretrained(".data/bert/pretrained")
>>> loss = model.get_loss_from_text(
...     tokenizer=tokenizer,
...     text_a="Language models predict the next token.",
...     device="gpu",
... )
>>> next_ids = model.predict_next_token_id_from_text(
...     tokenizer=tokenizer,
...     text_a="Deep learning is",
...     device="gpu",
... )