diff --git a/egs/librispeech/ASR/streaming_pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/streaming_pruned_transducer_stateless/beam_search.py new file mode 120000 index 000000000..227d2247c --- /dev/null +++ b/egs/librispeech/ASR/streaming_pruned_transducer_stateless/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/streaming_pruned_transducer_stateless/conformer.py b/egs/librispeech/ASR/streaming_pruned_transducer_stateless/conformer.py index fc838f75b..8bb3c4c89 100644 --- a/egs/librispeech/ASR/streaming_pruned_transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/streaming_pruned_transducer_stateless/conformer.py @@ -15,7 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy + +import logging import math import warnings from typing import Optional, Tuple @@ -24,7 +25,7 @@ import torch from torch import Tensor, nn from transformer import Transformer -from icefall.utils import make_pad_mask +from icefall.utils import make_pad_mask, subsequent_chunk_mask class Conformer(Transformer): @@ -56,6 +57,12 @@ class Conformer(Transformer): cnn_module_kernel: int = 31, normalize_before: bool = True, vgg_frontend: bool = False, + dynamic_chunk_training: bool = True, + short_chunk_threshold: float = 0.75, + causal: bool = True, + short_chunk_size: int = 25, + use_codebook_loss: bool = False, + num_codebooks: int = 4, ) -> None: super(Conformer, self).__init__( num_features=num_features, @@ -69,6 +76,9 @@ class Conformer(Transformer): normalize_before=normalize_before, vgg_frontend=vgg_frontend, ) + self.dynamic_chunk_training = dynamic_chunk_training + self.short_chunk_threshold = short_chunk_threshold + self.short_chunk_size = short_chunk_size self.encoder_pos = RelPositionalEncoding(d_model, dropout) @@ -79,6 +89,7 @@ class Conformer(Transformer): dropout, cnn_module_kernel, normalize_before, + causal, ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) self.normalize_before = normalize_before @@ -112,9 +123,176 @@ class Conformer(Transformer): # Caution: We assume the subsampling factor is 4! lengths = ((x_lens - 1) // 2 - 1) // 2 assert x.size(0) == lengths.max().item() - mask = make_pad_mask(lengths) + src_key_padding_mask = make_pad_mask(lengths) - x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C) + if self.dynamic_chunk_training: + max_len = x.size(0) + chunk_size = torch.randint(1, max_len, (1,)).item() + if chunk_size > (max_len * self.short_chunk_threshold): + chunk_size = max_len + else: + chunk_size = chunk_size % self.short_chunk_size + 1 + + mask = ~subsequent_chunk_mask( + size=x.size(0), chunk_size=chunk_size, device=x.device + ) + x = self.encoder( + x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask + ) # (T, B, F) + else: + x = self.encoder( + x, pos_emb, src_key_padding_mask=src_key_padding_mask + ) # (T, N, C) + + if self.normalize_before: + x = self.after_norm(x) + + logits = self.encoder_output_layer(x) + logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return logits, lengths + + def streaming_forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + chunk_size: int = 16, + simulate_streaming: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # x: [N, T, C] + + # Caution: We assume the subsampling factor is 4! + lengths = ((x_lens - 1) // 2 - 1) // 2 + src_key_padding_mask = make_pad_mask(lengths) + + if chunk_size < 0: + # Deocding with full-right context. + x = self.encoder_embed(x) + + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + assert x.size(0) == lengths.max().item() + + x = self.encoder( + x, pos_emb, src_key_padding_mask=src_key_padding_mask + ) # (T, B, F) + else: + # As temporarily in icefall only subsampling_rate == 4 is supported, + # following parameters are hard-coded here. + # Change it accordingly if other subsamling_rate are supported. + # The first frame encoder_out needs at least 7 frames fbank feature + embed_left_context = 7 + # Each successive frame needs 4 cached frames fbank feature + subsampling_rate = 4 + # So the extra frames needed to generate the first frame encoder_out: + embed_conv_context = embed_left_context - subsampling_rate + + stride = chunk_size * subsampling_rate + decoding_window = embed_conv_context + stride + if simulate_streaming: + # simulate chunk_by_chunk streaming decoding + # Results of this branch should be identical to following + # "else" branch. + # But this branch is a little slower + # as the feature is feeded chunk by chunk + + # store the result of chunk_by_chunk decoding + encoder_output = [] + + # caches + pos_emb_positive = [] + pos_emb_negative = [] + pos_emb_central = None + encoder_cache = [None for i in range(len(self.encoder.layers))] + conv_cache = [None for i in range(len(self.encoder.layers))] + + # start chunk_by_chunk decoding + offset = 0 + feature = x + num_frames = feature.size(1) + for cur in range( + 0, num_frames - embed_left_context + 1, stride + ): + end = min(cur + decoding_window, num_frames) + cur_feature = feature[:, cur:end, :] + cur_feature = self.encoder_embed(cur_feature) + cur_embed, cur_pos_emb = self.encoder_pos( + cur_feature, offset + ) + cur_embed = cur_embed.permute( + 1, 0, 2 + ) # (B, T, F) -> (T, B, F) + + cur_T = cur_feature.size(1) + if cur == 0: + real_chunk_size = min(cur_T, chunk_size) + assert ( + cur_pos_emb.size(1) == 2 * real_chunk_size - 1 + ), f"{cur_pos_emb.size(1)} == 2 * {real_chunk_size} - 1" + + # Extract the central pos embedding during first chunk + pos_emb_central = cur_pos_emb[ + 0, (real_chunk_size - 1), : + ].view(1, 1, -1) + cur_T -= 1 + + # first chunk with chunk_size > 1 + # or not first chunk + if (cur_T > 1 and cur == 0) or cur != 0: + pos_emb_positive.append(cur_pos_emb[0, :cur_T].flip(0)) + pos_emb_negative.append(cur_pos_emb[0, -cur_T:]) + + assert pos_emb_positive[-1].size(0) == cur_T + + pos_emb_pos = torch.cat( + pos_emb_positive, dim=0 + ).unsqueeze(0) + pos_emb_neg = torch.cat( + pos_emb_negative, dim=0 + ).unsqueeze(0) + cur_pos_emb = torch.cat( + [pos_emb_pos.flip(1), pos_emb_central, pos_emb_neg], + dim=1, + ) + + x = self.encoder.chunk_forward( + cur_embed, + cur_pos_emb, + src_key_padding_mask=src_key_padding_mask[ + :, : offset + cur_embed.size(0) + ], + encoder_cache=encoder_cache, + conv_cache=conv_cache, + offset=offset, + ) # (T, B, F) + encoder_output.append(x) + offset += cur_embed.size(0) + + assert num_frames - end <= 3 + if num_frames != end: + logging.info( + f"The tailing {num_frames - end} frames fbank are not deocded." + ) + x = torch.cat(encoder_output, dim=0) + + else: + # NOT simulate chunk_by_chunk decoding + # Results of this branch should be identical to previous + # simulate chunk_by_chunk decoding branch. + # But this branch is faster. + x = self.encoder_embed(x) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + assert x.size(0) == lengths.max().item() + mask = ~subsequent_chunk_mask( + size=x.size(0), chunk_size=chunk_size, device=x.device + ) + x = self.encoder( + x, + pos_emb, + mask=mask, + src_key_padding_mask=src_key_padding_mask, + ) # (T, N, C) if self.normalize_before: x = self.after_norm(x) @@ -153,6 +331,7 @@ class ConformerEncoderLayer(nn.Module): dropout: float = 0.1, cnn_module_kernel: int = 31, normalize_before: bool = True, + causal: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() self.self_attn = RelPositionMultiheadAttention( @@ -173,7 +352,9 @@ class ConformerEncoderLayer(nn.Module): nn.Linear(dim_feedforward, d_model), ) - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + self.conv_module = ConvolutionModule( + d_model, cnn_module_kernel, causal=causal + ) self.norm_ff_macaron = nn.LayerNorm( d_model @@ -263,13 +444,105 @@ class ConformerEncoderLayer(nn.Module): return src + def chunk_forward( + self, + src: Tensor, + pos_emb: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + encoder_cache: Optional[Tensor] = None, + conv_cache: Optional[Tensor] = None, + offset=0, + ) -> Tensor: + """ + Pass the input through the encoder layer. -class ConformerEncoder(nn.Module): + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + + # macaron style feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff_macaron(src) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) + if not self.normalize_before: + src = self.norm_ff_macaron(src) + + # multi-headed self-attention module + residual = src + if self.normalize_before: + src = self.norm_mha(src) + if encoder_cache is None: + # src: [chunk_size, N, F] e.g. [8, 41, 512] + key = src + val = key + encoder_cache = key + else: + key = torch.cat([encoder_cache, src], dim=0) + val = key + encoder_cache = key + src_att = self.self_attn( + src, + key, + val, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + offset=offset, + )[0] + src = residual + self.dropout(src_att) + if not self.normalize_before: + src = self.norm_mha(src) + + # convolution module + residual = src # [chunk_size, N, F] e.g. [8, 41, 512] + if self.normalize_before: + src = self.norm_conv(src) + if conv_cache is not None: + src = torch.cat([conv_cache, src], dim=0) + conv_cache = src + + src = self.conv_module(src) + src = src[-residual.size(0) :, :, :] # noqa: E203 + + src = residual + self.dropout(src) + if not self.normalize_before: + src = self.norm_conv(src) + + # feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff(src) + src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) + if not self.normalize_before: + src = self.norm_ff(src) + + if self.normalize_before: + src = self.norm_final(src) + + return src, encoder_cache, conv_cache + + +class ConformerEncoder(nn.TransformerEncoder): r"""ConformerEncoder is a stack of N encoder layers Args: encoder_layer: an instance of the ConformerEncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). Examples:: >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) @@ -279,10 +552,11 @@ class ConformerEncoder(nn.Module): >>> out = conformer_encoder(src, pos_emb) """ - def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: - super().__init__() - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] + def __init__( + self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None + ) -> None: + super(ConformerEncoder, self).__init__( + encoder_layer=encoder_layer, num_layers=num_layers, norm=norm ) self.num_layers = num_layers @@ -319,6 +593,55 @@ class ConformerEncoder(nn.Module): src_key_padding_mask=src_key_padding_mask, ) + if self.norm is not None: + output = self.norm(output) + + return output + + def chunk_forward( + self, + src: Tensor, + pos_emb: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + encoder_cache=None, + conv_cache=None, + offset=0, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + output = src + + for layer_index, mod in enumerate(self.layers): + output, e_cache, c_cache = mod.chunk_forward( + output, + pos_emb, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + encoder_cache=encoder_cache[layer_index], + conv_cache=conv_cache[layer_index], + offset=offset, + ) + encoder_cache[layer_index] = e_cache + conv_cache[layer_index] = c_cache + + if self.norm is not None: + output = self.norm(output) + return output @@ -346,12 +669,13 @@ class RelPositionalEncoding(torch.nn.Module): self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - def extend_pe(self, x: Tensor) -> None: + def extend_pe(self, x: Tensor, offset: int = 0) -> None: """Reset the positional encodings.""" + x_size_1 = offset + x.size(1) if self.pe is not None: # self.pe contains both positive and negative parts # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device if self.pe.dtype != x.dtype or str(self.pe.device) != str( x.device @@ -361,9 +685,9 @@ class RelPositionalEncoding(torch.nn.Module): # Suppose `i` means to the position of query vecotr and `j` means the # position of key vector. We use position relative positions when keys # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + def forward( + self, x: torch.Tensor, offset: int = 0 + ) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: x (torch.Tensor): Input tensor (batch, time, `*`). + offset: time-index of the first frame of x. + used to compute positional encoding in a streaming fasion. Returns: torch.Tensor: Encoded tensor (batch, time, `*`). torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). """ - self.extend_pe(x) + self.extend_pe(x, offset) x = x * self.xscale + x_size_1 = offset + x.size(1) pos_emb = self.pe[ :, self.pe.size(1) // 2 - - x.size(1) + - x_size_1 + 1 : self.pe.size(1) // 2 # noqa E203 - + x.size(1), + + x_size_1, ] + x_T = x.size(1) + if offset > 0: + pos_emb = torch.cat([pos_emb[:, :x_T], pos_emb[:, -x_T:]], dim=1) + return self.dropout(x), self.dropout(pos_emb) @@ -464,6 +797,7 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, + offset=0, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: @@ -522,9 +856,10 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, + offset=offset, ) - def rel_shift(self, x: Tensor) -> Tensor: + def rel_shift(self, x: Tensor, offset=0) -> Tensor: """Compute relative positional encoding. Args: @@ -533,18 +868,20 @@ class RelPositionMultiheadAttention(nn.Module): Returns: Tensor: tensor of shape (batch, head, time1, time2) - (note: time2 has the same value as time1, but it is for + (note: time2 == time1 + offset, since it is for the key, while time1 is for the query). """ (batch_size, num_heads, time1, n) = x.shape - assert n == 2 * time1 - 1 + time2 = time1 + offset + assert n == 2 * time2 - 1, f"{n} == 2 * {time2} - 1" # Note: TorchScript requires explicit arg for stride() batch_stride = x.stride(0) head_stride = x.stride(1) time1_stride = x.stride(2) n_stride = x.stride(3) + return x.as_strided( - (batch_size, num_heads, time1, time1), + (batch_size, num_heads, time1, time2), (batch_stride, head_stride, time1_stride - n_stride, n_stride), storage_offset=n_stride * (time1 - 1), ) @@ -566,6 +903,7 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, + offset=0, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: @@ -639,7 +977,6 @@ class RelPositionMultiheadAttention(nn.Module): if _b is not None: _b = _b[_start:_end] q = nn.functional.linear(query, _w, _b) - # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim @@ -745,7 +1082,9 @@ class RelPositionMultiheadAttention(nn.Module): pos_emb_bsz = pos_emb.size(0) assert pos_emb_bsz in (1, bsz) # actually it is 1 p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) - p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1) + p = p.permute(0, 2, 3, 1) q_with_bias_u = (q + self.pos_bias_u).transpose( 1, 2 @@ -765,10 +1104,11 @@ class RelPositionMultiheadAttention(nn.Module): # compute matrix b and matrix d matrix_bd = torch.matmul( - q_with_bias_v, p.transpose(-2, -1) + q_with_bias_v, p ) # (batch, head, time1, 2*time1-1) - matrix_bd = self.rel_shift(matrix_bd) - + matrix_bd = self.rel_shift( + matrix_bd, offset=offset + ) # [B, head, time1, time2] attn_output_weights = ( matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) @@ -835,11 +1175,16 @@ class ConvolutionModule(nn.Module): channels (int): The number of channels of conv layers. kernel_size (int): Kernerl size of conv layers. bias (bool): Whether to use bias in conv layers (default=True). + causal (bool): Whether to use causal convlution (default=True). """ def __init__( - self, channels: int, kernel_size: int, bias: bool = True + self, + channels: int, + kernel_size: int, + bias: bool = True, + causal: bool = True, ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() @@ -854,12 +1199,20 @@ class ConvolutionModule(nn.Module): padding=0, bias=bias, ) + assert ( + causal + ), "Currently, causal convolution is required for streaming conformer." + + # Manualy padding self.lorder zeros to the left during forward. + self.lorder = kernel_size - 1 + padding = 0 + self.depthwise_conv = nn.Conv1d( channels, channels, kernel_size, stride=1, - padding=(kernel_size - 1) // 2, + padding=padding, groups=channels, bias=bias, ) @@ -892,6 +1245,10 @@ class ConvolutionModule(nn.Module): x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv + if self.lorder > 0: + # Make depthwise_conv causal by + # manualy padding self.lorder zeros to the left + x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) x = self.depthwise_conv(x) # x is (batch, channels, time) x = x.permute(0, 2, 1) diff --git a/egs/librispeech/ASR/streaming_pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/streaming_pruned_transducer_stateless/decode.py index 9479d57a8..3ae49c3e8 100755 --- a/egs/librispeech/ASR/streaming_pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/streaming_pruned_transducer_stateless/decode.py @@ -18,18 +18,22 @@ """ Usage: (1) greedy search -./pruned_transducer_stateless/decode.py \ - --epoch 28 \ +./streaming_pruned_transducer_stateless/decode.py \ + --simulate-streaming [True|False] \ + --right-chunk-size [1/4/8/16/32/-1] \ + --epoch 49 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless/exp \ + --exp-dir ./streaming_pruned_transducer_stateless/exp \ --max-duration 100 \ --decoding-method greedy_search (2) beam search -./pruned_transducer_stateless/decode.py \ - --epoch 28 \ +./streaming_pruned_transducer_stateless/decode.py \ + --simulate-streaming [True|False] \ + --right-chunk-size [1/4/8/16/32/-1] \ + --epoch 49 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless/exp \ + --exp-dir ./streaming_pruned_transducer_stateless/exp \ --max-duration 100 \ --decoding-method beam_search \ --beam-size 4 @@ -38,6 +42,7 @@ Usage: import argparse import logging +import os from collections import defaultdict from pathlib import Path from typing import Dict, List, Tuple @@ -52,12 +57,17 @@ from decoder import Decoder from joiner import Joiner from model import Transducer -from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + load_checkpoint, + save_checkpoint, +) from icefall.env import get_env_info from icefall.utils import ( AttributeDict, setup_logger, store_transcripts, + str2bool, write_error_stats, ) @@ -67,6 +77,29 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="Whether to split fbanks into chunks to simulate forward conformer" + "in a streaming fashion", + ) + + parser.add_argument( + "--tailing-dummy-frames", + type=int, + default=20, + help="tailing dummy frames padded to the right," + "only used during decoding", + ) + + parser.add_argument( + "--right-chunk-size", + type=int, + default=16, + help="right context to attend during decoding", + ) + parser.add_argument( "--epoch", type=int, @@ -86,7 +119,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless/exp", + default="streaming_pruned_transducer_stateless/exp", help="The experiment dir", ) @@ -145,6 +178,8 @@ def get_params() -> AttributeDict: # parameters for decoder "embedding_dim": 512, "env_info": get_env_info(), + # model average + "save_averaged_model": False, } ) return params @@ -236,10 +271,26 @@ def decode_one_batch( # at entry, feature is (N, T, C) supervisions = batch["supervisions"] + # Extra dummy tailing frames my reduce deletion error + # example WITHOUT padding: + # CHAPTER SEVEN ON THE RACES OF MAN + # example WITH padding: + # CHAPTER SEVEN ON THE RACES OF (MAN->*) + tailing_frames = ( + torch.tensor([-23.0259]) + .expand([feature.size(0), params.tailing_dummy_frames, 80]) + .to(feature.device) + ) + feature = torch.cat([feature, tailing_frames], dim=1) + supervisions["num_frames"] += params.tailing_dummy_frames + feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens + encoder_out, encoder_out_lens = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.right_chunk_size, + simulate_streaming=params.simulate_streaming, ) hyps = [] batch_size = encoder_out.size(0) @@ -395,6 +446,9 @@ def main(): params.res_dir = params.exp_dir / params.decoding_method params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + params.suffix += f"-chunk_size-{params.right_chunk_size}" + params.suffix += f"-{params.simulate_streaming}" + params.suffix += f"-tailing-dummy-frams-{params.tailing_dummy_frames}" if params.decoding_method == "beam_search": params.suffix += f"-beam-{params.beam_size}" else: @@ -425,15 +479,24 @@ def main(): if params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + model_path = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" # noqa: E501 + if os.path.isfile(model_path): + load_checkpoint(model_path, model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + if params.save_averaged_model: + save_checkpoint( + filename=model_path, + model=model, + ) model.to(device) model.eval() model.device = device diff --git a/egs/librispeech/ASR/streaming_pruned_transducer_stateless/train.py b/egs/librispeech/ASR/streaming_pruned_transducer_stateless/train.py index f0ea2ccaa..1924f7b58 100755 --- a/egs/librispeech/ASR/streaming_pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/streaming_pruned_transducer_stateless/train.py @@ -21,11 +21,13 @@ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3" -./pruned_transducer_stateless/train.py \ +./streaming_pruned_transducer_stateless/train.py \ + --short-chunk-size=25 \ --world-size 4 \ - --num-epochs 30 \ + --full-libri 1 \ + --num-epochs 50 \ --start-epoch 0 \ - --exp-dir pruned_transducer_stateless/exp \ + --exp-dir streaming_pruned_transducer_stateless/exp \ --full-libri 1 \ --max-duration 300 """ @@ -74,6 +76,12 @@ def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="chunk length of dynamic training", + ) parser.add_argument( "--world-size", @@ -252,6 +260,8 @@ def get_params() -> AttributeDict: "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, + "dynamic_chunk_training": True, + "causal": True, # Now only causal convolution is verified # parameters for decoder "embedding_dim": 512, # parameters for Noam @@ -274,6 +284,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, vgg_frontend=params.vgg_frontend, + dynamic_chunk_training=params.dynamic_chunk_training, + short_chunk_size=params.short_chunk_size, + causal=params.causal, ) return encoder diff --git a/icefall/utils.py b/icefall/utils.py index c231dbbe4..b38f7a3d7 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -694,6 +694,42 @@ def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor: return expaned_lengths >= lengths.unsqueeze(1) +# From https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py#L42 +def subsequent_chunk_mask( + size: int, + chunk_size: int, + num_left_chunks: int = -1, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Create mask for subsequent steps (size, size) with chunk size, + this is for streaming encoder + Args: + size (int): size of mask + chunk_size (int): size of chunk + num_left_chunks (int): number of left chunks + <0: use full chunk + >=0: use num_left_chunks + device (torch.device): "cpu" or "cuda" or torch.Tensor.device + Returns: + torch.Tensor: mask + Examples: + >>> subsequent_chunk_mask(4, 2) + [[1, 1, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 1], + [1, 1, 1, 1]] + """ + ret = torch.zeros(size, size, device=device, dtype=torch.bool) + for i in range(size): + if num_left_chunks < 0: + start = 0 + else: + start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) + ending = min((i // chunk_size + 1) * chunk_size, size) + ret[i, start:ending] = True + return ret + + def l1_norm(x): return torch.sum(torch.abs(x))