diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index f79e63962..771658273 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -32,7 +32,7 @@ from scaling import ( ) from torch import Tensor, nn -from icefall.utils import make_pad_mask +from icefall.utils import make_pad_mask, subsequent_chunk_mask class Conformer(EncoderInterface): @@ -46,8 +46,27 @@ class Conformer(EncoderInterface): num_encoder_layers (int): number of encoder layers dropout (float): dropout rate layer_dropout (float): layer-dropout rate. - cnn_module_kernel (int): Kernel size of convolution module - vgg_frontend (bool): whether to use vgg frontend. + cnn_module_kernel (int): Kernel size of convolution module. + dynamic_chunk_training (bool): whether to use dynamic chunk training, if + you want to train a streaming model, this is expected to be True. + When setting True, it will use a masking strategy to make the attention + see only limited left and right context. + short_chunk_threshold (float): a threshold to determinize the chunk size + to be used in masking training, if the randomly generated chunk size + is greater than ``max_len * short_chunk_threshold`` (max_len is the + max sequence length of current batch) then it will use + full context in training (i.e. with chunk size equals to max_len). + This will be used only when dynamic_chunk_training is True. + short_chunk_size (int): see docs above, if the randomly generated chunk + size equals to or less than ``max_len * short_chunk_threshold``, the + chunk size will be sampled uniformly from 1 to short_chunk_size. + This also will be used only when dynamic_chunk_training is True. + num_left_chunks (int): the left context (in chunks) attention can see, the + chunk size is decided by short_chunk_threshold and short_chunk_size. + A minus value means seeing full left context. + This also will be used only when dynamic_chunk_training is True. + causal (bool): Whether to use causal convolution in conformer encoder + layer. This MUST be True when using dynamic_chunk_training. """ def __init__( @@ -62,6 +81,11 @@ class Conformer(EncoderInterface): layer_dropout: float = 0.075, cnn_module_kernel: int = 31, aux_layer_period: int = 3, + dynamic_chunk_training: bool = False, + short_chunk_threshold: float = 0.75, + short_chunk_size: int = 25, + num_left_chunks: int = -1, + causal: bool = False, ) -> None: super(Conformer, self).__init__() @@ -79,6 +103,15 @@ class Conformer(EncoderInterface): self.encoder_pos = RelPositionalEncoding(d_model, dropout) + self.encoder_layers = num_encoder_layers + self.d_model = d_model + self.cnn_module_kernel = cnn_module_kernel + self.causal = causal + self.dynamic_chunk_training = dynamic_chunk_training + self.short_chunk_threshold = short_chunk_threshold + self.short_chunk_size = short_chunk_size + self.num_left_chunks = num_left_chunks + encoder_layer = ConformerEncoderLayer( d_model, nhead, @@ -119,11 +152,39 @@ class Conformer(EncoderInterface): lengths = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == lengths.max().item() - mask = make_pad_mask(lengths) + src_key_padding_mask = make_pad_mask(lengths) - x = self.encoder( - x, pos_emb, src_key_padding_mask=mask, warmup=warmup - ) # (T, N, C) + if self.dynamic_chunk_training: + assert ( + self.causal + ), "Causal convolution is required for streaming conformer." + max_len = x.size(0) + chunk_size = torch.randint(1, max_len, (1,)).item() + if chunk_size > (max_len * self.short_chunk_threshold): + chunk_size = max_len + else: + chunk_size = chunk_size % self.short_chunk_size + 1 + + mask = ~subsequent_chunk_mask( + size=x.size(0), + chunk_size=chunk_size, + num_left_chunks=self.num_left_chunks, + device=x.device, + ) + x = self.encoder( + x, + pos_emb, + mask=mask, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + ) # (T, N, C) + else: + x = self.encoder( + x, + pos_emb, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + ) # (T, N, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -141,6 +202,8 @@ class ConformerEncoderLayer(nn.Module): dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). cnn_module_kernel (int): Kernel size of convolution module. + causal (bool): Whether to use causal convolution in conformer encoder + layer. This MUST be True when using dynamic_chunk_training and streaming decoding. Examples:: >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) @@ -157,6 +220,7 @@ class ConformerEncoderLayer(nn.Module): dropout: float = 0.1, layer_dropout: float = 0.075, cnn_module_kernel: int = 31, + causal: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() @@ -184,7 +248,9 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + self.conv_module = ConvolutionModule( + d_model, cnn_module_kernel, causal=causal + ) self.norm_final = BasicNorm(d_model) @@ -250,7 +316,8 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - src = src + self.dropout(self.conv_module(src)) + conv, _ = self.conv_module(src) + src = src + self.dropout(conv) # feed forward module src = src + self.dropout(self.feed_forward(src)) @@ -832,6 +899,39 @@ class RelPositionMultiheadAttention(nn.Module): ) attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + + # If we are using dynamic_chunk_training and setting a limited + # num_left_chunks, the attention may only see the padding values which + # will also be masked out by `key_padding_mask`, at this circumstances, + # the whole column of `attn_output_weights` will be `-inf` + # (i.e. be `nan` after softmax), so, we fill `0.0` at the masking + # positions to avoid invalid loss value below. + if ( + attn_mask is not None + and attn_mask.dtype == torch.bool + and key_padding_mask is not None + ): + if attn_mask.size(0) != 1: + attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) + combined_mask = attn_mask | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze( + 0 + ) | key_padding_mask.unsqueeze(1).unsqueeze(2) + + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + combined_mask, 0.0 + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + attn_output_weights = nn.functional.dropout( attn_output_weights, p=dropout_p, training=training ) @@ -865,17 +965,24 @@ class ConvolutionModule(nn.Module): channels (int): The number of channels of conv layers. kernel_size (int): Kernerl size of conv layers. bias (bool): Whether to use bias in conv layers (default=True). + causal (bool): Whether to use causal convolution. """ def __init__( - self, channels: int, kernel_size: int, bias: bool = True + self, + channels: int, + kernel_size: int, + bias: bool = True, + causal: bool = False, ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding assert (kernel_size - 1) % 2 == 0 + self.causal = causal + self.pointwise_conv1 = ScaledConv1d( channels, 2 * channels, @@ -902,12 +1009,17 @@ class ConvolutionModule(nn.Module): channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 ) + self.lorder = kernel_size - 1 + padding = (kernel_size - 1) // 2 + if self.causal: + padding = 0 + self.depthwise_conv = ScaledConv1d( channels, channels, kernel_size, stride=1, - padding=(kernel_size - 1) // 2, + padding=padding, groups=channels, bias=bias, ) @@ -928,15 +1040,20 @@ class ConvolutionModule(nn.Module): initial_scale=0.25, ) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: Tensor, cache: Optional[Tensor] = None) -> Tensor: """Compute convolution module. Args: x: Input tensor (#time, batch, channels). + cache: The cache of depthwise_conv, only used in real streaming + decoding. Returns: Tensor: Output tensor (#time, batch, channels). - + If cache is None return the output tensor (#time, batch, channels). + If cache is not None, return a tuple of Tensor, the first one is + the output tensor (#time, batch, channels), the second one is the + new cache for next chunk (#kernel_size - 1, batch, channels). """ # exchange the temporal dimension and the feature dimension x = x.permute(1, 2, 0) # (#batch, channels, time). @@ -948,6 +1065,19 @@ class ConvolutionModule(nn.Module): x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv + if self.causal and self.lorder > 0: + if cache is None: + # Make depthwise_conv causal by + # manualy padding self.lorder zeros to the left + x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) + else: + assert ( + not self.training + ), "Cache should be None in training time" + assert cache.size(0) == self.lorder + x = torch.cat([cache.permute(1, 2, 0), x], dim=2) + cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa + x = self.depthwise_conv(x) x = self.deriv_balancer2(x) @@ -955,7 +1085,11 @@ class ConvolutionModule(nn.Module): x = self.pointwise_conv2(x) # (batch, channel, time) - return x.permute(2, 0, 1) + # torch.jit.script requires return types be the same as annotated above + if cache is None: + cache = torch.empty(0) + + return x.permute(2, 0, 1), cache class Conv2dSubsampling(nn.Module): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index eaf893997..7ca6f298e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -134,6 +134,40 @@ def add_model_arguments(parser: argparse.ArgumentParser): """, ) + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -408,6 +442,10 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: nhead=params.nhead, dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, + dynamic_chunk_training=params.dynamic_chunk_training, + short_chunk_size=params.short_chunk_size, + num_left_chunks=params.num_left_chunks, + causal=params.causal_convolution, ) return encoder @@ -892,6 +930,11 @@ def run(rank, world_size, args): params.blank_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() + if params.dynamic_chunk_training: + assert ( + params.causal_convolution + ), "dynamic_chunk_training requires causal convolution" + logging.info(params) logging.info("About to create model")