from typing import List, Tuple import torch from torch import nn from torchmetrics.classification import MulticlassAccuracy from transformers.trainer_pt_utils import LabelSmoother IGNORE_TOKEN_ID = LabelSmoother.ignore_index import logging from utils import get_rank class EncoderProjector(nn.Module): """ The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model. Modified from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py. Args: encoder_dim (:obj:`int`): The dimension of the encoder outputs. llm_dim (:obj:`int`): The dimension of the language model. downsample_rate (:obj:`int`, `optional`, defaults to 5): The downsample rate to use. """ def __init__(self, encoder_dim, llm_dim, downsample_rate=5): super().__init__() self.downsample_rate = downsample_rate self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim) self.relu = nn.ReLU() self.linear2 = nn.Linear(llm_dim, llm_dim) def forward(self, x): batch_size, seq_len, feat_dim = x.size() num_frames_to_discard = seq_len % self.downsample_rate if num_frames_to_discard > 0: x = x[:, :-num_frames_to_discard, :] seq_len = x.size(1) x = x.contiguous() x = x.view( batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate ) x = self.linear1(x) x = self.relu(x) x = self.linear2(x) return x class SPEECH_LLM(nn.Module): """ The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector. The encoder is used to extract speech features from the input speech signal. The encoder projector is used to project the encoder outputs to the same dimension as the language model. The language model is used to generate the text from the speech features. Args: encoder (:obj:`nn.Module`): The encoder module. llm (:obj:`nn.Module`): The language model module. encoder_projector (:obj:`nn.Module`): The encoder projector module. """ def __init__( self, encoder: nn.Module, llm: nn.Module, encoder_projector: nn.Module, codec_lm: nn.Module = None, codec_lm_padding_side: str = "left", ): super().__init__() self.encoder = encoder self.llm = llm self.encoder_projector = encoder_projector self.codec_lm = codec_lm if self.codec_lm: self.speech_token_projector = nn.Linear( self.llm.config.hidden_size + self.llm.config.hidden_size, self.codec_lm.config.hidden_size, ) self.codec_lm_head = nn.Linear( self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size ) self.speech_token_projector = self.speech_token_projector.to( dtype=torch.float16 ) self.codec_lm_head = self.codec_lm_head.to(dtype=torch.float16) self.loss_fct = torch.nn.CrossEntropyLoss() self.codec_lm_padding_side = codec_lm_padding_side self.audio_accuracy_metric = MulticlassAccuracy( self.codec_lm.vocab_size, top_k=10, average="micro", multidim_average="global", ignore_index=IGNORE_TOKEN_ID, ) def _merge_input_ids_with_speech_features( self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None ): """ Merge the speech features with the input_ids and attention_mask. This is done by replacing the speech tokens with the speech features and padding the input_ids to the maximum length of the speech features. Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L277. Args: speech_features (:obj:`torch.Tensor`): The speech features to merge with the input_ids. inputs_embeds (:obj:`torch.Tensor`): The embeddings of the input_ids. input_ids (:obj:`torch.Tensor`): The input ids to merge. attention_mask (:obj:`torch.Tensor`): The attention mask to merge. labels (:obj:`torch.Tensor`, `optional`): The labels to merge. Returns: :obj:`Tuple(torch.Tensor)`: The merged embeddings, attention mask, labels and position ids. """ num_speechs, speech_len, embed_dim = speech_features.shape batch_size, sequence_length = input_ids.shape left_padding = not torch.sum( input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id) ) # 1. Create a mask to know where special speech tokens are special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1) # Compute the maximum embed dimension max_embed_dim = ( num_special_speech_tokens.max() * (speech_len - 1) ) + sequence_length batch_indices, non_speech_indices = torch.where( input_ids != self.llm.config.default_speech_token_id ) # 2. Compute the positions where text should be written # Calculate new positions for text tokens in merged speech-text sequence. # `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens. # `torch.cumsum` computes how each speech token shifts subsequent text token positions. # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. new_token_positions = ( torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1 ) nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1] if left_padding: new_token_positions += nb_speech_pad[:, None] # offset for left padding text_to_overwrite = new_token_positions[batch_indices, non_speech_indices] # 3. Create the full embedding, already padded to the maximum position final_embedding = torch.zeros( batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device, ) final_attention_mask = torch.zeros( batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device, ) if labels is not None: final_labels = torch.full( (batch_size, max_embed_dim), IGNORE_TOKEN_ID, dtype=input_ids.dtype, device=input_ids.device, ) # In case the Vision model or the Language model has been offloaded to CPU, we need to manually # set the corresponding tensors into their correct target device. target_device = inputs_embeds.device batch_indices, non_speech_indices, text_to_overwrite = ( batch_indices.to(target_device), non_speech_indices.to(target_device), text_to_overwrite.to(target_device), ) attention_mask = attention_mask.to(target_device) # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[ batch_indices, non_speech_indices ] final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[ batch_indices, non_speech_indices ] if labels is not None: final_labels[batch_indices, text_to_overwrite] = labels[ batch_indices, non_speech_indices ] # 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835) speech_to_overwrite = torch.full( (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device, ) speech_to_overwrite[batch_indices, text_to_overwrite] = False speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[ :, None ].to(target_device) if speech_to_overwrite.sum() != speech_features.shape[:-1].numel(): raise ValueError( f"The input provided to the model are wrong. The number of speech tokens is {torch.sum(special_speech_token_mask)} while" f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation." ) final_embedding[speech_to_overwrite] = ( speech_features.contiguous().reshape(-1, embed_dim).to(target_device) ) final_attention_mask |= speech_to_overwrite position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_( (final_attention_mask == 0), 1 ) # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. batch_indices, pad_indices = torch.where( input_ids == self.llm.config.pad_token_id ) indices_to_mask = new_token_positions[batch_indices, pad_indices] final_embedding[batch_indices, indices_to_mask] = 0 if labels is None: final_labels = None return final_embedding, final_attention_mask, final_labels, position_ids def forward( self, fbank: torch.Tensor = None, input_ids: torch.LongTensor = None, attention_mask: torch.Tensor = None, labels: torch.LongTensor = None, ): encoder_outs = self.encoder(fbank) speech_features = self.encoder_projector(encoder_outs) inputs_embeds = self.llm.get_input_embeddings()(input_ids) ( inputs_embeds, attention_mask, labels, _, ) = self._merge_input_ids_with_speech_features( speech_features, inputs_embeds, input_ids, attention_mask, labels ) rank = get_rank() print(f"Current rank: {rank}, input_ids: {input_ids.shape}, input_ids: {input_ids}") print(f"Current rank: {rank}, input_embeds: {inputs_embeds.shape}, input_embeds: {inputs_embeds}") print(f"Current rank: {rank}, attention_mask: {attention_mask.shape}, attention_mask: {attention_mask}") print(f"Current rank: {rank}, labels: {labels.shape}, labels: {labels}") model_outputs = self.llm( inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True, ) print(f"Current rank: {rank}, model_outputs: {model_outputs}") with torch.no_grad(): preds = torch.argmax(model_outputs.logits, -1) acc = compute_accuracy( preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=IGNORE_TOKEN_ID, ) return model_outputs.loss, acc def forward_with_speech_output( self, fbank: torch.Tensor = None, input_ids: torch.LongTensor = None, attention_mask: torch.Tensor = None, labels: torch.LongTensor = None, speech_codec_ids: torch.LongTensor = None, ): encoder_outs = self.encoder(fbank) speech_features = self.encoder_projector(encoder_outs) inputs_embeds = self.llm.get_input_embeddings()(input_ids) ( inputs_embeds, attention_mask, labels, _, ) = self._merge_input_ids_with_speech_features( speech_features, inputs_embeds, input_ids, attention_mask, labels ) input_seq_len = attention_mask.sum(dim=1) # shape, B ( text_label_start_index_list, text_input_start_index_list, input_question_len_list, ) = ([], [], []) for i in range(labels.shape[0]): input_embeds_valid_index = torch.where(attention_mask[i] != 0)[0] input_embeds_start_index = input_embeds_valid_index[0] text_labels_valid_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0] text_labels_start_index = text_labels_valid_index[0] assert ( input_seq_len[i] == input_embeds_valid_index[-1] - input_embeds_start_index + 1 ), f"input_seq_len: {input_seq_len[i]}, input_embeds_valid_index: {input_embeds_valid_index}, input_embeds_start_index: {input_embeds_start_index}" assert ( input_embeds_valid_index[-1] == text_labels_valid_index[-1] ), f"input_embeds_valid_index: {input_embeds_valid_index}, text_labels_valid_index: {text_labels_valid_index}" input_question_len = text_labels_start_index - input_embeds_start_index assert ( input_question_len + text_labels_valid_index[-1] - text_labels_start_index + 1 == input_seq_len[i] ) text_label_start_index_list.append(text_labels_start_index) text_input_start_index_list.append(input_embeds_start_index) input_question_len_list.append(input_question_len) rank = get_rank() print(f"Current rank: {rank}, input_ids: {input_ids.shape}, input_ids: {input_ids}") print(f"Current rank: {rank}, input_embeds: {inputs_embeds.shape}, input_embeds: {inputs_embeds}") print(f"Current rank: {rank}, attention_mask: {attention_mask.shape}, attention_mask: {attention_mask}") print(f"Current rank: {rank}, labels: {labels.shape}, labels: {labels}") model_outputs = self.llm( inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True, ) print(f"Current rank: {rank}, model_outputs: {model_outputs}") text_loss = model_outputs.loss delay_step = 1 # prepare codec lm inputs audio_codes_lens = [ len(x) + input_question_len_list[i] + delay_step + 1 for i, x in enumerate(speech_codec_ids) ] max_len_speech_codec = max(audio_codes_lens) if self.codec_lm_padding_side == "right": audio_codes = [ [self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) for i, x in enumerate(speech_codec_ids) ] audio_labels = [ [self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id] + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) for i, x in enumerate(speech_codec_ids) ] elif self.codec_lm_padding_side == "left": audio_codes = [ [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x for i, x in enumerate(speech_codec_ids) ] audio_labels = [ [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id] for i, x in enumerate(speech_codec_ids) ] audio_codes = torch.tensor( audio_codes, dtype=torch.int64, device=input_ids.device ) audio_labels = torch.tensor( audio_labels, dtype=torch.int64, device=input_ids.device ) audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id) audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes) text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], [] for i in range(len(text_label_start_index_list)): text_last_hidden = model_outputs.hidden_states[-1][ i, text_input_start_index_list[i] : text_input_start_index_list[i] + input_seq_len[i] - 1, ] print(233336666666, text_last_hidden, text_last_hidden.shape) text_last_hidden_lists.append(text_last_hidden) text_embed = inputs_embeds[ i, text_input_start_index_list[i] + 1 : text_input_start_index_list[i] + input_seq_len[i], ] # exclude bos text_embeds_list.append(text_embed) text_input_embeds = torch.cat( [ text_last_hidden, text_embed, ], dim=-1, ) # shape, T, D1 + D2 text_input_embeds = self.speech_token_projector( text_input_embeds ) # shape, T, D_codec text_input_embeds_list.append(text_input_embeds) for i in range(audio_embeddings.shape[0]): text_input_embeds = text_input_embeds_list[i] if self.codec_lm_padding_side == "right": audio_embeddings[i, : text_input_embeds.shape[0]] += text_input_embeds elif self.codec_lm_padding_side == "left": start_idx = torch.where( audio_codes[i] == self.codec_lm.config.mask_token_id )[0][0] start_idx_re_compute = torch.where(audio_attention_mask[i] != 0)[0][0] assert ( start_idx == start_idx_re_compute ), f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}" if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx: text_input_embeds = text_input_embeds[ : audio_embeddings.shape[1] - start_idx ] logging.warning( f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}" ) audio_embeddings[ i, start_idx : start_idx + text_input_embeds.shape[0] ] += text_input_embeds speech_outputs = self.codec_lm( attention_mask=audio_attention_mask, inputs_embeds=audio_embeddings, return_dict=True, output_hidden_states=True, ) last_hidden_state = speech_outputs.hidden_states[-1].clone() audio_logits = self.codec_lm_head(last_hidden_state) # shape, B, T, vocab_size audio_logits = audio_logits.contiguous().view( -1, self.codec_lm.config.vocab_size ) audio_labels = audio_labels.contiguous().view(-1) audio_labels = audio_labels.masked_fill( audio_labels == self.codec_lm.config.pad_token_id, IGNORE_TOKEN_ID ) codec_loss = self.loss_fct(audio_logits, audio_labels) audio_preds = torch.argmax(audio_logits, -1) with torch.no_grad(): preds = torch.argmax(model_outputs.logits, -1) print(23333444444, preds) print(233335555555, labels) acc = compute_accuracy( preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=IGNORE_TOKEN_ID, ) audio_acc = compute_accuracy( audio_preds.detach(), audio_labels.detach(), ignore_label=IGNORE_TOKEN_ID, ) audio_topk_acc = self.audio_accuracy_metric( audio_logits.detach(), audio_labels.detach() ).item() return text_loss, acc, codec_loss, audio_acc, audio_topk_acc def decode( self, fbank: torch.Tensor = None, input_ids: torch.LongTensor = None, attention_mask: torch.Tensor = None, **kwargs, ): encoder_outs = self.encoder(fbank) speech_features = self.encoder_projector(encoder_outs) speech_features = speech_features.to(torch.float16) inputs_embeds = self.llm.get_input_embeddings()(input_ids) ( inputs_embeds, attention_mask, _, _, ) = self._merge_input_ids_with_speech_features( speech_features, inputs_embeds, input_ids, attention_mask ) generated_ids = self.llm.generate( inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=kwargs.get("max_new_tokens", 1024), num_beams=kwargs.get("num_beams", 1), do_sample=kwargs.get("do_sample", True), min_length=kwargs.get("min_length", 1), top_p=kwargs.get("top_p", 0.5), top_k=kwargs.get("top_k", 20), repetition_penalty=kwargs.get("repetition_penalty", 1.1), temperature=kwargs.get("temperature", 0.7), bos_token_id=self.llm.config.bos_token_id, eos_token_id=self.llm.config.eos_token_id, pad_token_id=self.llm.config.pad_token_id, ) return generated_ids def decode_with_speech_output( self, fbank: torch.Tensor = None, input_ids: torch.LongTensor = None, # Prompt input_ids attention_mask: torch.Tensor = None, # Prompt attention_mask max_text_new_tokens: int = 1024, max_speech_new_tokens: int = 2048, # Max length for speech tokens llm_kwargs: dict = None, # Kwargs for text LLM generate codec_lm_kwargs: dict = None, # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET ) -> Tuple[torch.LongTensor, List[List[int]]]: """ Generates text and corresponding speech tokens using the revised logic. Args: fbank: Input audio features. input_ids: Input token IDs for the text prompt. attention_mask: Attention mask for the text prompt. max_text_new_tokens: Max new tokens for text generation. max_speech_new_tokens: Max new tokens for speech generation. llm_kwargs: Additional arguments for self.llm.generate. codec_lm_kwargs: Additional arguments for self.codec_lm.generate. Returns: Tuple[torch.LongTensor, List[List[int]]]: - generated_text_ids: Tensor of generated text token IDs (including prompt). - generated_speech_tokens: List of lists, where each inner list contains the generated speech codec tokens for a batch item. """ assert fbank.shape[0] == 1, "Batch size must be 1 for speech generation." if ( not self.codec_lm or not self.speech_token_projector or not self.codec_lm_head ): raise ValueError( "codec_lm and associated layers must be initialized to generate speech output." ) device = next(self.parameters()).device # Use model's device batch_size = fbank.shape[0] # --- 1. Prepare Prompt Embeddings --- encoder_outs = self.encoder(fbank) speech_features = self.encoder_projector(encoder_outs) speech_features = speech_features.to(self.llm.dtype) # Ensure matching dtype prompt_embeds = self.llm.get_input_embeddings()(input_ids) # Merge speech features with prompt embeddings ( merged_prompt_inputs_embeds, merged_prompt_attention_mask, _, _, ) = self._merge_input_ids_with_speech_features( speech_features, prompt_embeds, input_ids, attention_mask ) # --- 2. Generate Text using LLM --- # Use merged embeds/mask as input to generate # Ensure kwargs passed are suitable for llm.generate # Note: Using default generation params from `decode` if not provided in kwargs final_llm_kwargs = { "bos_token_id": self.llm.config.bos_token_id, "eos_token_id": self.llm.config.eos_token_id, "pad_token_id": self.llm.config.pad_token_id, "num_beams": 1, "do_sample": True, # Typically false for S2ST/S2TT tasks unless exploration needed "top_p": 0.5, "top_k": 20, "repetition_penalty": 1.1, "temperature": 0.7, **(llm_kwargs or {}), # User-provided kwargs override defaults } text_outputs = self.llm.generate( inputs_embeds=merged_prompt_inputs_embeds, attention_mask=merged_prompt_attention_mask, max_new_tokens=max_text_new_tokens, return_dict_in_generate=True, output_hidden_states=True, **final_llm_kwargs, ) delay_step = 1 generated_text_ids = text_outputs.sequences # [B, S_full] eos_token_id = self.llm.config.eos_token_id eos_token_embedding = self.llm.get_input_embeddings()( torch.tensor([[eos_token_id]], device=device) ) assert ( generated_text_ids[0, -1] == eos_token_id ), f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}" thinker_token_embeds_org = [ token_hidden_states[0].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states ] first_thinker_token_embed = torch.cat( [ thinker_token_embeds_org[0][:, 1:], thinker_token_embeds_org[1], ], dim=1, ) thinker_token_embeds = ( [first_thinker_token_embed] + thinker_token_embeds_org[2:] + [eos_token_embedding] ) thinker_hidden_states = [ token_hidden_states[-1].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states ] thinker_reply_part = [ torch.cat( [ thinker_hidden_state, thinker_token_embed, ], dim=-1, ) for thinker_hidden_state, thinker_token_embed in zip( thinker_hidden_states[1:], thinker_token_embeds[1:] ) ] thinker_reply_part = torch.cat(thinker_reply_part, dim=1) # thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0] thinker_prompt_part = torch.cat( [ thinker_hidden_states[0], thinker_token_embeds[0], ], dim=-1, ) thinker_prompt_part = self.speech_token_projector(thinker_prompt_part) thinker_reply_part = self.speech_token_projector(thinker_reply_part) thinker_prompt_part_seq_len = thinker_prompt_part.shape[1] talker_input_ids = torch.full( (batch_size, thinker_prompt_part_seq_len + delay_step + 1), self.codec_lm.config.mask_token_id, dtype=torch.long, device=self.llm.device, ) talker_input_ids[:, -1] = self.codec_lm.config.bos_token_id talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) thinker_input_embeds = torch.cat( [ thinker_prompt_part, thinker_reply_part[:, : delay_step + 1, :], ], dim=1, ) talker_inputs_embeds += thinker_input_embeds thinker_reply_part = thinker_reply_part[:, delay_step + 1 :, :] past_key_values = None generated_speech_tokens_list = [] next_token_ids = None for t in range(max_speech_new_tokens): if t > 0: talker_inputs_embeds = self.codec_lm.get_input_embeddings()( next_token_ids ) if thinker_reply_part.shape[1] > 0: talker_inputs_embeds += thinker_reply_part[:, :1, :] thinker_reply_part = thinker_reply_part[:, 1:, :] codec_outputs = self.codec_lm( inputs_embeds=talker_inputs_embeds, past_key_values=past_key_values, use_cache=True, return_dict=True, output_hidden_states=True, ) last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :] next_token_logits = self.codec_lm_head(last_token_hidden_state) next_token_ids = topk_sampling( next_token_logits, ) if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id: break past_key_values = codec_outputs.past_key_values # Update KV cache generated_speech_tokens_list.append( next_token_ids.squeeze(1).cpu().tolist()[0] ) return generated_text_ids, generated_speech_tokens_list def compute_accuracy(pad_outputs, pad_targets, ignore_label): """Calculate accuracy. Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py Args: pad_outputs (LongTensor): Prediction tensors (B, Lmax). pad_targets (LongTensor): Target label tensors (B, Lmax). ignore_label (int): Ignore label id. Returns: float: Accuracy value (0.0 - 1.0). """ mask = pad_targets != ignore_label numerator = torch.sum( pad_outputs.masked_select(mask) == pad_targets.masked_select(mask) ) denominator = torch.sum(mask) return numerator.float() / denominator.float() def topk_sampling( logits, top_k=50, top_p=0.95, temperature=0.8, ): if temperature != 1.0: logits = logits / temperature # Top-p/top-k filtering logits_filtered = top_k_top_p_filtering( logits.clone(), top_k=top_k, top_p=top_p, min_tokens_to_keep=2 ) # Sample probs = torch.nn.functional.softmax(logits_filtered, dim=-1) tokens = torch.multinomial(probs, num_samples=1) return tokens # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py def top_k_top_p_filtering( logits, top_k=20, top_p=0.5, filter_value=-float("Inf"), min_tokens_to_keep=1 ): """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (batch size, vocabulary size) if top_k > 0: keep only top k tokens with highest probability (top-k filtering). if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) Make sure we keep at least min_tokens_to_keep per batch example in the output From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ if top_k > 0: top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum( torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1 ) # Remove tokens with cumulative probability above the threshold (token with 0 are kept) sorted_indices_to_remove = cumulative_probs > top_p if min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter( 1, sorted_indices, sorted_indices_to_remove ) logits[indices_to_remove] = filter_value return logits