From 10998bef69ba2d47c82f2be226754eaaa996534e Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Tue, 31 May 2022 18:53:20 +0800 Subject: [PATCH] add tail padding --- .../conv_emformer_transducer_stateless/decode.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py index dd6885b19..cb5c398c1 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py @@ -59,6 +59,7 @@ Usage: import argparse import logging +import math from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -70,7 +71,7 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from beam_search import ( beam_search, - fast_beam_search, + fast_beam_search_one_best, greedy_search, greedy_search_batch, modified_beam_search, @@ -91,6 +92,8 @@ from icefall.utils import ( write_error_stats, ) +LOG_EPS = math.log(1e-10) + def get_parser(): parser = argparse.ArgumentParser( @@ -262,13 +265,20 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) + feature_lens += params.right_context_length + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.right_context_length), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens = model.encoder( x=feature, x_lens=feature_lens ) hyps = [] if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search( + hyp_tokens = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -294,6 +304,7 @@ def decode_one_batch( hyp_tokens = modified_beam_search( model=model, encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, beam=params.beam_size, ) for hyp in sp.decode(hyp_tokens):