From 5bd2490b4499f645031f5cb91dc25f54a329c2c3 Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 18 May 2022 23:26:24 +0800 Subject: [PATCH] support streaming in conformer --- .../ASR/pruned_transducer_stateless/decode.py | 42 ++- .../ASR/pruned_transducer_stateless/train.py | 27 ++ .../ASR/transducer_stateless/conformer.py | 317 ++++++++++++++++-- icefall/__init__.py | 1 + icefall/utils.py | 36 ++ 5 files changed, 392 insertions(+), 31 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index ea43836bd..fee28d4fe 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -249,6 +249,25 @@ def get_parser(): help="""Maximum number of symbols per frame. Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--streaming-mode", + type=str2bool, + default=False, + help=""" + """, + ) + parser.add_argument( + "--right-chunk-size", + type=int, + default=16, + help="right context to attend during decoding", + ) + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context to attend during decoding", + ) return parser @@ -301,9 +320,18 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + if params.streaming_mode: + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.right_chunk_size, + left_context=params.left_context, + streaming_data=False + ) + else: + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -526,6 +554,10 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if params.streaming_mode: + params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + if "fast_beam_search" in params.decoding_method: params.suffix += f"-use-LG-{params.use_LG}" params.suffix += f"-beam-{params.beam}" @@ -561,6 +593,10 @@ def main(): logging.info(params) logging.info("About to create model") + # TODO(wei kang): make following config more elegant + params.dynamic_chunk_training=params.streaming_mode + params.short_chunk_size=25 + params.num_left_chunks=params.left_context // params.right_chunk_size model = get_transducer_model(params) if params.iter > 0: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index c360d025a..a922192e5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -222,6 +222,29 @@ def get_parser(): """, ) + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="chunk length of dynamic training", + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="chunk length of dynamic training", + ) + + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True + """, + ) + return parser @@ -310,6 +333,10 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, vgg_frontend=params.vgg_frontend, + dynamic_chunk_training=params.dynamic_chunk_training, + short_chunk_size=params.short_chunk_size, + num_left_chunks=params.num_left_chunks, + causal=True if params.dynamic_chunk_training else False, ) return encoder diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 488c82386..046345508 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -18,13 +18,77 @@ import copy import math import warnings -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch from torch import Tensor, nn from transformer import Transformer -from icefall.utils import make_pad_mask +from icefall.utils import make_pad_mask, subsequent_chunk_mask + + +class DecodeStates(object): + def __init__(self, + layers: int, + left_context: int, + dim: int, + init: bool = True, + dtype: torch.dtype = torch.float32, + device: torch.device = torch.device('cpu')): + self.layers = layers + self.left_context = left_context + self.dim = dim + self.dtype = dtype + self.device = device + if init: + # shape (layer, T, dim) + self.attn_cache = torch.zeros((layers, left_context, dim), + dtype=dtype, + device=device) + self.conv_cache = torch.zeros((layers, left_context, dim), + dtype=dtype, + device=device) + self.offset = torch.tensor([0], dtype=dtype, device=device) + + @staticmethod + def stack(states: List['DecodeStates']) -> 'DecodeStates': + assert len(states) >= 1 + obj = DecodeStates(layers=states[0].layers, + left_context=states[0].left_context, + dim=states[0].dim, + init=False, + dtype=states[0].dtype, + device=states[0].device) + attn_cache = [] + conv_cache = [] + offset = [] + for i in range(len(states)): + attn_cache.append(states[i].attn_cache) + conv_cache.append(states[i].conv_cache) + offset.append(states[i].offset) + obj.attn_cache = torch.stack(attn_cache, dim=2) + obj.conv_cache = torch.stack(conv_cache, dim=2) + obj.offset = torch.stack(offset, dim=0) + return obj + + @staticmethod + def unstack(states: 'DecodeStates') -> List['DecodeStates']: + results = [] + attn_cache = torch.unbind(states.attn_cache, dim=2) + conv_cache = torch.unbind(states.conv_cache, dim=2) + offset = torch.unbind(states.offset, dim=0) + for i in range(states.attn_cache.size(2)): + obj = DecodeStates(layers=states.layers, + left_context=states.left_context, + dim=states.dim, + init=False, + dtype=states.dtype, + device=states.device) + obj.attn_cache = attn_cache[i] + obj.conv_cache = conv_cache[i] + obj.offset = offset[i] + results.append(obj) + return results class Conformer(Transformer): @@ -56,6 +120,11 @@ class Conformer(Transformer): cnn_module_kernel: int = 31, normalize_before: bool = True, vgg_frontend: bool = False, + dynamic_chunk_training: bool = False, + short_chunk_threshold: float = 0.75, + short_chunk_size: int = 25, + num_left_chunks: int = -1, + causal: bool = False, ) -> None: super(Conformer, self).__init__( num_features=num_features, @@ -70,6 +139,12 @@ class Conformer(Transformer): vgg_frontend=vgg_frontend, ) + self.dynamic_chunk_training = dynamic_chunk_training + self.short_chunk_threshold = short_chunk_threshold + self.short_chunk_size = short_chunk_size + self.num_left_chunks = num_left_chunks + self.causal = causal + self.encoder_pos = RelPositionalEncoding(d_model, dropout) encoder_layer = ConformerEncoderLayer( @@ -79,6 +154,7 @@ class Conformer(Transformer): dropout, cnn_module_kernel, normalize_before, + causal, ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) self.normalize_before = normalize_before @@ -115,9 +191,29 @@ class Conformer(Transformer): lengths = ((x_lens - 1) // 2 - 1) // 2 assert x.size(0) == lengths.max().item() - mask = make_pad_mask(lengths) - x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C) + src_key_padding_mask = make_pad_mask(lengths) + mask = None + + if self.dynamic_chunk_training: + assert ( + self.causal + ), "Causal convolution is required for streaming conformer." + max_len = x.size(0) + chunk_size = torch.randint(1, max_len, (1,)).item() + if chunk_size > (max_len * self.short_chunk_threshold): + chunk_size = max_len + else: + chunk_size = chunk_size % self.short_chunk_size + 1 + + mask = ~subsequent_chunk_mask( + size=x.size(0), chunk_size=chunk_size, + num_left_chunks=self.num_left_chunks, device=x.device + ) + + x = self.encoder( + x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask + ) # (T, N, C) if self.normalize_before: x = self.after_norm(x) @@ -128,6 +224,80 @@ class Conformer(Transformer): return logits, lengths + def streaming_forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + decode_states: Optional[DecodeStates] = None, + chunk_size: int = 32, + left_context: int = 64, + streaming_data: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor, DecodeStates]: + # x: [N, T, C] + + # Caution: We assume the subsampling factor is 4! + lengths = ((x_lens - 1) // 2 - 1) // 2 + + if streaming_data: + assert ( + decode_states is not None + ), "Require cache when sending data in streaming mode" + assert ( + left_context == decode_states.left_context + ), f"""The given left_context must equal to the left_context in + `decode_states`, need {decode_states.left_context} given + {left_context}.""" + + src_key_padding_mask = make_pad_mask(lengths + left_context) + + embed = self.encoder_embed(x) + embed, pos_enc = self.encoder_pos(embed, left_context) + embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + + x = self.encoder( + embed, + pos_enc, + src_key_padding_mask=src_key_padding_mask, + attn_cache=decode_states.attn_cache, + conv_cache=decode_states.conv_cache, + left_context=decode_states.left_context, + ) # (T, B, F) + + decode_states.offset += embed.size(0) + else: + assert decode_states is None + + src_key_padding_mask = make_pad_mask(lengths) + x = self.encoder_embed(x) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + assert x.size(0) == lengths.max().item() + assert left_context % chunk_size == 0 + num_left_chunks = left_context // chunk_size + + mask = ~subsequent_chunk_mask( + size=x.size(0), + chunk_size=chunk_size, + num_left_chunks=num_left_chunks, + device=x.device + ) + x = self.encoder( + x, + pos_emb, + mask=mask, + src_key_padding_mask=src_key_padding_mask, + ) # (T, N, C) + + if self.normalize_before: + x = self.after_norm(x) + + logits = self.encoder_output_layer(x) + logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return logits, lengths, decode_states + + class ConformerEncoderLayer(nn.Module): """ ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. @@ -156,6 +326,7 @@ class ConformerEncoderLayer(nn.Module): dropout: float = 0.1, cnn_module_kernel: int = 31, normalize_before: bool = True, + causal: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() self.self_attn = RelPositionMultiheadAttention( @@ -176,7 +347,9 @@ class ConformerEncoderLayer(nn.Module): nn.Linear(dim_feedforward, d_model), ) - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + self.conv_module = ConvolutionModule( + d_model, cnn_module_kernel, causal=causal + ) self.norm_ff_macaron = nn.LayerNorm( d_model @@ -201,6 +374,9 @@ class ConformerEncoderLayer(nn.Module): pos_emb: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + attn_cache: Optional[Tensor] = None, + conv_cache: Optional[Tensor] = None, + left_context: int = 0, ) -> Tensor: """ Pass the input through the encoder layer. @@ -233,13 +409,25 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_mha(src) + + key = src + val = src + if not self.training and attn_cache is not None: + # src: [chunk_size, N, F] e.g. [8, 41, 512] + key = torch.cat([attn_cache, src], dim=0) + val = key + attn_cache = key + else: + assert left_context == 0 + src_att = self.self_attn( src, - src, - src, + key, + val, pos_emb=pos_emb, attn_mask=src_mask, key_padding_mask=src_key_padding_mask, + left_context=left_context, )[0] src = residual + self.dropout(src_att) if not self.normalize_before: @@ -249,7 +437,15 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_conv(src) - src = residual + self.dropout(self.conv_module(src)) + + if not self.training and conv_cache is not None: + src = torch.cat([conv_cache, src], dim=0) + conv_cache = src + + src = self.conv_module(src) + src = src[-residual.size(0) :, :, :] # noqa: E203 + + src = residual + self.dropout(src) if not self.normalize_before: src = self.norm_conv(src) @@ -264,7 +460,7 @@ class ConformerEncoderLayer(nn.Module): if self.normalize_before: src = self.norm_final(src) - return src + return src, attn_cache, conv_cache class ConformerEncoder(nn.Module): @@ -295,6 +491,9 @@ class ConformerEncoder(nn.Module): pos_emb: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + attn_cache: Optional[Tensor] = None, + conv_cache: Optional[Tensor] = None, + left_context: int = 0, ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -314,13 +513,26 @@ class ConformerEncoder(nn.Module): """ output = src - for mod in self.layers: - output = mod( + if self.training: + assert left_context == 0 + assert attn_cache is None + assert conv_cache is None + else: + assert left_context >= 0 + + for layer_index, mod in enumerate(self.layers): + output, a_cache, c_cache = mod( output, pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, + attn_cache=None if attn_cache is None else attn_cache[layer_index], + conv_cache=None if conv_cache is None else conv_cache[layer_index], + left_context=left_context, ) + if attn_cache is not None and conv_cache is not None: + attn_cache[layer_index, ...] = a_cache[-left_context:, ...] + conv_cache[layer_index, ...] = c_cache[-left_context:, ...] return output @@ -349,12 +561,13 @@ class RelPositionalEncoding(torch.nn.Module): self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - def extend_pe(self, x: Tensor) -> None: + def extend_pe(self, x: Tensor, context: int = 0) -> None: """Reset the positional encodings.""" + x_size_1 = x.size(1) + context if self.pe is not None: # self.pe contains both positive and negative parts # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device if self.pe.dtype != x.dtype or str(self.pe.device) != str( x.device @@ -364,9 +577,9 @@ class RelPositionalEncoding(torch.nn.Module): # Suppose `i` means to the position of query vecotr and `j` means the # position of key vector. We use position relative positions when keys # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + def forward( + self, + x: torch.Tensor, + context: int = 0 + ) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: @@ -395,14 +612,15 @@ class RelPositionalEncoding(torch.nn.Module): torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). """ - self.extend_pe(x) + self.extend_pe(x, context) x = x * self.xscale + x_size_1 = x.size(1) + context pos_emb = self.pe[ :, self.pe.size(1) // 2 - - x.size(1) + - x_size_1 + 1 : self.pe.size(1) // 2 # noqa E203 - + x.size(1), + + x_size_1, ] return self.dropout(x), self.dropout(pos_emb) @@ -467,6 +685,7 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, + left_context: int = 0, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: @@ -525,9 +744,10 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, + left_context=left_context, ) - def rel_shift(self, x: Tensor) -> Tensor: + def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor: """Compute relative positional encoding. Args: @@ -540,14 +760,17 @@ class RelPositionMultiheadAttention(nn.Module): the key, while time1 is for the query). """ (batch_size, num_heads, time1, n) = x.shape - assert n == 2 * time1 - 1 + time2 = time1 + left_context + + assert n == 2 * time2 - 1, f"{n} == 2 * {time2} - 1" + # Note: TorchScript requires explicit arg for stride() batch_stride = x.stride(0) head_stride = x.stride(1) time1_stride = x.stride(2) n_stride = x.stride(3) return x.as_strided( - (batch_size, num_heads, time1, time1), + (batch_size, num_heads, time1, time2), (batch_stride, head_stride, time1_stride - n_stride, n_stride), storage_offset=n_stride * (time1 - 1), ) @@ -569,6 +792,7 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, + left_context: int = 0, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: @@ -748,7 +972,9 @@ class RelPositionMultiheadAttention(nn.Module): pos_emb_bsz = pos_emb.size(0) assert pos_emb_bsz in (1, bsz) # actually it is 1 p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) - p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1) + p = p.permute(0, 2, 3, 1) q_with_bias_u = (q + self.pos_bias_u).transpose( 1, 2 @@ -768,9 +994,10 @@ class RelPositionMultiheadAttention(nn.Module): # compute matrix b and matrix d matrix_bd = torch.matmul( - q_with_bias_v, p.transpose(-2, -1) + q_with_bias_v, p ) # (batch, head, time1, 2*time1-1) - matrix_bd = self.rel_shift(matrix_bd) + + matrix_bd = self.rel_shift(matrix_bd, left_context=left_context) attn_output_weights = ( matrix_ac + matrix_bd @@ -805,6 +1032,24 @@ class RelPositionMultiheadAttention(nn.Module): ) attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + + # If we are using dynamic_chunk_training and setting a limited + # num_left_chunks, the attention may only see the padding values which + # will also be masked out by `key_padding_mask`, at this circumstances, + # the whole column of `attn_output_weights` will be `-inf` + # (i.e. be `nan` after softmax), so, we fill `0.0` at the masking + # positions to avoid invalid loss value below. + if attn_mask is not None and attn_mask.dtype == torch.bool and \ + key_padding_mask is not None: + combined_mask = attn_mask.unsqueeze( + 0) | key_padding_mask.unsqueeze(1).unsqueeze(2) + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len) + attn_output_weights = attn_output_weights.masked_fill( + combined_mask, 0.0) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len) + attn_output_weights = nn.functional.dropout( attn_output_weights, p=dropout_p, training=training ) @@ -842,12 +1087,17 @@ class ConvolutionModule(nn.Module): """ def __init__( - self, channels: int, kernel_size: int, bias: bool = True + self, + channels: int, + kernel_size: int, + bias: bool = True, + causal: bool = False ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding assert (kernel_size - 1) % 2 == 0 + self.causal = causal self.pointwise_conv1 = nn.Conv1d( channels, @@ -857,12 +1107,18 @@ class ConvolutionModule(nn.Module): padding=0, bias=bias, ) + + self.lorder = kernel_size - 1 + padding = (kernel_size - 1) // 2 + if self.causal: + padding = 0 + self.depthwise_conv = nn.Conv1d( channels, channels, kernel_size, stride=1, - padding=(kernel_size - 1) // 2, + padding=padding, groups=channels, bias=bias, ) @@ -895,6 +1151,11 @@ class ConvolutionModule(nn.Module): x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv + if self.causal and self.lorder > 0: + # Make depthwise_conv causal by + # manualy padding self.lorder zeros to the left + x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) + x = self.depthwise_conv(x) # x is (batch, channels, time) x = x.permute(0, 2, 1) diff --git a/icefall/__init__.py b/icefall/__init__.py index ec77e89b5..52d551c6a 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -61,5 +61,6 @@ from .utils import ( setup_logger, store_transcripts, str2bool, + subsequent_chunk_mask, write_error_stats, ) diff --git a/icefall/utils.py b/icefall/utils.py index daccd4346..8e87d29df 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -693,6 +693,42 @@ def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor: return expaned_lengths >= lengths.unsqueeze(1) +# Copied and modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py +def subsequent_chunk_mask( + size: int, + chunk_size: int, + num_left_chunks: int = -1, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Create mask for subsequent steps (size, size) with chunk size, + this is for streaming encoder + Args: + size (int): size of mask + chunk_size (int): size of chunk + num_left_chunks (int): number of left chunks + <0: use full chunk + >=0: use num_left_chunks + device (torch.device): "cpu" or "cuda" or torch.Tensor.device + Returns: + torch.Tensor: mask + Examples: + >>> subsequent_chunk_mask(4, 2) + [[1, 1, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 1], + [1, 1, 1, 1]] + """ + ret = torch.zeros(size, size, device=device, dtype=torch.bool) + for i in range(size): + if num_left_chunks < 0: + start = 0 + else: + start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) + ending = min((i // chunk_size + 1) * chunk_size, size) + ret[i, start:ending] = True + return ret + + def l1_norm(x): return torch.sum(torch.abs(x))