From 0325e3a04e50680c80d41d01e474bb02cf727bf2 Mon Sep 17 00:00:00 2001 From: pkufool Date: Sun, 29 May 2022 16:25:01 +0800 Subject: [PATCH] Add torch.jit.export --- .../ASR/pruned_transducer_stateless/decode.py | 3 +- .../ASR/pruned_transducer_stateless/export.py | 42 +++ .../ASR/pruned_transducer_stateless/train.py | 1 - .../pruned_transducer_stateless2/conformer.py | 262 ++++++++++++------ .../pruned_transducer_stateless2/decode.py | 3 +- .../pruned_transducer_stateless2/export.py | 41 +++ .../ASR/transducer_stateless/conformer.py | 216 ++++++++++----- 7 files changed, 415 insertions(+), 153 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index b08542950..c5e3465e9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -375,9 +375,10 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.simulate_streaming: - encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + encoder_out, encoder_out_lens = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, + states=[], chunk_size=params.right_chunk_size, left_context=params.left_context, simulate_streaming=True, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index a4210831c..fb16ab6c1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -109,6 +109,47 @@ def get_parser(): "2 means tri-gram", ) + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + Note: not needed here, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + Note: not needed for here, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="""How many left context can be seen in chunks when calculating attention. + Note: not needed here, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + return parser @@ -130,6 +171,7 @@ def main(): # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 41020255b..84fad00ff 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -288,7 +288,6 @@ def get_parser(): """, ) - return parser diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index a5186e150..8a039d6f9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -18,7 +18,7 @@ import copy import math import warnings -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple import torch from encoder_interface import EncoderInterface @@ -157,7 +157,6 @@ class Conformer(EncoderInterface): assert x.size(0) == lengths.max().item() src_key_padding_mask = make_pad_mask(lengths) - mask = None if self.dynamic_chunk_training: assert ( @@ -176,24 +175,32 @@ class Conformer(EncoderInterface): num_left_chunks=self.num_left_chunks, device=x.device, ) - - x, _ = self.encoder( - x, - pos_emb, - mask=mask, - src_key_padding_mask=src_key_padding_mask, - warmup=warmup, - ) # (T, N, C) + x = self.encoder( + x, + pos_emb, + mask=mask, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + ) # (T, N, C) + else: + x = self.encoder( + x, + pos_emb, + mask=None, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + ) # (T, N, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) return x, lengths + @torch.jit.export def streaming_forward( self, x: torch.Tensor, x_lens: torch.Tensor, + states: List[Tensor], warmup: float = 1.0, - states: Optional[List[Tensor]] = None, chunk_size: int = 16, left_context: int = 64, simulate_streaming: bool = False, @@ -205,17 +212,17 @@ class Conformer(EncoderInterface): x_lens: A tensor of shape (batch_size,) containing the number of frames in `x` before padding. - warmup: - A floating point value that gradually increases from 0 throughout - training; when it is >= 1.0 we are "fully warmed up". It is used - to turn modules on sequentially. states: The decode states for previous frames which contains the cached data. It has two elements, the first element is the attn_cache which has a shape of (encoder_layers, left_context, batch, attention_dim), the second element is the conv_cache which has a shape of (encoder_layers, cnn_module_kernel-1, batch, conv_dim). - Note: If not None, states will be modified in this function. + Note: states will be modified in this function. + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. chunk_size: The chunk size for decoding, this will be used to simulate streaming decoding using masking. @@ -245,10 +252,6 @@ class Conformer(EncoderInterface): lengths = (((x_lens - 1) >> 1) - 1) >> 1 if not simulate_streaming: - assert ( - states is not None - ), "Require cache when sending data in streaming mode" - assert ( len(states) == 2 and states[0].shape @@ -272,7 +275,7 @@ class Conformer(EncoderInterface): embed, pos_enc = self.encoder_pos(embed, left_context) embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - x = self.encoder( + x = self.encoder.chunk_forward( embed, pos_enc, src_key_padding_mask=src_key_padding_mask, @@ -282,7 +285,6 @@ class Conformer(EncoderInterface): ) # (T, B, F) else: - assert states is None src_key_padding_mask = make_pad_mask(lengths) x = self.encoder_embed(x) @@ -392,9 +394,7 @@ class ConformerEncoderLayer(nn.Module): src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, warmup: float = 1.0, - states: Optional[List[Tensor]] = None, - left_context: int = 0, - ) -> Tuple[Tensor]: + ) -> Tensor: """ Pass the input through the encoder layer. @@ -405,20 +405,9 @@ class ConformerEncoderLayer(nn.Module): src_key_padding_mask: the mask for the src keys per batch (optional). warmup: controls selective bypass of of layers; if < 1.0, we will bypass layers more frequently. - states: - The decode states for previous frames which contains the cached data. - It has two elements, the first element is the attn_cache which has - a shape of (left_context, batch, attention_dim), - the second element is the conv_cache which has a shape of - (cnn_module_kernel-1, batch, conv_dim). - Note: If not None, states will be modified in this function. - left_context: left context (in frames) used during streaming decoding. - this is used only in real streaming decoding, in other circumstances, - it MUST be 0. - Shape: src: (S, N, E). - pos_emb: (N, 2*S-1, E),for streaming decoding it is (N, 2*(S+left_context)-1, E). + pos_emb: (N, 2*S-1, E) src_mask: (S, S). src_key_padding_mask: (N, S). S is the source sequence length, N is the batch size, E is the feature number @@ -440,15 +429,82 @@ class ConformerEncoderLayer(nn.Module): # macaron style feed forward module src = src + self.dropout(self.feed_forward_macaron(src)) - key = src - val = src - if not self.training and states is not None: - # src: [chunk_size, N, F] e.g. [8, 41, 512] - key = torch.cat([states[0], src], dim=0) - val = key - states[0] = key[-left_context:, ...] - else: - assert left_context == 0 + # multi-headed self-attention module + src_att = self.self_attn( + src, + src, + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + + src = src + self.dropout(src_att) + + # convolution module + conv, _ = self.conv_module(src) + src = src + self.dropout(conv) + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + if alpha != 1.0: + src = alpha * src + (1 - alpha) * src_orig + + return src + + @torch.jit.export + def chunk_forward( + self, + src: Tensor, + pos_emb: Tensor, + states: List[Tensor], + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + warmup: float = 1.0, + left_context: int = 0, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + states: + The decode states for previous frames which contains the cached data. + It has two elements, the first element is the attn_cache which has + a shape of (left_context, batch, attention_dim), + the second element is the conv_cache which has a shape of + (cnn_module_kernel-1, batch, conv_dim). + Note: states will be modified in this function. + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + warmup: controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + left_context: left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. + + Shape: + src: (S, N, E). + pos_emb: (N, 2*(S+left_context)-1, E). + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + + assert not self.training + assert len(states) == 2 + assert states[0].shape == (left_context, src.size(1), src.size(2)) + + # macaron style feed forward module + src = src + self.dropout(self.feed_forward_macaron(src)) + + key = torch.cat([states[0], src], dim=0) + val = key + states[0] = key[-left_context:, ...] # multi-headed self-attention module src_att = self.self_attn( @@ -464,11 +520,9 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - if not self.training and states is not None: - conv, conv_cache = self.conv_module(src, states[1]) - states[1] = conv_cache - else: - conv = self.conv_module(src) + conv, conv_cache = self.conv_module(src, states[1]) + states[1] = conv_cache + src = src + self.dropout(conv) # feed forward module @@ -476,9 +530,6 @@ class ConformerEncoderLayer(nn.Module): src = self.norm_final(self.balancer(src)) - if alpha != 1.0: - src = alpha * src + (1 - alpha) * src_orig - return src @@ -511,8 +562,6 @@ class ConformerEncoder(nn.Module): mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, warmup: float = 1.0, - states: Optional[List[Tensor]] = None, - left_context: int = 0, ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -523,20 +572,10 @@ class ConformerEncoder(nn.Module): src_key_padding_mask: the mask for the src keys per batch (optional). warmup: controls selective bypass of of layers; if < 1.0, we will bypass layers more frequently. - states: - The decode states for previous frames which contains the cached data. - It has two elements, the first element is the attn_cache which has - a shape of (encoder_layers, left_context, batch, attention_dim), - the second element is the conv_cache which has a shape of - (encoder_layers, cnn_module_kernel-1, batch, conv_dim). - Note: If not None, states will be modified in this function. - left_context: left context (in frames) used during streaming decoding. - this is used only in real streaming decoding, in other circumstances, - it MUST be 0. Shape: src: (S, N, E). - pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E). + pos_emb: (N, 2*S-1, E) mask: (S, S). src_key_padding_mask: (N, S). S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number @@ -544,30 +583,81 @@ class ConformerEncoder(nn.Module): """ output = src - if self.training: - assert left_context == 0 - assert states is None - else: - assert left_context >= 0 - for layer_index, mod in enumerate(self.layers): - cache = ( - None - if states is None - else [states[0][layer_index], states[1][layer_index]] - ) output = mod( output, pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, warmup=warmup, + ) + + return output + + @torch.jit.export + def chunk_forward( + self, + src: Tensor, + pos_emb: Tensor, + states: List[Tensor], + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + warmup: float = 1.0, + left_context: int = 0, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + states: + The decode states for previous frames which contains the cached data. + It has two elements, the first element is the attn_cache which has + a shape of (encoder_layers, left_context, batch, attention_dim), + the second element is the conv_cache which has a shape of + (encoder_layers, cnn_module_kernel-1, batch, conv_dim). + Note: states will be modified in this function. + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + warmup: controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + left_context: left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. + + Shape: + src: (S, N, E). + pos_emb: (N, 2*(S+left_context)-1, E). + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + assert not self.training + assert len(states) == 2 + assert states[0].shape == ( + len(self.layers), + left_context, + src.size(1), + src.size(2), + ) + assert states[1].size(0) == len(self.layers) + + output = src + + for layer_index, mod in enumerate(self.layers): + cache = [states[0][layer_index], states[1][layer_index]] + output = mod.chunk_forward( + output, + pos_emb, states=cache, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, left_context=left_context, ) - if states is not None: - states[0][layer_index] = cache[0] - states[1][layer_index] = cache[1] + states[0][layer_index] = cache[0] + states[1][layer_index] = cache[1] return output @@ -1216,7 +1306,7 @@ class ConvolutionModule(nn.Module): self, x: Tensor, cache: Optional[Tensor] = None, - ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + ) -> Tuple[Tensor, Tensor]: """Compute convolution module. Args: @@ -1260,9 +1350,11 @@ class ConvolutionModule(nn.Module): x = self.pointwise_conv2(x) # (batch, channel, time) - return ( - x.permute(2, 0, 1) if cache is None else (x.permute(2, 0, 1), cache) - ) + # torch.jit.script requires return types be the same as annotated above + if cache is None: + cache = torch.empty(0) + + return x.permute(2, 0, 1), cache class Conv2dSubsampling(nn.Module): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index e812fb534..57f05c6bf 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -335,9 +335,10 @@ def decode_one_batch( ) if params.simulate_streaming: - encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + encoder_out, encoder_out_lens = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, + states=[], chunk_size=params.right_chunk_size, left_context=params.left_context, simulate_streaming=True, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index cff9c7377..a8cbbbd9b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -124,6 +124,47 @@ def get_parser(): "2 means tri-gram", ) + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + Note: not needed here, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + Note: not needed for here, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="""How many left context can be seen in chunks when calculating attention. + Note: not needed here, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + return parser diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 7d3c9869d..3d6b089c1 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -18,7 +18,7 @@ import copy import math import warnings -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple import torch from torch import Tensor, nn @@ -155,7 +155,6 @@ class Conformer(Transformer): assert x.size(0) == lengths.max().item() src_key_padding_mask = make_pad_mask(lengths) - mask = None if self.dynamic_chunk_training: assert ( @@ -174,10 +173,13 @@ class Conformer(Transformer): num_left_chunks=self.num_left_chunks, device=x.device, ) - - x, _ = self.encoder( - x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask - ) # (T, N, C) + x = self.encoder( + x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask + ) # (T, N, C) + else: + x = self.encoder( + x, pos_emb, mask=None, src_key_padding_mask=src_key_padding_mask + ) # (T, N, C) if self.normalize_before: x = self.after_norm(x) @@ -187,11 +189,12 @@ class Conformer(Transformer): return logits, lengths + @torch.jit.export def streaming_forward( self, x: torch.Tensor, x_lens: torch.Tensor, - states: Optional[List[torch.Tensor]] = None, + states: List[torch.Tensor], chunk_size: int = 16, left_context: int = 64, simulate_streaming: bool = False, @@ -209,7 +212,7 @@ class Conformer(Transformer): a shape of (encoder_layers, left_context, batch, attention_dim), the second element is the conv_cache which has a shape of (encoder_layers, cnn_module_kernel-1, batch, conv_dim). - Note: If not None, states will be modified in this function. + Note: states will be modified in this function. chunk_size: The chunk size for decoding, this will be used to simulate streaming decoding using masking. @@ -239,10 +242,6 @@ class Conformer(Transformer): lengths = (((x_lens - 1) >> 1) - 1) >> 1 if not simulate_streaming: - assert ( - states is not None - ), "Require cache when sending data in streaming mode" - assert ( len(states) == 2 and states[0].shape @@ -266,17 +265,14 @@ class Conformer(Transformer): embed, pos_enc = self.encoder_pos(embed, left_context) embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - x = self.encoder( + x = self.encoder.chunk_forward( embed, pos_enc, src_key_padding_mask=src_key_padding_mask, states=states, left_context=left_context, ) # (T, B, F) - else: - assert states is None - src_key_padding_mask = make_pad_mask(lengths) x = self.encoder_embed(x) x, pos_emb = self.encoder_pos(x) @@ -389,8 +385,6 @@ class ConformerEncoderLayer(nn.Module): pos_emb: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - states: Optional[List[Tensor]] = None, - left_context: int = 0, ) -> Tensor: """ Pass the input through the encoder layer. @@ -400,19 +394,95 @@ class ConformerEncoderLayer(nn.Module): pos_emb: Positional embedding tensor (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E). + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + # macaron style feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff_macaron(src) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) + if not self.normalize_before: + src = self.norm_ff_macaron(src) + + # multi-headed self-attention module + residual = src + if self.normalize_before: + src = self.norm_mha(src) + + src_att = self.self_attn( + src, + src, + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = residual + self.dropout(src_att) + if not self.normalize_before: + src = self.norm_mha(src) + + # convolution module + residual = src + if self.normalize_before: + src = self.norm_conv(src) + + src, _ = self.conv_module(src) + src = residual + self.dropout(src) + + if not self.normalize_before: + src = self.norm_conv(src) + + # feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff(src) + src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) + if not self.normalize_before: + src = self.norm_ff(src) + + if self.normalize_before: + src = self.norm_final(src) + + return src + + @torch.jit.export + def chunk_forward( + self, + src: Tensor, + pos_emb: Tensor, + states: List[Tensor], + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + left_context: int = 0, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). states: The decode states for previous frames which contains the cached data. It has two elements, the first element is the attn_cache which has a shape of (encoder_layers, left_context, batch, attention_dim), the second element is the conv_cache which has a shape of (encoder_layers, cnn_module_kernel-1, batch, conv_dim). - Note: If not None, states will be modified in this function. + Note: states will be modified in this function. + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). left_context: left context (in frames) used during streaming decoding. this is used only in real streaming decoding, in other circumstances, it MUST be 0. Shape: src: (S, N, E). - pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E). + pos_emb: (N, 2*(S+left_context)-1, E). src_mask: (S, S). src_key_padding_mask: (N, S). S is the source sequence length, N is the batch size, E is the feature number @@ -433,15 +503,9 @@ class ConformerEncoderLayer(nn.Module): if self.normalize_before: src = self.norm_mha(src) - key = src - val = src - if not self.training and states is not None: - # src: [chunk_size, N, F] e.g. [8, 41, 512] - key = torch.cat([states[0], src], dim=0) - val = key - states[0] = key[-left_context:, ...] - else: - assert left_context == 0 + key = torch.cat([states[0], src], dim=0) + val = key + states[0] = key[-left_context:, ...] src_att = self.self_attn( src, @@ -461,11 +525,8 @@ class ConformerEncoderLayer(nn.Module): if self.normalize_before: src = self.norm_conv(src) - if not self.training and states is not None: - src, conv_cache = self.conv_module(src, states[1]) - states[1] = conv_cache - else: - src = self.conv_module(src) + src, conv_cache = self.conv_module(src, states[1]) + states[1] = conv_cache src = residual + self.dropout(src) if not self.normalize_before: @@ -513,8 +574,6 @@ class ConformerEncoder(nn.Module): pos_emb: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - states: Optional[List[Tensor]] = None, - left_context: int = 0, ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -523,21 +582,11 @@ class ConformerEncoder(nn.Module): pos_emb: Positional embedding tensor (required). mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). - states: - The decode states for previous frames which contains the cached data. - It has two elements, the first element is the attn_cache which has - a shape of (encoder_layers, left_context, batch, attention_dim), - the second element is the conv_cache which has a shape of - (encoder_layers, cnn_module_kernel-1, batch, conv_dim). - Note: If not None, states will be modified in this function. - left_context: left context (in frames) used during streaming decoding. - this is used only in real streaming decoding, in other circumstances, - it MUST be 0. Shape: Shape: src: (S, N, E). - pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E). + pos_emb: (N, 2*S-1, E). mask: (S, S). src_key_padding_mask: (N, S). S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number @@ -545,29 +594,65 @@ class ConformerEncoder(nn.Module): """ output = src - if self.training: - assert left_context == 0 - assert states is None - else: - assert left_context >= 0 - for layer_index, mod in enumerate(self.layers): - cache = ( - None - if states is None - else [states[0][layer_index], states[1][layer_index]] - ) output = mod( output, pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, + ) + return output + + @torch.jit.export + def chunk_forward( + self, + src: Tensor, + pos_emb: Tensor, + states: List[Tensor], + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + left_context: int = 0, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + states: + The decode states for previous frames which contains the cached data. + It has two elements, the first element is the attn_cache which has + a shape of (encoder_layers, left_context, batch, attention_dim), + the second element is the conv_cache which has a shape of + (encoder_layers, cnn_module_kernel-1, batch, conv_dim). + Note: states will be modified in this function. + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + left_context: left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. + Shape: + src: (S, N, E). + pos_emb: (N, 2*(S+left_context)-1, E). + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + assert not self.training + output = src + + for layer_index, mod in enumerate(self.layers): + cache = [states[0][layer_index], states[1][layer_index]] + output = mod.chunk_forward( + output, + pos_emb, states=cache, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, left_context=left_context, ) - if states is not None: - states[0][layer_index] = cache[0] - states[1][layer_index] = cache[1] + states[0][layer_index] = cache[0] + states[1][layer_index] = cache[1] return output @@ -1186,7 +1271,7 @@ class ConvolutionModule(nn.Module): def forward( self, x: Tensor, cache: Optional[Tensor] = None - ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + ) -> Tuple[Tensor, Tensor]: """Compute convolution module. Args: @@ -1227,9 +1312,10 @@ class ConvolutionModule(nn.Module): x = self.pointwise_conv2(x) # (batch, channel, time) - return ( - x.permute(2, 0, 1) if cache is None else (x.permute(2, 0, 1), cache) - ) + if cache is None: + cache = torch.empty(0) + + return x.permute(2, 0, 1), cache class Swish(torch.nn.Module):