From d7be9bd9c5c98d8093da577fff5a5312f5a6ee48 Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 7 Jun 2022 12:00:26 +0800 Subject: [PATCH] Minor fixes --- .../ASR/pruned_transducer_stateless/decode.py | 22 ++--- .../decode_stream.py | 9 +- .../ASR/pruned_transducer_stateless/export.py | 10 +-- .../ASR/pruned_transducer_stateless/model.py | 26 +----- .../streaming_decode.py | 64 ++++++++------ .../ASR/pruned_transducer_stateless/train.py | 42 +-------- .../pruned_transducer_stateless2/conformer.py | 61 +++++++++---- .../pruned_transducer_stateless2/decode.py | 25 ++---- .../pruned_transducer_stateless2/export.py | 10 +-- .../ASR/pruned_transducer_stateless2/model.py | 26 +----- .../streaming_decode.py | 25 +++--- .../ASR/pruned_transducer_stateless2/train.py | 46 +--------- .../pruned_transducer_stateless3/decode.py | 20 ++--- .../pruned_transducer_stateless3/export.py | 10 +-- .../streaming_decode.py | 32 +++---- .../ASR/pruned_transducer_stateless3/train.py | 37 +------- .../pruned_transducer_stateless4/decode.py | 23 ++--- .../streaming_decode.py | 30 +++---- .../ASR/pruned_transducer_stateless4/train.py | 45 +--------- .../ASR/transducer_stateless/conformer.py | 85 ++++++++++++++++--- icefall/utils.py | 7 +- 21 files changed, 246 insertions(+), 409 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index c5e3465e9..6b048519a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -73,7 +73,7 @@ Usage: --avg 15 \ --simulate-streaming 1 \ --causal-convolution 1 \ - --right-chunk-size 16 \ + --decode-chunk-size 16 \ --left-context 64 \ --exp-dir ./pruned_transducer_stateless/exp \ --max-duration 600 \ @@ -302,16 +302,9 @@ def get_parser(): test a streaming model. """, ) + parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - parser.add_argument( - "--right-chunk-size", + "--decode-chunk-size", type=int, default=16, help="The chunk size for decoding (in frames after subsampling)", @@ -379,7 +372,7 @@ def decode_one_batch( x=feature, x_lens=feature_lens, states=[], - chunk_size=params.right_chunk_size, + chunk_size=params.decode_chunk_size, left_context=params.left_context, simulate_streaming=True, ) @@ -610,7 +603,7 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if params.simulate_streaming: - params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}" + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" params.suffix += f"-left-context-{params.left_context}" if "fast_beam_search" in params.decoding_method: @@ -646,9 +639,8 @@ def main(): params.vocab_size = sp.get_piece_size() if params.simulate_streaming: - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" + # Decoding in streaming requires causal convolution" + params.causal_convolution = True logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py index 9bbbf57ee..050bef60a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py @@ -37,6 +37,7 @@ class DecodeStream(object): `get_init_state` in conformer.py decoding_graph: Decoding graph used for decoding, may be a TrivialGraph or a HLG. + Used only when decoding_method is fast_beam_search. device: The device to run this stream. """ @@ -49,7 +50,8 @@ class DecodeStream(object): # It contains a 2-D tensors representing the feature frames. self.features: torch.Tensor = None - # how many frames are processed. (before subsampling). + # how many frames have been processed. (before subsampling). + # we only modify this value in `func:get_feature_frames`. self.num_processed_frames: int = 0 self._done: bool = False # The transcript of current utterance. @@ -57,6 +59,9 @@ class DecodeStream(object): # The decoding result (partial or final) of current utterance. self.hyp: List = [] + # how many frames have been processed, after subsampling (i.e. a + # cumulative sum of the second return value of + # encoder.streaming_forward self.feature_len: int = 0 if params.decoding_method == "greedy_search": @@ -69,7 +74,7 @@ class DecodeStream(object): else: assert ( False - ), f"Decoding method :{params.decoding_method} do not support" + ), f"Decoding method :{params.decoding_method} do not support." @property def done(self) -> bool: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index d05bef337..4c4c0ee2d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -149,14 +149,6 @@ def get_parser(): are streaming model, this should be True. """, ) - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - exporting a streaming model. - """, - ) return parser @@ -183,7 +175,7 @@ def main(): params.vocab_size = sp.get_piece_size() if params.streaming_model: - assert params.causal_convolution + params.causal_convolution = True logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/model.py b/egs/librispeech/ASR/pruned_transducer_stateless/model.py index 1c83adc44..2f019bcdb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/model.py @@ -66,8 +66,6 @@ class Transducer(nn.Module): prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, - delay_penalty: float = 0.0, - return_sym_delay: bool = False, ) -> torch.Tensor: """ Args: @@ -138,31 +136,10 @@ class Transducer(nn.Module): lm_only_scale=lm_scale, am_only_scale=am_scale, boundary=boundary, - delay_penalty=delay_penalty, reduction="sum", return_grad=True, ) - sym_delay = None - if return_sym_delay: - B, S, T0 = px_grad.shape - T = T0 - 1 - if boundary is None: - offset = torch.tensor( - (T - 1) / 2, - dtype=px_grad.dtype, - device=px_grad.device, - ).expand(B, 1, 1) - total_syms = S * B - else: - offset = (boundary[:, 3] - 1) / 2 - total_syms = torch.sum(boundary[:, 2]) - offset = torch.arange(T0, device=px_grad.device).reshape( - 1, 1, T0 - ) - offset.reshape(B, 1, 1) - sym_delay = px_grad * offset - sym_delay = torch.sum(sym_delay) / total_syms - # ranges : [B, T, prune_range] ranges = k2.get_rnnt_prune_ranges( px_grad=px_grad, @@ -186,8 +163,7 @@ class Transducer(nn.Module): ranges=ranges, termination_symbol=blank_id, boundary=boundary, - delay_penalty=delay_penalty, reduction="sum", ) - return (simple_loss, pruned_loss, sym_delay) + return (simple_loss, pruned_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index 1f3fa79b7..f61d10b20 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -20,9 +20,12 @@ Usage: ./pruned_transducer_stateless2/streaming_decode.py \ --epoch 28 \ --avg 15 \ + --decode-chunk-size 8 \ + --left-context 32 \ + --right-context 2 \ --exp-dir ./pruned_transducer_stateless2/exp \ --decoding_method greedy_search \ - --num-decode-streams 200 + --num-decode-streams 1000 """ import argparse @@ -182,15 +185,6 @@ def get_parser(): """, ) - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=True, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - parser.add_argument( "--decode-chunk-size", type=int, @@ -205,6 +199,13 @@ def get_parser(): help="left context can be seen during decoding (in frames after subsampling)", ) + parser.add_argument( + "--right-context", + type=int, + default=4, + help="right context can be seen during decoding (in frames after subsampling)", + ) + parser.add_argument( "--num-decode-streams", type=int, @@ -343,14 +344,18 @@ def decode_one_chunk( processed_feature_lens = [] for stream in decode_streams: + # we plus 2 here because we will cut off one frame on each size of + # encoder_embed output as they see invalid paddings. so we need extra 2 + # frames. feat, feat_len = stream.get_feature_frames( - params.decode_chunk_size * params.subsampling_factor + (params.decode_chunk_size + 2 + params.right_context) + * params.subsampling_factor ) features.append(feat) feature_lens.append(feat_len) states.append(stream.states) + processed_feature_lens.append(stream.feature_len) if params.decoding_method == "fast_beam_search": - processed_feature_lens.append(stream.feature_len) rnnt_stream_list.append(stream.rnnt_decoding_stream) feature_lens = torch.tensor(feature_lens, device=device) @@ -358,15 +363,21 @@ def decode_one_chunk( # if T is less than 7 there will be an error in time reduction layer, # because we subsample features with ((x_len - 1) // 2 - 1) // 2 - if features.size(1) < 7: - feature_lens += 7 - features.size(1) + # we plus 2 here because we will cut off one frame on each size of + # encoder_embed output as they see invalid paddings. so we need extra 2 + # frames. + tail_length = 7 + (2 + params.right_context) * params.subsampling_factor + if features.size(1) < tail_length: + feature_lens += tail_length - features.size(1) features = torch.cat( [ features, torch.tensor( LOG_EPS, dtype=features.dtype, device=device ).expand( - features.size(0), 7 - features.size(1), features.size(2) + features.size(0), + tail_length - features.size(1), + features.size(2), ), ], dim=1, @@ -377,12 +388,16 @@ def decode_one_chunk( torch.stack([x[1] for x in states], dim=2), ] + processed_feature_lens = torch.tensor(processed_feature_lens, device=device) + # Note: states will be modified in streaming_forward. encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( x=features, x_lens=feature_lens, states=states, left_context=params.left_context, + right_context=params.right_context, + processed_lens=processed_feature_lens, ) if params.decoding_method == "greedy_search": @@ -395,9 +410,6 @@ def decode_one_chunk( max_contexts=params.max_contexts, max_states=params.max_states, ) - processed_feature_lens = torch.tensor( - processed_feature_lens, device=device - ) decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) processed_lens = processed_feature_lens + encoder_out_lens hyp_tokens = fast_beam_search( @@ -411,8 +423,8 @@ def decode_one_chunk( finished_streams = [] for i in range(len(decode_streams)): decode_streams[i].states = [states[0][i], states[1][i]] + decode_streams[i].feature_len += encoder_out_lens[i] if params.decoding_method == "fast_beam_search": - decode_streams[i].feature_len += encoder_out_lens[i] decode_streams[i].hyp = hyp_tokens[i] if decode_streams[i].done: finished_streams.append(i) @@ -457,12 +469,14 @@ def decode_dataset( opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 - log_interval = 300 + log_interval = 50 decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -536,7 +550,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): @@ -597,6 +611,7 @@ def main(): # for streaming params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" params.suffix += f"-left-context-{params.left_context}" + params.suffix += f"-right-context-{params.right_context}" # for fast_beam_search if params.decoding_method == "fast_beam_search": @@ -620,10 +635,7 @@ def main(): params.blank_id = sp.piece_to_id("") params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() - - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" + params.causal_convolution = True logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index bcb78414f..7dcbc23ec 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -37,7 +37,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --exp-dir pruned_transducer_stateless/exp \ --full-libri 1 \ --dynamic-chunk-training 1 \ - --causal-convolution 1 \ --short-chunk-size 25 \ --num-left-chunks 4 \ --max-duration 300 @@ -244,15 +243,6 @@ def get_parser(): """, ) - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - parser.add_argument( "--short-chunk-size", type=int, @@ -269,25 +259,6 @@ def get_parser(): help="How many left context can be seen in chunks when calculating attention.", ) - parser.add_argument( - "--delay-penalty", - type=float, - default=0.0, - help="""A constant value to penalize symbol delay, this may be - needed when training with time masking, to avoid the time masking - encouraging the network to delay symbols. - """, - ) - - parser.add_argument( - "--return-sym-delay", - type=str2bool, - default=False, - help="""Whether to return `sym_delay` during training, this is a stat - to measure symbols emission delay, especially for time masking training. - """, - ) - return parser @@ -554,17 +525,14 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y).to(device) - sym_delay = None with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, sym_delay = model( + simple_loss, pruned_loss = model( x=feature, x_lens=feature_lens, y=y, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, - delay_penalty=params.delay_penalty, - return_sym_delay=params.return_sym_delay, ) loss = params.simple_loss_scale * simple_loss + pruned_loss @@ -582,9 +550,6 @@ def compute_loss( info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() - if sym_delay is not None: - info["sym_delay"] = sym_delay.detatch().cpu().item() - return loss, info @@ -839,9 +804,8 @@ def run(rank, world_size, args): params.vocab_size = sp.get_piece_size() if params.dynamic_chunk_training: - assert ( - params.causal_convolution - ), "dynamic_chunk_training requires causal convolution" + # dynamic_chunk_training requires causal convolution + params.causal_convolution = True logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index bb786c775..e28b5034d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -279,19 +279,26 @@ class Conformer(EncoderInterface): The chunk size for decoding, this will be used to simulate streaming decoding using masking. left_context: - How many old frames the attention can see in current chunk, it MUST - be equal to left_context in decode_states. + How many previous frames the attention can see in current chunk. + Note: It's not that each individual frame has `left_context` frames + of left context, some have more. + right_context: + How many future frames the attention can see in current chunk. + Note: It's not that each individual frame has `right_context` frames + of right context, some have more. simulate_streaming: If setting True, it will use a masking strategy to simulate streaming fashion (i.e. every chunk data only see limited left context and right context). The whole sequence is supposed to be send at a time When using simulate_streaming. + processed_lens: + How many frames (after subsampling) have been processed for each sequence. Returns: Return a tuple containing 2 tensors: - logits, its shape is (batch_size, output_seq_len, output_dim) - logit_lens, a tensor of shape (batch_size,) containing the number of frames in `logits` before padding. - - decode_states, the updated DecodeStates including the information + - decode_states, the updated states including the information of current chunk. """ @@ -321,8 +328,6 @@ class Conformer(EncoderInterface): {(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)}, given {states[1].shape}.""" - # src_key_padding_mask = make_pad_mask(lengths + left_context) - lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output src_key_padding_mask = make_pad_mask(lengths) @@ -341,6 +346,8 @@ class Conformer(EncoderInterface): embed = self.encoder_embed(x) + # cut off 1 frame on each size of embed as they see the padding + # value which causes a training and decoding mismatch. embed = embed[:, 1:-1, :] embed, pos_enc = self.encoder_pos(embed, left_context) @@ -359,7 +366,8 @@ class Conformer(EncoderInterface): x = x[0:-right_context, ...] lengths -= right_context else: - + # this branch simulates streaming decoding using mask as we are + # using in training time. src_key_padding_mask = make_pad_mask(lengths) x = self.encoder_embed(x) x, pos_emb = self.encoder_pos(x) @@ -558,9 +566,14 @@ class ConformerEncoderLayer(nn.Module): src_key_padding_mask: the mask for the src keys per batch (optional). warmup: controls selective bypass of of layers; if < 1.0, we will bypass layers more frequently. - left_context: left context (in frames) used during streaming decoding. - this is used only in real streaming decoding, in other circumstances, - it MUST be 0. + left_context: + How many previous frames the attention can see in current chunk. + Note: It's not that each individual frame has `left_context` frames + of left context, some have more. + right_context: + How many future frames the attention can see in current chunk. + Note: It's not that each individual frame has `right_context` frames + of right context, some have more. Shape: src: (S, N, E). @@ -708,10 +721,14 @@ class ConformerEncoder(nn.Module): src_key_padding_mask: the mask for the src keys per batch (optional). warmup: controls selective bypass of of layers; if < 1.0, we will bypass layers more frequently. - left_context: left context (in frames) used during streaming decoding. - this is used only in real streaming decoding, in other circumstances, - it MUST be 0. - + left_context: + How many previous frames the attention can see in current chunk. + Note: It's not that each individual frame has `left_context` frames + of left context, some have more. + right_context: + How many future frames the attention can see in current chunk. + Note: It's not that each individual frame has `right_context` frames + of right context, some have more. Shape: src: (S, N, E). pos_emb: (N, 2*(S+left_context)-1, E). @@ -1273,9 +1290,17 @@ class RelPositionMultiheadAttention(nn.Module): and attn_mask.dtype == torch.bool and key_padding_mask is not None ): - combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) + if attn_mask.size(0) != 1: + attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) + combined_mask = attn_mask | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze( + 0 + ) | key_padding_mask.unsqueeze(1).unsqueeze(2) + attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) @@ -1404,6 +1429,10 @@ class ConvolutionModule(nn.Module): x: Input tensor (#time, batch, channels). cache: The cache of depthwise_conv, only used in real streaming decoding. + right_context: + How many future frames the attention can see in current chunk. + Note: It's not that each individual frame has `right_context` frames + of right context, some have more. Returns: If cache is None return the output tensor (#time, batch, channels). diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 57f05c6bf..96d5f12e2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -59,8 +59,7 @@ Usage: --epoch 28 \ --avg 15 \ --simulate-streaming 1 \ - --causal-convolution 1 \ - --right-chunk-size 16 \ + --decode-chunk-size 16 \ --left-context 64 \ --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 600 \ @@ -257,16 +256,7 @@ def get_parser(): ) parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - - parser.add_argument( - "--right-chunk-size", + "--decode-chunk-size", type=int, default=16, help="The chunk size for decoding (in frames after subsampling)", @@ -335,11 +325,11 @@ def decode_one_batch( ) if params.simulate_streaming: - encoder_out, encoder_out_lens = model.encoder.streaming_forward( + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, states=[], - chunk_size=params.right_chunk_size, + chunk_size=params.decode_chunk_size, left_context=params.left_context, simulate_streaming=True, ) @@ -561,7 +551,7 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if params.simulate_streaming: - params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}" + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" params.suffix += f"-left-context-{params.left_context}" if "fast_beam_search" in params.decoding_method: @@ -594,9 +584,8 @@ def main(): params.vocab_size = sp.get_piece_size() if params.simulate_streaming: - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" + # Decoding in streaming requires causal convolution + params.causal_convolution = True logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index c018bfa03..3c6424359 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -164,14 +164,6 @@ def get_parser(): are streaming model, this should be True. """, ) - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - exporting a streaming model. - """, - ) return parser @@ -197,7 +189,7 @@ def main(): params.vocab_size = sp.get_piece_size() if params.streaming_model: - assert params.causal_convolution + params.causal_convolution = True logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 4b787363e..2434fd41d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -78,8 +78,6 @@ class Transducer(nn.Module): am_scale: float = 0.0, lm_scale: float = 0.0, warmup: float = 1.0, - delay_penalty: float = 0.0, - return_sym_delay: bool = False, ) -> torch.Tensor: """ Args: @@ -157,31 +155,10 @@ class Transducer(nn.Module): lm_only_scale=lm_scale, am_only_scale=am_scale, boundary=boundary, - delay_penalty=delay_penalty, reduction="sum", return_grad=True, ) - sym_delay = None - if return_sym_delay: - B, S, T0 = px_grad.shape - T = T0 - 1 - if boundary is None: - offset = torch.tensor( - (T - 1) / 2, - dtype=px_grad.dtype, - device=px_grad.device, - ).expand(B, 1, 1) - total_syms = S * B - else: - offset = (boundary[:, 3] - 1) / 2 - total_syms = torch.sum(boundary[:, 2]) - offset = torch.arange(T0, device=px_grad.device).reshape( - 1, 1, T0 - ) - offset.reshape(B, 1, 1) - sym_delay = px_grad * offset - sym_delay = torch.sum(sym_delay) / total_syms - # ranges : [B, T, prune_range] ranges = k2.get_rnnt_prune_ranges( px_grad=px_grad, @@ -210,9 +187,8 @@ class Transducer(nn.Module): symbols=y_padded, ranges=ranges, termination_symbol=blank_id, - delay_penalty=delay_penalty, boundary=boundary, reduction="sum", ) - return (simple_loss, pruned_loss, sym_delay) + return (simple_loss, pruned_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index 65bffd063..6072a288e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -20,9 +20,12 @@ Usage: ./pruned_transducer_stateless2/streaming_decode.py \ --epoch 28 \ --avg 15 \ + --left-context 32 \ + --decode-chunk-size 8 \ + --right-context 2 \ --exp-dir ./pruned_transducer_stateless2/exp \ --decoding_method greedy_search \ - --num-decode-streams 200 + --num-decode-streams 1000 """ import argparse @@ -182,15 +185,6 @@ def get_parser(): """, ) - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=True, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - parser.add_argument( "--decode-chunk-size", type=int, @@ -356,6 +350,9 @@ def decode_one_chunk( processed_feature_lens = [] for stream in decode_streams: + # we plus 2 here because we will cut off one frame on each size of + # encoder_embed output as they see invalid paddings. so we need extra 2 + # frames. feat, feat_len = stream.get_feature_frames( (params.decode_chunk_size + 2 + params.right_context) * params.subsampling_factor @@ -372,7 +369,7 @@ def decode_one_chunk( # if T is less than 7 there will be an error in time reduction layer, # because we subsample features with ((x_len - 1) // 2 - 1) // 2 - tail_length = 15 + params.right_context * params.subsampling_factor + tail_length = 7 + (2 + params.right_context) * params.subsampling_factor if features.size(1) < tail_length: feature_lens += tail_length - features.size(1) features = torch.cat( @@ -642,10 +639,8 @@ def main(): params.blank_id = sp.piece_to_id("") params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() - - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" + # Decoding in streaming requires causal convolution + params.causal_convolution = True logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index d445713fe..465d7da0f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -48,11 +48,9 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --exp-dir pruned_transducer_stateless/exp \ --full-libri 1 \ --dynamic-chunk-training 1 \ - --causal-convolution 1 \ --short-chunk-size 25 \ --num-left-chunks 4 \ --max-duration 300 - """ @@ -285,15 +283,6 @@ def get_parser(): """, ) - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - parser.add_argument( "--short-chunk-size", type=int, @@ -310,25 +299,6 @@ def get_parser(): help="How many left context can be seen in chunks when calculating attention.", ) - parser.add_argument( - "--delay-penalty", - type=float, - default=0.0, - help="""A constant value to penalize symbol delay, this may be - needed when training with time masking, to avoid the time masking - encouraging the network to delay symbols. - """, - ) - - parser.add_argument( - "--return-sym-delay", - type=str2bool, - default=False, - help="""Whether to return `sym_delay` during training, this is a stat - to measure symbols emission delay, especially for time masking training. - """, - ) - return parser @@ -611,7 +581,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, sym_delay = model( + simple_loss, pruned_loss = model( x=feature, x_lens=feature_lens, y=y, @@ -619,8 +589,6 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, warmup=warmup, - delay_penalty=params.delay_penalty, - return_sym_delay=params.return_sym_delay, ) # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid @@ -650,9 +618,6 @@ def compute_loss( info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() - if params.return_sym_delay: - info["sym_delay"] = sym_delay.detach().cpu().item() - return loss, info @@ -882,13 +847,8 @@ def run(rank, world_size, args): params.vocab_size = sp.get_piece_size() if params.dynamic_chunk_training: - assert ( - params.causal_convolution - ), "dynamic_chunk_training requires causal convolution" - else: - assert ( - params.delay_penalty == 0.0 - ), "delay_penalty is intended for dynamic_chunk_training" + # dynamic_chunk_training requires causal convolution + params.causal_convolution = True logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 43af59761..a95af9ddc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -266,16 +266,7 @@ def get_parser(): ) parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - - parser.add_argument( - "--right-chunk-size", + "--decode-chunk-size", type=int, default=16, help="The chunk size for decoding (in frames after subsampling)", @@ -348,7 +339,7 @@ def decode_one_batch( x=feature, x_lens=feature_lens, states=[], - chunk_size=params.right_chunk_size, + chunk_size=params.decode_chunk_size, left_context=params.left_context, simulate_streaming=True, ) @@ -596,7 +587,7 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if params.simulate_streaming: - params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}" + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" params.suffix += f"-left-context-{params.left_context}" if params.decoding_method == "fast_beam_search": @@ -635,9 +626,8 @@ def main(): params.vocab_size = sp.get_piece_size() if params.simulate_streaming: - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" + # Decoding in streaming requires causal convolution + params.causal_convolution = True logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index f7f96f9a6..020586e7d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -165,14 +165,6 @@ def get_parser(): are streaming model, this should be True. """, ) - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - exporting a streaming model. - """, - ) return parser @@ -198,7 +190,7 @@ def main(): params.vocab_size = sp.get_piece_size() if params.streaming_model: - assert params.causal_convolution + params.causal_convolution = True logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 82d7d024d..1117acbf4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -20,9 +20,12 @@ Usage: ./pruned_transducer_stateless2/streaming_decode.py \ --epoch 28 \ --avg 15 \ + --left-context 32 \ + --decode-chunk-size 8 \ + --right-context 2 \ --exp-dir ./pruned_transducer_stateless2/exp \ --decoding_method greedy_search \ - --num-decode-streams 200 + --num-decode-streams 1000 """ import argparse @@ -183,15 +186,6 @@ def get_parser(): """, ) - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=True, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - parser.add_argument( "--decode-chunk-size", type=int, @@ -357,6 +351,9 @@ def decode_one_chunk( processed_feature_lens = [] for stream in decode_streams: + # we plus 2 here because we will cut off one frame on each size of + # encoder_embed output as they see invalid paddings. so we need extra 2 + # frames. feat, feat_len = stream.get_feature_frames( (params.decode_chunk_size + 2 + params.right_context) * params.subsampling_factor @@ -373,7 +370,10 @@ def decode_one_chunk( # if T is less than 7 there will be an error in time reduction layer, # because we subsample features with ((x_len - 1) // 2 - 1) // 2 - tail_length = 15 + params.right_context * params.subsampling_factor + # we plus 2 here because we will cut off one frame on each size of + # encoder_embed output as they see invalid paddings. so we need extra 2 + # frames. + tail_length = 7 + (2 + params.right_context) * params.subsampling_factor if features.size(1) < tail_length: feature_lens += tail_length - features.size(1) features = torch.cat( @@ -481,7 +481,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -641,10 +643,8 @@ def main(): params.blank_id = sp.piece_to_id("") params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() - - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" + # Decoding in streaming requires causal convolution + params.causal_convolution = True logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index a8b8c7349..18473167d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -295,15 +295,6 @@ def get_parser(): """, ) - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - parser.add_argument( "--short-chunk-size", type=int, @@ -320,25 +311,6 @@ def get_parser(): help="How many left context can be seen in chunks when calculating attention.", ) - parser.add_argument( - "--delay-penalty", - type=float, - default=0.0, - help="""A constant value to penalize symbol delay, this may be - needed when training with time masking, to avoid the time masking - encouraging the network to delay symbols. - """, - ) - - parser.add_argument( - "--return-sym-delay", - type=str2bool, - default=False, - help="""Whether to return `sym_delay` during training, this is a stat - to measure symbols emission delay, especially for time masking training. - """, - ) - return parser @@ -963,13 +935,8 @@ def run(rank, world_size, args): params.vocab_size = sp.get_piece_size() if params.dynamic_chunk_training: - assert ( - params.causal_convolution - ), "dynamic_chunk_training requires causal convolution" - else: - assert ( - params.delay_penalty == 0.0 - ), "delay_penalty is intended for dynamic_chunk_training" + # dynamic_chunk_training requires causal convolution + params.causal_convolution = True logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 9d340a20e..d830886d3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -60,8 +60,7 @@ Usage: --epoch 30 \ --avg 15 \ --simulate-streaming 1 \ - --causal-convolution 1 \ - --right-chunk-size 16 \ + --decode-chunk-size 16 \ --left-context 64 \ --exp-dir ./pruned_transducer_stateless4/exp \ --max-duration 600 \ @@ -269,16 +268,7 @@ def get_parser(): ) parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - - parser.add_argument( - "--right-chunk-size", + "--decode-chunk-size", type=int, default=16, help="The chunk size for decoding (in frames after subsampling)", @@ -351,7 +341,7 @@ def decode_one_batch( x=feature, x_lens=feature_lens, states=[], - chunk_size=params.right_chunk_size, + chunk_size=params.decode_chunk_size, left_context=params.left_context, simulate_streaming=True, ) @@ -573,7 +563,7 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if params.simulate_streaming: - params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}" + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" params.suffix += f"-left-context-{params.left_context}" if "fast_beam_search" in params.decoding_method: @@ -609,9 +599,8 @@ def main(): params.vocab_size = sp.get_piece_size() if params.simulate_streaming: - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" + # Decoding in streaming requires causal convolution + params.causal_convolution = True logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index b9fdaa68e..7d072079e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -20,6 +20,9 @@ Usage: ./pruned_transducer_stateless2/streaming_decode.py \ --epoch 28 \ --avg 15 \ + --left-context 32 \ + --decode-chunk-size 8 \ + --right-context 2 \ --exp-dir ./pruned_transducer_stateless2/exp \ --decoding_method greedy_search \ --num-decode-streams 200 @@ -194,15 +197,6 @@ def get_parser(): """, ) - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=True, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - parser.add_argument( "--decode-chunk-size", type=int, @@ -368,6 +362,9 @@ def decode_one_chunk( processed_feature_lens = [] for stream in decode_streams: + # we plus 2 here because we will cut off one frame on each size of + # encoder_embed output as they see invalid paddings. so we need extra 2 + # frames. feat, feat_len = stream.get_feature_frames( (params.decode_chunk_size + 2 + params.right_context) * params.subsampling_factor @@ -384,7 +381,10 @@ def decode_one_chunk( # if T is less than 7 there will be an error in time reduction layer, # because we subsample features with ((x_len - 1) // 2 - 1) // 2 - tail_length = 15 + params.right_context * params.subsampling_factor + # we plus 2 here because we will cut off one frame on each size of + # encoder_embed output as they see invalid paddings. so we need extra 2 + # frames. + tail_length = 7 + (2 + params.right_context) * params.subsampling_factor if features.size(1) < tail_length: feature_lens += tail_length - features.size(1) features = torch.cat( @@ -492,7 +492,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -655,10 +657,8 @@ def main(): params.blank_id = sp.piece_to_id("") params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() - - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" + # Decoding in streaming requires causal convolution + params.causal_convolution = True logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 2a46d2c18..637c6214b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -49,7 +49,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --exp-dir pruned_transducer_stateless4/exp \ --full-libri 1 \ --dynamic-chunk-training 1 \ - --causal-convolution 1 \ --short-chunk-size 25 \ --num-left-chunks 4 \ --max-duration 300 @@ -302,15 +301,6 @@ def get_parser(): """, ) - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - parser.add_argument( "--short-chunk-size", type=int, @@ -327,25 +317,6 @@ def get_parser(): help="How many left context can be seen in chunks when calculating attention.", ) - parser.add_argument( - "--delay-penalty", - type=float, - default=0.0, - help="""A constant value to penalize symbol delay, this may be - needed when training with time masking, to avoid the time masking - encouraging the network to delay symbols. - """, - ) - - parser.add_argument( - "--return-sym-delay", - type=str2bool, - default=False, - help="""Whether to return `sym_delay` during training, this is a stat - to measure symbols emission delay, especially for time masking training. - """, - ) - return parser @@ -640,7 +611,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, sym_delay = model( + simple_loss, pruned_loss = model( x=feature, x_lens=feature_lens, y=y, @@ -648,8 +619,6 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, warmup=warmup, - delay_penalty=params.delay_penalty, - return_sym_delay=params.return_sym_delay, ) # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid @@ -679,9 +648,6 @@ def compute_loss( info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() - if params.return_sym_delay: - info["sym_delay"] = sym_delay.detach().cpu().item() - return loss, info @@ -922,13 +888,8 @@ def run(rank, world_size, args): params.vocab_size = sp.get_piece_size() if params.dynamic_chunk_training: - assert ( - params.causal_convolution - ), "dynamic_chunk_training requires causal convolution" - else: - assert ( - params.delay_penalty == 0.0 - ), "delay_penalty is intended for dynamic_chunk_training" + # dynamic_chunk_training requires causal convolution + params.causal_convolution = True logging.info(params) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 031d2414a..61409a3a7 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -248,7 +248,9 @@ class Conformer(Transformer): states: List[torch.Tensor], chunk_size: int = 16, left_context: int = 64, + right_context: int = 0, simulate_streaming: bool = False, + processed_lens: Optional[Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: """ Args: @@ -268,13 +270,20 @@ class Conformer(Transformer): The chunk size for decoding, this will be used to simulate streaming decoding using masking. left_context: - How many old frames the attention can see in current chunk, it MUST - be equal to left_context in decode_states. + How many previous frames the attention can see in current chunk. + Note: It's not that each individual frame has `left_context` frames + of left context, some have more. + right_context: + How many future frames the attention can see in current chunk. + Note: It's not that each individual frame has `right_context` frames + of right context, some have more. simulate_streaming: If setting True, it will use a masking strategy to simulate streaming fashion (i.e. every chunk data only see limited left context and right context). The whole sequence is supposed to be send at a time When using simulate_streaming. + processed_lens: + How many frames (after subsampling) have been processed for each sequence. Returns: Return a tuple containing 2 tensors: - logits, its shape is (batch_size, output_seq_len, output_dim) @@ -310,9 +319,27 @@ class Conformer(Transformer): {(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)}, given {states[1].shape}.""" - src_key_padding_mask = make_pad_mask(lengths + left_context) + lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output + src_key_padding_mask = make_pad_mask(lengths) + + assert processed_lens is not None + + processed_mask = torch.arange(left_context, device=x.device).expand( + x.size(0), left_context + ) + processed_lens = processed_lens.view(x.size(0), 1) + processed_mask = (processed_lens <= processed_mask).flip(1) + + src_key_padding_mask = torch.cat( + [processed_mask, src_key_padding_mask], dim=1 + ) embed = self.encoder_embed(x) + + # cut off 1 frame on each size of embed as they see the padding + # value which causes a training and decoding mismatch. + embed = embed[:, 1:-1, :] + embed, pos_enc = self.encoder_pos(embed, left_context) embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) @@ -322,6 +349,7 @@ class Conformer(Transformer): src_key_padding_mask=src_key_padding_mask, states=states, left_context=left_context, + right_context=right_context, ) # (T, B, F) else: src_key_padding_mask = make_pad_mask(lengths) @@ -512,6 +540,7 @@ class ConformerEncoderLayer(nn.Module): src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, left_context: int = 0, + right_context: int = 0, ) -> Tuple[Tensor, List[Tensor]]: """ Pass the input through the encoder layer. @@ -528,9 +557,14 @@ class ConformerEncoderLayer(nn.Module): Note: states will be modified in this function. src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). - left_context: left context (in frames) used during streaming decoding. - this is used only in real streaming decoding, in other circumstances, - it MUST be 0. + left_context: + How many previous frames the attention can see in current chunk. + Note: It's not that each individual frame has `left_context` frames + of left context, some have more. + right_context: + How many future frames the attention can see in current chunk. + Note: It's not that each individual frame has `right_context` frames + of right context, some have more. Shape: src: (S, N, E). pos_emb: (N, 2*(S+left_context)-1, E). @@ -562,7 +596,12 @@ class ConformerEncoderLayer(nn.Module): # separately) if needed. key = torch.cat([states[0], src], dim=0) val = key - states[0] = key[-left_context:, ...] + if right_context > 0: + states[0] = key[ + -(left_context + right_context) : -right_context, ... # noqa + ] + else: + states[0] = key[-left_context:, ...] src_att = self.self_attn( src, @@ -582,7 +621,9 @@ class ConformerEncoderLayer(nn.Module): if self.normalize_before: src = self.norm_conv(src) - src, conv_cache = self.conv_module(src, states[1]) + src, conv_cache = self.conv_module( + src, states[1], right_context=right_context + ) states[1] = conv_cache src = residual + self.dropout(src) @@ -669,6 +710,7 @@ class ConformerEncoder(nn.Module): mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, left_context: int = 0, + right_context: int = 0, ) -> Tuple[Tensor, List[Tensor]]: r"""Pass the input through the encoder layers in turn. @@ -684,9 +726,14 @@ class ConformerEncoder(nn.Module): Note: states will be modified in this function. mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). - left_context: left context (in frames) used during streaming decoding. - this is used only in real streaming decoding, in other circumstances, - it MUST be 0. + left_context: + How many previous frames the attention can see in current chunk. + Note: It's not that each individual frame has `left_context` frames + of left context, some have more. + right_context: + How many future frames the attention can see in current chunk. + Note: It's not that each individual frame has `right_context` frames + of right context, some have more. Shape: src: (S, N, E). pos_emb: (N, 2*(S+left_context)-1, E). @@ -707,6 +754,7 @@ class ConformerEncoder(nn.Module): src_mask=mask, src_key_padding_mask=src_key_padding_mask, left_context=left_context, + right_context=right_context, ) states[0][layer_index] = cache[0] states[1][layer_index] = cache[1] @@ -1329,7 +1377,10 @@ class ConvolutionModule(nn.Module): self.activation = Swish() def forward( - self, x: Tensor, cache: Optional[Tensor] = None + self, + x: Tensor, + cache: Optional[Tensor] = None, + right_context: int = 0, ) -> Tuple[Tensor, Tensor]: """Compute convolution module. @@ -1359,7 +1410,15 @@ class ConvolutionModule(nn.Module): ), "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) - cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa + if right_context > 0: + cache = x.permute(2, 0, 1)[ + -(self.lorder + right_context) : ( # noqa + -right_context + ), + ..., + ] + else: + cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa x = self.depthwise_conv(x) # x is (batch, channels, time) diff --git a/icefall/utils.py b/icefall/utils.py index 60ca9dcde..8e87d29df 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -535,11 +535,8 @@ class MetricsTracker(collections.defaultdict): ans = [] for k, v in self.items(): if k != "frames": - if k != "sym_delay": - norm_value = float(v) / num_frames - ans.append((k, norm_value)) - else: - ans.append((k, float(v))) + norm_value = float(v) / num_frames + ans.append((k, norm_value)) return ans def reduce(self, device):