diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index fa47698f0..0b4e5a6d8 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -17,7 +17,6 @@ from dataclasses import dataclass from typing import Dict, List, Optional -import numpy as np import torch from model import Transducer @@ -108,8 +107,11 @@ class Hypothesis: # Newly predicted tokens are appended to `ys`. ys: List[int] - # The log prob of ys - log_prob: float + # The log prob of ys. + # It contains only one entry. + # TODO(fangjun): It was a float before. We need to change its usage + # in greedy_search and beam_search. + log_prob: torch.Tensor @property def key(self) -> str: @@ -145,8 +147,10 @@ class HypothesisList(object): """ key = hyp.key if key in self: - old_hyp = self._data[key] - old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob) + old_hyp = self._data[key] # shallow copy + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -348,47 +352,70 @@ def modified_beam_search( T = encoder_out.size(1) B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) + B.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) encoder_out_len = torch.tensor([1]) decoder_out_len = torch.tensor([1]) - decoder_cache: Dict[str, torch.Tensor] = {} for t in range(T): # fmt: off current_encoder_out = encoder_out[:, t:t+1, :] + # current_encoder_out is of shape (1, 1, encoder_out_dim) # fmt: on - A = B + A = list(B) B = HypothesisList() - joint_cache: Dict[str, torch.Tensor] = {} + ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) + # ys_log_probs is of shape (num_hyps, 1) - for hyp in A: - decoder_out = run_decoder( - ys=hyp.ys, model=model, decoder_cache=decoder_cache - ) - key = "_".join(map(str, hyp.ys[-context_size:])) - key += f"-t-{t}" - log_prob = run_joiner( - key=key, - model=model, - encoder_out=current_encoder_out, - decoder_out=decoder_out, - encoder_out_len=encoder_out_len, - decoder_out_len=decoder_out_len, - joint_cache=joint_cache, - ) - log_prob = log_prob.cpu().tolist() + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyp in A], + device=device, + ) + # decoder_input is of shape (num_hyps, context_size) - for i, v in enumerate(log_prob): - if i == blank_id: - # Use [:] to make a copy - new_ys = hyp.ys[:] - else: - new_ys = hyp.ys + [i] - new_hyp = Hypothesis(ys=new_ys, log_prob=hyp.log_prob + v) - B.add(new_hyp) - B = B.topk(beam) + decoder_out = model.decoder(decoder_input, need_pad=False) + # decoder_output is of shape (num_hyps, 1, decoder_output_dim) + + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, -1 + ) + + logits = model.joiner( + current_encoder_out, + decoder_out, + encoder_out_len.expand(decoder_out.size(0)), + decoder_out_len.expand(decoder_out.size(0)), + ) + # logits is of shape (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) + + log_probs.add_(ys_log_probs) + + log_probs = log_probs.reshape(-1) + topk_log_probs, topk_indexes = log_probs.topk(beam) + + # topk_hyp_indexes are indexes into `A` + topk_hyp_indexes = topk_indexes // logits.size(-1) + topk_token_indexes = topk_indexes % logits.size(-1) + + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() + + for i in range(len(topk_hyp_indexes)): + hyp = A[topk_hyp_indexes[i]] + new_ys = hyp.ys[:] + new_token = topk_token_indexes[i] + if new_token != blank_id: + new_ys.append(new_token) + new_log_prob = topk_log_probs[i] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B.add(new_hyp) best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index dca084477..b82fed37b 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -75,24 +75,24 @@ class Decoder(nn.Module): """ Args: y: - A 2-D tensor of shape (N, U) with blank prepended. + A 2-D tensor of shape (N, U). need_pad: True to left pad the input. Should be True during training. False to not pad the input. Should be False during inference. Returns: Return a tensor of shape (N, U, embedding_dim). """ - embeding_out = self.embedding(y) + embedding_out = self.embedding(y) if self.context_size > 1: - embeding_out = embeding_out.permute(0, 2, 1) + embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embeding_out = F.pad( - embeding_out, pad=(self.context_size - 1, 0) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) ) else: # During inference time, there is no need to do extra padding # as we only need one output - assert embeding_out.size(-1) == self.context_size - embeding_out = self.conv(embeding_out) - embeding_out = embeding_out.permute(0, 2, 1) - return embeding_out + assert embedding_out.size(-1) == self.context_size + embedding_out = self.conv(embedding_out) + embedding_out = embedding_out.permute(0, 2, 1) + return embedding_out