diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 6dfe11cee..3c4500087 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -107,6 +107,7 @@ Usage: import argparse import logging +import math from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -138,6 +139,8 @@ from icefall.utils import ( write_error_stats, ) +LOG_EPS = math.log(1e-10) + def get_parser(): parser = argparse.ArgumentParser( @@ -288,7 +291,7 @@ def get_parser(): "--decode-chunk-size", type=int, default=16, - help="The chunk size for decoding (in frames after subsampling)", + help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.", ) parser.add_argument( "--left-context", @@ -370,6 +373,14 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.simulate_streaming: + if params.decode_chunk_size > 0: + # except the case of using full attention + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index f94ffef59..9bac46004 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -375,6 +375,11 @@ class Conformer(EncoderInterface): assert x.size(0) == lengths.max().item() + if chunk_size < 0: + # use full attention + chunk_size = x.size(0) + left_context = -1 + num_left_chunks = -1 if left_context >= 0: assert left_context % chunk_size == 0 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 172c9ab7c..c57514193 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -295,7 +295,7 @@ def get_parser(): "--decode-chunk-size", type=int, default=16, - help="The chunk size for decoding (in frames after subsampling)", + help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.", ) parser.add_argument( @@ -378,12 +378,14 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) + if params.decode_chunk_size > 0: + # except the case of using full attention + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index aa055049e..b39007dfc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -344,7 +344,7 @@ def get_parser(): "--decode-chunk-size", type=int, default=16, - help="The chunk size for decoding (in frames after subsampling)", + help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.", ) parser.add_argument( @@ -508,12 +508,14 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) + if params.decode_chunk_size > 0: + # except the case of using full attention + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 5ec3d3b45..79d919ab1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -326,14 +326,14 @@ def get_parser(): "--decode-chunk-size", type=int, default=16, - help="The chunk size for decoding (in frames after subsampling)", + help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.", ) parser.add_argument( "--left-context", type=int, default=64, - help="left context can be seen during decoding (in frames after subsampling)", # noqa + help="""Left context can be seen during decoding (in frames after subsampling). """, # noqa ) parser.add_argument( @@ -409,12 +409,14 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) + if params.decode_chunk_size > 0: + # except the case of using full attention + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 2be895feb..af0b2d9fc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -291,7 +291,7 @@ def get_parser(): "--decode-chunk-size", type=int, default=16, - help="The chunk size for decoding (in frames after subsampling)", + help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.", ) parser.add_argument( @@ -470,12 +470,14 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) + if params.decode_chunk_size > 0: + # except the case of using full attention + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 01e8c5b21..94d0393c2 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -358,6 +358,11 @@ class Conformer(Transformer): assert x.size(0) == lengths.max().item() + if chunk_size < 0: + # use full attention + chunk_size = x.size(0) + left_context = -1 + num_left_chunks = -1 if left_context >= 0: assert left_context % chunk_size == 0