From 7cc697c03a9c97856f0e5397dfc227f7cfcc348c Mon Sep 17 00:00:00 2001 From: pkufool Date: Sun, 22 May 2022 17:14:17 +0800 Subject: [PATCH] support streaming on pruned_transducer_stateless2; add delay penalty; fixes for decode states --- .../ASR/pruned_transducer_stateless/model.py | 26 +- .../ASR/pruned_transducer_stateless/train.py | 27 +- .../pruned_transducer_stateless2/conformer.py | 346 ++++++++++++++++-- .../pruned_transducer_stateless2/decode.py | 103 +++++- .../ASR/pruned_transducer_stateless2/model.py | 27 +- .../ASR/pruned_transducer_stateless2/train.py | 86 ++++- .../ASR/transducer_stateless/conformer.py | 177 +++------ icefall/utils.py | 7 +- 8 files changed, 637 insertions(+), 162 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/model.py b/egs/librispeech/ASR/pruned_transducer_stateless/model.py index 2f019bcdb..8b50ac657 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/model.py @@ -66,6 +66,8 @@ class Transducer(nn.Module): prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, + delay_penalty: float = 0.0, + return_sym_delay: bool = False, ) -> torch.Tensor: """ Args: @@ -136,10 +138,31 @@ class Transducer(nn.Module): lm_only_scale=lm_scale, am_only_scale=am_scale, boundary=boundary, + delay_penalty=delay_penalty, reduction="sum", return_grad=True, ) + sym_delay = None + if return_sym_delay: + B, S, T0 = px_grad.shape + T = T0 - 1 + if boundary is None: + offset = torch.tensor( + (T - 1) / 2, + dtype=px_grad.dtype, + device=px_grad.device, + ).expand(B, 1, 1) + total_syms = S * B + else: + offset = (boundary[:, 3] - 1) / 2 + total_syms = torch.sum(boundary[:, 2]) + offset = torch.arange( + T0, device=px_grad.device + ).reshape(1, 1, T0) - offset.reshape(B, 1, 1) + sym_delay = px_grad * offset + sym_delay = torch.sum(sym_delay) / total_syms + # ranges : [B, T, prune_range] ranges = k2.get_rnnt_prune_ranges( px_grad=px_grad, @@ -163,7 +186,8 @@ class Transducer(nn.Module): ranges=ranges, termination_symbol=blank_id, boundary=boundary, + delay_penalty=delay_penalty, reduction="sum", ) - return (simple_loss, pruned_loss) + return (simple_loss, pruned_loss, sym_delay) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index cb8f13c54..f505cbc14 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -269,6 +269,25 @@ def get_parser(): help="How many left context can be seen in chunks when calculating attention.", ) + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value to penalize symbol delay, this may be + needed when training with time masking, to avoid the time masking + encouraging the network to delay symbols. + """, + ) + + parser.add_argument( + "--return-sym-delay", + type=str2bool, + default=False, + help="""Whether to return `sym_delay` during training, this is a stat + to measure symbols emission delay, especially for time masking training. + """, + ) + return parser @@ -536,14 +555,17 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y).to(device) + sym_delay = None with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( + simple_loss, pruned_loss, sym_delay = model( x=feature, x_lens=feature_lens, y=y, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, + delay_penalty=params.delay_penalty, + return_sym_delay=params.return_sym_delay, ) loss = params.simple_loss_scale * simple_loss + pruned_loss @@ -561,6 +583,9 @@ def compute_loss( info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() + if sym_delay is not None: + info["sym_delay"] = sym_delay.detatch().cpu().item() + return loss, info diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 257936b59..b36c02df2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -32,7 +32,7 @@ from scaling import ( ) from torch import Tensor, nn -from icefall.utils import make_pad_mask +from icefall.utils import make_pad_mask, subsequent_chunk_mask class Conformer(EncoderInterface): @@ -48,6 +48,26 @@ class Conformer(EncoderInterface): layer_dropout (float): layer-dropout rate. cnn_module_kernel (int): Kernel size of convolution module vgg_frontend (bool): whether to use vgg frontend. + dynamic_chunk_training (bool): whether to use dynamic chunk training, if + you want to train a streaming model, this is expected to be True. + When setting True, it will use a masking strategy to make the attention + see only limited left and right context. + short_chunk_threshold (float): a threshold to determinize the chunk size + to be used in masking training, if the randomly generated chunk size + is greater than ``max_len * short_chunk_threshold`` (max_len is the + max sequence length of current batch) then it will use + full context in training (i.e. with chunk size equals to max_len). + This will be used only when dynamic_chunk_training is True. + short_chunk_size (int): see docs above, if the randomly generated chunk + size equals to or less than ``max_len * short_chunk_threshold``, the + chunk size will be sampled uniformly from 1 to short_chunk_size. + This also will be used only when dynamic_chunk_training is True. + num_left_chunks (int): the left context (in chunks) attention can see, the + chunk size is decided by short_chunk_threshold and short_chunk_size. + A minus value means seeing full left context. + This also will be used only when dynamic_chunk_training is True. + causal (bool): Whether to use causal convolution in conformer encoder + layer. This MUST be True when using dynamic_chunk_training. """ def __init__( @@ -61,6 +81,11 @@ class Conformer(EncoderInterface): dropout: float = 0.1, layer_dropout: float = 0.075, cnn_module_kernel: int = 31, + dynamic_chunk_training: bool = False, + short_chunk_threshold: float = 0.75, + short_chunk_size: int = 25, + num_left_chunks: int = -1, + causal: bool = False, ) -> None: super(Conformer, self).__init__() @@ -76,6 +101,14 @@ class Conformer(EncoderInterface): # (2) embedding: num_features -> d_model self.encoder_embed = Conv2dSubsampling(num_features, d_model) + self.encoder_layers = num_encoder_layers + self.d_model = d_model + self.dynamic_chunk_training = dynamic_chunk_training + self.short_chunk_threshold = short_chunk_threshold + self.short_chunk_size = short_chunk_size + self.num_left_chunks = num_left_chunks + self.causal = causal + self.encoder_pos = RelPositionalEncoding(d_model, dropout) encoder_layer = ConformerEncoderLayer( @@ -85,6 +118,7 @@ class Conformer(EncoderInterface): dropout, layer_dropout, cnn_module_kernel, + causal, ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) @@ -117,10 +151,31 @@ class Conformer(EncoderInterface): # Caution: We assume the subsampling factor is 4! lengths = ((x_lens - 1) // 2 - 1) // 2 assert x.size(0) == lengths.max().item() - mask = make_pad_mask(lengths) - x = self.encoder( - x, pos_emb, src_key_padding_mask=mask, warmup=warmup + src_key_padding_mask = make_pad_mask(lengths) + mask = None + + if self.dynamic_chunk_training: + assert ( + self.causal + ), "Causal convolution is required for streaming conformer." + max_len = x.size(0) + chunk_size = torch.randint(1, max_len, (1,)).item() + if chunk_size > (max_len * self.short_chunk_threshold): + chunk_size = max_len + else: + chunk_size = chunk_size % self.short_chunk_size + 1 + + mask = ~subsequent_chunk_mask( + size=x.size(0), chunk_size=chunk_size, + num_left_chunks=self.num_left_chunks, device=x.device + ) + + x, _ = self.encoder( + x, pos_emb, + mask=mask, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, ) # (T, N, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -128,6 +183,116 @@ class Conformer(EncoderInterface): return x, lengths + def streaming_forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + warmup: float = 1.0, + states: Optional[Tensor] = None, + chunk_size: int = 16, + left_context: int = 64, + simulate_streaming: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + states: + The decode states for previous frames which contains the cached data. + It has a shape of (2, encoder_layers, left_context, batch, attention_dim), + states[0,...] is the attn_cache, states[1,...] is the conv_cache. + chunk_size: + The chunk size for decoding, this will be used to simulate streaming + decoding using masking. + left_context: + How many old frames the attention can see in current chunk, it MUST + be equal to left_context in decode_states. + simulate_streaming: + If setting True, it will use a masking strategy to simulate streaming + fashion (i.e. every chunk data only see limited left context and + right context). The whole sequence is supposed to be send at a time + When using simulate_streaming. + Returns: + Return a tuple containing 2 tensors: + - logits, its shape is (batch_size, output_seq_len, output_dim) + - logit_lens, a tensor of shape (batch_size,) containing the number + of frames in `logits` before padding. + - decode_states, the updated DecodeStates including the information + of current chunk. + """ + + # x: [N, T, C] + # Caution: We assume the subsampling factor is 4! + lengths = ((x_lens - 1) // 2 - 1) // 2 + + if not simulate_streaming: + assert ( + decode_states is not None + ), "Require cache when sending data in streaming mode" + + assert ( + states.shape == (2, self.encoder_layers, left_context, x.size(0), self.d_model) + ), f"""The shape of states MUST be equal to + (2, encoder_layers, left_context, batch, d_model) which is + {(2, self.encoder_layers, left_context, x.size(0), self.d_model)} + given {states.shape}.""" + + src_key_padding_mask = make_pad_mask(lengths + left_context) + + embed = self.encoder_embed(x) + embed, pos_enc = self.encoder_pos(embed, left_context) + embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + + x, states = self.encoder( + embed, + pos_enc, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + states=states, + left_context=left_context, + ) # (T, B, F) + + else: + assert states is None + + src_key_padding_mask = make_pad_mask(lengths) + x = self.encoder_embed(x) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + assert x.size(0) == lengths.max().item() + + num_left_chunks = -1 + if left_context >= 0: + assert left_context % chunk_size == 0 + num_left_chunks = left_context // chunk_size + + mask = ~subsequent_chunk_mask( + size=x.size(0), + chunk_size=chunk_size, + num_left_chunks=num_left_chunks, + device=x.device + ) + x, _ = self.encoder( + x, + pos_emb, + mask=mask, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + ) # (T, N, C) + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return x, lengths, states + + class ConformerEncoderLayer(nn.Module): """ ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. @@ -139,6 +304,8 @@ class ConformerEncoderLayer(nn.Module): dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). cnn_module_kernel (int): Kernel size of convolution module. + causal (bool): Whether to use causal convolution in conformer encoder + layer. This MUST be True when using dynamic_chunk_training and streaming decoding. Examples:: >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) @@ -155,6 +322,7 @@ class ConformerEncoderLayer(nn.Module): dropout: float = 0.1, layer_dropout: float = 0.075, cnn_module_kernel: int = 31, + causal: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() @@ -182,7 +350,11 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + self.conv_module = ConvolutionModule( + d_model, + cnn_module_kernel, + causal=causal + ) self.norm_final = BasicNorm(d_model) @@ -200,7 +372,9 @@ class ConformerEncoderLayer(nn.Module): src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, warmup: float = 1.0, - ) -> Tensor: + states: Optional[Tensor] = None, + left_context: int = 0, + ) -> Tuple[Tensor, Tensor]: """ Pass the input through the encoder layer. @@ -211,10 +385,17 @@ class ConformerEncoderLayer(nn.Module): src_key_padding_mask: the mask for the src keys per batch (optional). warmup: controls selective bypass of of layers; if < 1.0, we will bypass layers more frequently. + states: + The decode states for previous frames which contains the cached data. + It has a shape of (2, encoder_layers, left_context, batch, attention_dim), + states[0,...] is the attn_cache, states[1,...] is the conv_cache. + left_context: left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. Shape: src: (S, N, E). - pos_emb: (N, 2*S-1, E) + pos_emb: (N, 2*S-1, E),for streaming decoding it is (N, 2*(S+left_context)-1, E). src_mask: (S, S). src_key_padding_mask: (N, S). S is the source sequence length, N is the batch size, E is the feature number @@ -236,19 +417,38 @@ class ConformerEncoderLayer(nn.Module): # macaron style feed forward module src = src + self.dropout(self.feed_forward_macaron(src)) + key = src + val = src + if not self.training and states is not None: + # src: [chunk_size, N, F] e.g. [8, 41, 512] + key = torch.cat([states[0, ...], src], dim=0) + val = key + states[0, ...] = key[-left_context, ...] + else: + assert left_context == 0 + # multi-headed self-attention module src_att = self.self_attn( src, - src, - src, + key, + val, pos_emb=pos_emb, attn_mask=src_mask, key_padding_mask=src_key_padding_mask, + left_context=left_context, )[0] + src = src + self.dropout(src_att) # convolution module - src = src + self.dropout(self.conv_module(src)) + if not self.training and states is not None: + src = torch.cat([states[1, ...], src], dim=0) + states[1, ...] = src[-left_context, ...] + + conv = self.conv_module(src) + conv = conv[-src.size(0) :, :, :] # noqa: E203 + + src = src + self.dropout(conv) # feed forward module src = src + self.dropout(self.feed_forward(src)) @@ -258,7 +458,7 @@ class ConformerEncoderLayer(nn.Module): if alpha != 1.0: src = alpha * src + (1 - alpha) * src_orig - return src + return src, states class ConformerEncoder(nn.Module): @@ -290,7 +490,9 @@ class ConformerEncoder(nn.Module): mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, warmup: float = 1.0, - ) -> Tensor: + states: Optional[Tensor] = None, + left_context: int = 0, + ) -> Tuple[Tensor, Tensor]: r"""Pass the input through the encoder layers in turn. Args: @@ -298,10 +500,19 @@ class ConformerEncoder(nn.Module): pos_emb: Positional embedding tensor (required). mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). + warmup: controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + states: + The decode states for previous frames which contains the cached data. + It has a shape of (2, encoder_layers, left_context, batch, attention_dim), + states[0,...] is the attn_cache, states[1,...] is the conv_cache. + left_context: left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. Shape: src: (S, N, E). - pos_emb: (N, 2*S-1, E) + pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E). mask: (S, S). src_key_padding_mask: (N, S). S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number @@ -309,16 +520,26 @@ class ConformerEncoder(nn.Module): """ output = src - for i, mod in enumerate(self.layers): - output = mod( + if self.training: + assert left_context == 0 + assert states is None + else: + assert left_context >= 0 + + for layer_index, mod in enumerate(self.layers): + output, cache = mod( output, pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, warmup=warmup, + states=None if states is None else states[:, layer_index, ...], + left_context=left_context, ) + if states is not None: + states[:, layer_index, ...] = cache - return output + return output, states class RelPositionalEncoding(torch.nn.Module): @@ -344,12 +565,13 @@ class RelPositionalEncoding(torch.nn.Module): self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - def extend_pe(self, x: Tensor) -> None: + def extend_pe(self, x: Tensor, context: int = 0) -> None: """Reset the positional encodings.""" + x_size_1 = x.size(1) + context if self.pe is not None: # self.pe contains both positive and negative parts # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device if self.pe.dtype != x.dtype or str(self.pe.device) != str( x.device @@ -359,9 +581,9 @@ class RelPositionalEncoding(torch.nn.Module): # Suppose `i` means to the position of query vecotr and `j` means the # position of key vector. We use position relative positions when keys # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + def forward( + self, + x: torch.Tensor, + context: int = 0 + ) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: x (torch.Tensor): Input tensor (batch, time, `*`). + context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. Returns: torch.Tensor: Encoded tensor (batch, time, `*`). torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). """ - self.extend_pe(x) + self.extend_pe(x, context) + x_size_1 = x.size(1) + context pos_emb = self.pe[ :, self.pe.size(1) // 2 - - x.size(1) + - x_size_1 + 1 : self.pe.size(1) // 2 # noqa E203 - + x.size(1), + + x_size_1, ] return self.dropout(x), self.dropout(pos_emb) @@ -466,6 +696,7 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, + left_context: int = 0, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: @@ -479,6 +710,9 @@ class RelPositionMultiheadAttention(nn.Module): need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. + left_context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. Shape: - Inputs: @@ -524,14 +758,18 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, + left_context=left_context, ) - def rel_shift(self, x: Tensor) -> Tensor: + def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor: """Compute relative positional encoding. Args: x: Input tensor (batch, head, time1, 2*time1-1). time1 means the length of query vector. + left_context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. Returns: Tensor: tensor of shape (batch, head, time1, time2) @@ -539,14 +777,17 @@ class RelPositionMultiheadAttention(nn.Module): the key, while time1 is for the query). """ (batch_size, num_heads, time1, n) = x.shape - assert n == 2 * time1 - 1 + + time2 = time1 + left_context + assert n == 2 * time2 - 1, f"{n} == 2 * {time2} - 1" + # Note: TorchScript requires explicit arg for stride() batch_stride = x.stride(0) head_stride = x.stride(1) time1_stride = x.stride(2) n_stride = x.stride(3) return x.as_strided( - (batch_size, num_heads, time1, time1), + (batch_size, num_heads, time1, time2), (batch_stride, head_stride, time1_stride - n_stride, n_stride), storage_offset=n_stride * (time1 - 1), ) @@ -568,6 +809,7 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, + left_context: int = 0, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: @@ -585,6 +827,9 @@ class RelPositionMultiheadAttention(nn.Module): need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. + left_context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. Shape: Inputs: @@ -748,7 +993,8 @@ class RelPositionMultiheadAttention(nn.Module): pos_emb_bsz = pos_emb.size(0) assert pos_emb_bsz in (1, bsz) # actually it is 1 p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) - p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1) + p = p.permute(0, 2, 3, 1) q_with_bias_u = (q + self._pos_bias_u()).transpose( 1, 2 @@ -768,9 +1014,9 @@ class RelPositionMultiheadAttention(nn.Module): # compute matrix b and matrix d matrix_bd = torch.matmul( - q_with_bias_v, p.transpose(-2, -1) + q_with_bias_v, p ) # (batch, head, time1, 2*time1-1) - matrix_bd = self.rel_shift(matrix_bd) + matrix_bd = self.rel_shift(matrix_bd, left_context) attn_output_weights = ( matrix_ac + matrix_bd @@ -805,6 +1051,24 @@ class RelPositionMultiheadAttention(nn.Module): ) attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + + # If we are using dynamic_chunk_training and setting a limited + # num_left_chunks, the attention may only see the padding values which + # will also be masked out by `key_padding_mask`, at this circumstances, + # the whole column of `attn_output_weights` will be `-inf` + # (i.e. be `nan` after softmax), so, we fill `0.0` at the masking + # positions to avoid invalid loss value below. + if attn_mask is not None and attn_mask.dtype == torch.bool and \ + key_padding_mask is not None: + combined_mask = attn_mask.unsqueeze( + 0) | key_padding_mask.unsqueeze(1).unsqueeze(2) + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len) + attn_output_weights = attn_output_weights.masked_fill( + combined_mask, 0.0) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len) + attn_output_weights = nn.functional.dropout( attn_output_weights, p=dropout_p, training=training ) @@ -838,16 +1102,21 @@ class ConvolutionModule(nn.Module): channels (int): The number of channels of conv layers. kernel_size (int): Kernerl size of conv layers. bias (bool): Whether to use bias in conv layers (default=True). - + causal (bool): Whether to use causal convolution. """ def __init__( - self, channels: int, kernel_size: int, bias: bool = True + self, + channels: int, + kernel_size: int, + bias: bool = True, + causal: bool = False ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding assert (kernel_size - 1) % 2 == 0 + self.causal = causal self.pointwise_conv1 = ScaledConv1d( channels, @@ -875,12 +1144,17 @@ class ConvolutionModule(nn.Module): channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 ) + self.lorder = kernel_size - 1 + padding = (kernel_size - 1) // 2 + if self.causal: + padding = 0 + self.depthwise_conv = ScaledConv1d( channels, channels, kernel_size, stride=1, - padding=(kernel_size - 1) // 2, + padding=padding, groups=channels, bias=bias, ) @@ -921,6 +1195,10 @@ class ConvolutionModule(nn.Module): x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv + if self.causal and self.lorder > 0: + # Make depthwise_conv causal by + # manualy padding self.lorder zeros to the left + x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) x = self.depthwise_conv(x) x = self.deriv_balancer2(x) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 05a4cdca5..cbc2baa8a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -53,6 +53,18 @@ Usage: --beam 4 \ --max-contexts 4 \ --max-states 8 + +(5) decode in streaming mode (take greedy search as an example) +./pruned_transducer_stateless2/decode.py \ + --epoch 28 \ + --avg 15 \ + --simulate-streaming 1 \ + --causal-convolution 1 \ + --right-chunk-size 16 \ + --left-context 64 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method greedy_search """ @@ -85,6 +97,7 @@ from icefall.utils import ( AttributeDict, setup_logger, store_transcripts, + str2bool, write_error_stats, ) @@ -190,6 +203,7 @@ def get_parser(): help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) + parser.add_argument( "--max-sym-per-frame", type=int, @@ -198,6 +212,70 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="""How many left context can be seen in chunks when calculating attention. + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + + parser.add_argument( + "--right-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + return parser @@ -246,9 +324,19 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + if params.simulate_streaming: + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.right_chunk_size, + left_context=params.left_context, + simulate_streaming=True + ) + else: + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) + hyps = [] if params.decoding_method == "fast_beam_search": @@ -461,6 +549,10 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" @@ -490,6 +582,11 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + logging.info(params) logging.info("About to create model") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 599bf2506..f510c6d15 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -15,6 +15,7 @@ # limitations under the License. +import logging import k2 import torch import torch.nn as nn @@ -77,6 +78,8 @@ class Transducer(nn.Module): am_scale: float = 0.0, lm_scale: float = 0.0, warmup: float = 1.0, + delay_penalty: float = 0.0, + return_sym_delay: bool = False, ) -> torch.Tensor: """ Args: @@ -154,10 +157,31 @@ class Transducer(nn.Module): lm_only_scale=lm_scale, am_only_scale=am_scale, boundary=boundary, + delay_penalty=delay_penalty, reduction="sum", return_grad=True, ) + sym_delay = None + if return_sym_delay: + B, S, T0 = px_grad.shape + T = T0 - 1 + if boundary is None: + offset = torch.tensor( + (T - 1) / 2, + dtype=px_grad.dtype, + device=px_grad.device, + ).expand(B, 1, 1) + total_syms = S * B + else: + offset = (boundary[:, 3] - 1) / 2 + total_syms = torch.sum(boundary[:, 2]) + offset = torch.arange( + T0, device=px_grad.device + ).reshape(1, 1, T0) - offset.reshape(B, 1, 1) + sym_delay = px_grad * offset + sym_delay = torch.sum(sym_delay) / total_syms + # ranges : [B, T, prune_range] ranges = k2.get_rnnt_prune_ranges( px_grad=px_grad, @@ -186,8 +210,9 @@ class Transducer(nn.Module): symbols=y_padded, ranges=ranges, termination_symbol=blank_id, + delay_penalty=delay_penalty, boundary=boundary, reduction="sum", ) - return (simple_loss, pruned_loss) + return (simple_loss, pruned_loss, sym_delay) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 51c1a231a..0deaede7a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -40,6 +40,19 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --full-libri 1 \ --max-duration 550 +# train a streaming model +./pruned_transducer_stateless2/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 0 \ + --exp-dir pruned_transducer_stateless/exp \ + --full-libri 1 \ + --dynamic-chunk-training 1 \ + --causal-convolution 1 \ + --short-chunk-size 25 \ + --num-left-chunks 4 \ + --max-duration 300 + """ @@ -263,6 +276,59 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value to penalize symbol delay, this may be + needed when training with time masking, to avoid the time masking + encouraging the network to delay symbols. + """, + ) + + parser.add_argument( + "--return-sym-delay", + type=str2bool, + default=False, + help="""Whether to return `sym_delay` during training, this is a stat + to measure symbols emission delay, especially for time masking training. + """, + ) + return parser @@ -349,6 +415,10 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: nhead=params.nhead, dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, + dynamic_chunk_training=params.dynamic_chunk_training, + short_chunk_size=params.short_chunk_size, + num_left_chunks=params.num_left_chunks, + causal=params.causal_convolution, ) return encoder @@ -541,7 +611,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( + simple_loss, pruned_loss, sym_delay = model( x=feature, x_lens=feature_lens, y=y, @@ -549,6 +619,8 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, warmup=warmup, + delay_penalty=params.delay_penalty, + return_sym_delay=params.return_sym_delay, ) # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid @@ -577,6 +649,9 @@ def compute_loss( info["loss"] = loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() + + if params.return_sym_delay: + info["sym_delay"] = sym_delay.detach().cpu().item() return loss, info @@ -806,6 +881,15 @@ def run(rank, world_size, args): params.blank_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() + if params.dynamic_chunk_training: + assert ( + params.causal_convolution + ), "dynamic_chunk_training requires causal convolution" + else: + assert ( + params.delay_penalty == 0.0 + ), "delay_penalty is intended for dynamic_chunk_training" + logging.info(params) logging.info("About to create model") diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index f2a051471..45efe9fd4 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -18,7 +18,7 @@ import copy import math import warnings -from typing import List, Optional, Tuple +from typing import Optional, Tuple import torch from torch import Tensor, nn @@ -27,70 +27,6 @@ from transformer import Transformer from icefall.utils import make_pad_mask, subsequent_chunk_mask -class DecodeStates(object): - def __init__(self, - layers: int, - left_context: int, - dim: int, - init: bool = True, - dtype: torch.dtype = torch.float32, - device: torch.device = torch.device('cpu')): - self.layers = layers - self.left_context = left_context - self.dim = dim - self.dtype = dtype - self.device = device - if init: - # shape (layer, T, dim) - self.attn_cache = torch.zeros((layers, left_context, dim), - dtype=dtype, - device=device) - self.conv_cache = torch.zeros((layers, left_context, dim), - dtype=dtype, - device=device) - self.offset = torch.tensor([0], dtype=dtype, device=device) - - @staticmethod - def stack(states: List['DecodeStates']) -> 'DecodeStates': - assert len(states) >= 1 - obj = DecodeStates(layers=states[0].layers, - left_context=states[0].left_context, - dim=states[0].dim, - init=False, - dtype=states[0].dtype, - device=states[0].device) - attn_cache = [] - conv_cache = [] - offset = [] - for i in range(len(states)): - attn_cache.append(states[i].attn_cache) - conv_cache.append(states[i].conv_cache) - offset.append(states[i].offset) - obj.attn_cache = torch.stack(attn_cache, dim=2) - obj.conv_cache = torch.stack(conv_cache, dim=2) - obj.offset = torch.stack(offset, dim=0) - return obj - - @staticmethod - def unstack(states: 'DecodeStates') -> List['DecodeStates']: - results = [] - attn_cache = torch.unbind(states.attn_cache, dim=2) - conv_cache = torch.unbind(states.conv_cache, dim=2) - offset = torch.unbind(states.offset, dim=0) - for i in range(states.attn_cache.size(2)): - obj = DecodeStates(layers=states.layers, - left_context=states.left_context, - dim=states.dim, - init=False, - dtype=states.dtype, - device=states.device) - obj.attn_cache = attn_cache[i] - obj.conv_cache = conv_cache[i] - obj.offset = offset[i] - results.append(obj) - return results - - class Conformer(Transformer): """ Args: @@ -119,7 +55,7 @@ class Conformer(Transformer): size equals to or less than ``max_len * short_chunk_threshold``, the chunk size will be sampled uniformly from 1 to short_chunk_size. This also will be used only when dynamic_chunk_training is True. - num_left_chunks (int): the left context attention can see in chunks, the + num_left_chunks (int): the left context (in chunks) attention can see, the chunk size is decided by short_chunk_threshold and short_chunk_size. A minus value means seeing full left context. This also will be used only when dynamic_chunk_training is True. @@ -159,6 +95,8 @@ class Conformer(Transformer): vgg_frontend=vgg_frontend, ) + self.encoder_layers = num_encoder_layers + self.d_model = d_model self.dynamic_chunk_training = dynamic_chunk_training self.short_chunk_threshold = short_chunk_threshold self.short_chunk_size = short_chunk_size @@ -231,7 +169,7 @@ class Conformer(Transformer): num_left_chunks=self.num_left_chunks, device=x.device ) - x = self.encoder( + x, _ = self.encoder( x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask ) # (T, N, C) @@ -248,11 +186,11 @@ class Conformer(Transformer): self, x: torch.Tensor, x_lens: torch.Tensor, - decode_states: Optional[DecodeStates] = None, + states: Optional[torch.Tensor] = None, chunk_size: int = 16, left_context: int = 64, simulate_streaming: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor, DecodeStates]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -260,9 +198,10 @@ class Conformer(Transformer): x_lens: A tensor of shape (batch_size,) containing the number of frames in `x` before padding. - decode_states: - The decode states for previous frames which contains the cached data - and the offset of current chunk in the whole sequence. + states: + The decode states for previous frames which contains the cached data. + It has a shape of (2, encoder_layers, left_context, batch, attention_dim), + states[0,...] is the attn_cache, states[1,...] is the conv_cache. chunk_size: The chunk size for decoding, this will be used to simulate streaming decoding using masking. @@ -289,13 +228,15 @@ class Conformer(Transformer): if not simulate_streaming: assert ( - decode_states is not None + states is not None ), "Require cache when sending data in streaming mode" + assert ( - left_context == decode_states.left_context - ), f"""The given left_context must equal to the left_context in - `decode_states`, need {decode_states.left_context} given - {left_context}.""" + states.shape == (2, self.encoder_layers, left_context, x.size(0), self.d_model) + ), f"""The shape of states MUST be equal to + (2, encoder_layers, left_context, batch, d_model) which is + {(2, self.encoder_layers, left_context, x.size(0), self.d_model)} + given {states.shape}.""" src_key_padding_mask = make_pad_mask(lengths + left_context) @@ -303,18 +244,16 @@ class Conformer(Transformer): embed, pos_enc = self.encoder_pos(embed, left_context) embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - x = self.encoder( + x, states = self.encoder( embed, pos_enc, src_key_padding_mask=src_key_padding_mask, - attn_cache=decode_states.attn_cache, - conv_cache=decode_states.conv_cache, - left_context=decode_states.left_context, + states=states, + left_context=left_context, ) # (T, B, F) - decode_states.offset += embed.size(0) else: - assert decode_states is None + assert states is None src_key_padding_mask = make_pad_mask(lengths) x = self.encoder_embed(x) @@ -322,8 +261,11 @@ class Conformer(Transformer): x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) assert x.size(0) == lengths.max().item() - assert left_context % chunk_size == 0 - num_left_chunks = left_context // chunk_size + + num_left_chunks = -1 + if left_context >= 0: + assert left_context % chunk_size == 0 + num_left_chunks = left_context // chunk_size mask = ~subsequent_chunk_mask( size=x.size(0), @@ -331,7 +273,7 @@ class Conformer(Transformer): num_left_chunks=num_left_chunks, device=x.device ) - x = self.encoder( + x, _ = self.encoder( x, pos_emb, mask=mask, @@ -344,7 +286,7 @@ class Conformer(Transformer): logits = self.encoder_output_layer(x) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - return logits, lengths, decode_states + return logits, lengths, states class ConformerEncoderLayer(nn.Module): @@ -425,10 +367,9 @@ class ConformerEncoderLayer(nn.Module): pos_emb: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - attn_cache: Optional[Tensor] = None, - conv_cache: Optional[Tensor] = None, + states: Optional[Tensor] = None, left_context: int = 0, - ) -> Tensor: + ) -> Tuple[Tensor, Tensor]: """ Pass the input through the encoder layer. @@ -437,9 +378,10 @@ class ConformerEncoderLayer(nn.Module): pos_emb: Positional embedding tensor (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). - attn_cache: attention cache for previous frames. - conv_cache: convolution cache for previous frames. - left_context: left context in frames used during streaming decoding. + states: The decode states for previous frames which contains the cached data. + It has a shape of (2, left_context, batch, attention_dim), + states[0,...] is the attn_cache, states[1,...] is the conv_cache. + left_context: left context (in frames) used during streaming decoding. this is used only in real streaming decoding, in other circumstances, it MUST be 0. Shape: @@ -467,11 +409,11 @@ class ConformerEncoderLayer(nn.Module): key = src val = src - if not self.training and attn_cache is not None: + if not self.training and states is not None: # src: [chunk_size, N, F] e.g. [8, 41, 512] - key = torch.cat([attn_cache, src], dim=0) + key = torch.cat([states[0, ...], src], dim=0) val = key - attn_cache = key + states[0, ...] = key[-left_context:, ...] else: assert left_context == 0 @@ -493,9 +435,9 @@ class ConformerEncoderLayer(nn.Module): if self.normalize_before: src = self.norm_conv(src) - if not self.training and conv_cache is not None: - src = torch.cat([conv_cache, src], dim=0) - conv_cache = src + if not self.training and states is not None: + src = torch.cat([states[1, ...], src], dim=0) + states[1, ...] = src[-left_context:, ...] src = self.conv_module(src) src = src[-residual.size(0) :, :, :] # noqa: E203 @@ -515,7 +457,7 @@ class ConformerEncoderLayer(nn.Module): if self.normalize_before: src = self.norm_final(src) - return src, attn_cache, conv_cache + return src, states class ConformerEncoder(nn.Module): @@ -546,10 +488,9 @@ class ConformerEncoder(nn.Module): pos_emb: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - attn_cache: Optional[Tensor] = None, - conv_cache: Optional[Tensor] = None, + states: Optional[Tensor] = None, left_context: int = 0, - ) -> Tensor: + ) -> Tuple[Tensor, Tensor]: r"""Pass the input through the encoder layers in turn. Args: @@ -557,9 +498,10 @@ class ConformerEncoder(nn.Module): pos_emb: Positional embedding tensor (required). mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). - attn_cache: attention cache for previous frames. - conv_cache: convolution cache for previous frames. - left_context: left context in frames used during streaming decoding. + states: The decode states for previous frames which contains the cached data. + It has a shape of (2, encoder_layers, left_context, batch, attention_dim), + states[0,...] is the attn_cache, states[1,...] is the conv_cache. + left_context: left context (in frames) used during streaming decoding. this is used only in real streaming decoding, in other circumstances, it MUST be 0. Shape: @@ -576,26 +518,23 @@ class ConformerEncoder(nn.Module): if self.training: assert left_context == 0 - assert attn_cache is None - assert conv_cache is None + assert states is None else: assert left_context >= 0 for layer_index, mod in enumerate(self.layers): - output, a_cache, c_cache = mod( + output, cache = mod( output, pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, - attn_cache=None if attn_cache is None else attn_cache[layer_index], - conv_cache=None if conv_cache is None else conv_cache[layer_index], + states=None if states is None else states[:,layer_index, ...], left_context=left_context, ) - if attn_cache is not None and conv_cache is not None: - attn_cache[layer_index, ...] = a_cache[-left_context:, ...] - conv_cache[layer_index, ...] = c_cache[-left_context:, ...] + if states is not None: + states[:, layer_index, ...] = cache - return output + return output, states class RelPositionalEncoding(torch.nn.Module): @@ -667,7 +606,7 @@ class RelPositionalEncoding(torch.nn.Module): Args: x (torch.Tensor): Input tensor (batch, time, `*`). - context (int): left context in frames used during streaming decoding. + context (int): left context (in frames) used during streaming decoding. this is used only in real streaming decoding, in other circumstances, it MUST be 0. Returns: @@ -762,7 +701,7 @@ class RelPositionMultiheadAttention(nn.Module): need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. - left_context (int): left context in frames used during streaming decoding. + left_context (int): left context (in frames) used during streaming decoding. this is used only in real streaming decoding, in other circumstances, it MUST be 0. @@ -819,7 +758,7 @@ class RelPositionMultiheadAttention(nn.Module): Args: x: Input tensor (batch, head, time1, 2*time1-1). time1 means the length of query vector. - left_context (int): left context in frames used during streaming decoding. + left_context (int): left context (in frames) used during streaming decoding. this is used only in real streaming decoding, in other circumstances, it MUST be 0. @@ -879,7 +818,7 @@ class RelPositionMultiheadAttention(nn.Module): need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. - left_context (int): left context in frames used during streaming decoding. + left_context (int): left context (in frames) used during streaming decoding. this is used only in real streaming decoding, in other circumstances, it MUST be 0. diff --git a/icefall/utils.py b/icefall/utils.py index 8e87d29df..60ca9dcde 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -535,8 +535,11 @@ class MetricsTracker(collections.defaultdict): ans = [] for k, v in self.items(): if k != "frames": - norm_value = float(v) / num_frames - ans.append((k, norm_value)) + if k != "sym_delay": + norm_value = float(v) / num_frames + ans.append((k, norm_value)) + else: + ans.append((k, float(v))) return ans def reduce(self, device):