diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 5876d5158..fae1d5a96 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -93,7 +93,9 @@ def fast_beam_search( ) # fmt: on logits = model.joiner( - current_encoder_out.unsqueeze(2), decoder_out.unsqueeze(1), project_input=False + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, ) logits = logits.squeeze(1).squeeze(1) log_probs = logits.log_softmax(dim=-1) @@ -140,7 +142,6 @@ def greedy_search( encoder_out = model.joiner.encoder_proj(encoder_out) - T = encoder_out.size(1) t = 0 hyp = [blank_id] * context_size @@ -163,9 +164,9 @@ def greedy_search( # fmt: off current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) # fmt: on - logits = model.joiner(current_encoder_out, - decoder_out.unsqueeze(1), - project_input=False) + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1), project_input=False + ) # logits is (1, 1, 1, vocab_size) y = logits.argmax().item() @@ -228,8 +229,9 @@ def greedy_search_batch( for t in range(T): current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1), - project_input=False) + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1), project_input=False + ) # logits'shape (batch_size, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) @@ -466,7 +468,6 @@ def modified_beam_search( decoder_out = model.joiner.decoder_proj(decoder_out) # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor # as index, so we use `to(torch.int64)` below. current_encoder_out = torch.index_select( @@ -720,7 +721,7 @@ def beam_search( logits = model.joiner( current_encoder_out, decoder_out.unsqueeze(1), - project_input=False + project_input=False, ) # TODO(fangjun): Scale the blank posterior diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index c23568ae9..b6d94aaf1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -17,9 +17,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch import Tensor -from typing import Optional -from scaling import ScaledConv1d, ScaledLinear, ScaledEmbedding +from scaling import ScaledConv1d, ScaledEmbedding class Decoder(nn.Module):