diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/decode.py index c40b01dfa..47b4f9fd0 100755 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/decode.py @@ -480,6 +480,7 @@ def main(): # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py index 6ecc8d420..0f4aad163 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py @@ -550,8 +550,8 @@ class EmformerAttention(nn.Module): Cached attention value of left context from preceding computation, with shape (L, B, D). pos_emb (torch.Tensor): - Position encoding embedding, with shape (PE, D). - For infer mode, PE = L+2*U-1. + Position encoding embedding, with shape (PE, D), + where PE = L + 2 * U - 1. Returns: A tuple containing 4 tensors: @@ -1264,6 +1264,10 @@ class EmformerEncoder(nn.Module): right_context at the end. """ U = x.size(0) - self.right_context_length + + # for query of [utterance] (i), key-value [utterance] (j), + # the max relative distance i - j is U - 1 + # the min relative distance i - j is -(U - 1) x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U) right_context = self._gen_right_context(x) @@ -1329,8 +1333,12 @@ class EmformerEncoder(nn.Module): f"expected size of {self.chunk_length + self.right_context_length} " f"for dimension 1 of x, but got {x.size(1)}." ) + pos_len = self.chunk_length + self.left_context_length neg_len = self.chunk_length + # for query of [utterance] (i), key-value [left_context, utterance] (j), + # the max relative distance i - j is L + U - 1 + # the min relative distance i - j is -(U - 1) x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len) right_context = x[self.chunk_length :]