From 1e35ea3260727325ea23f07b92ea64941773e6c0 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Mon, 22 Nov 2021 18:52:12 +0800 Subject: [PATCH] streaming conformer code --- .../ASR/streaming_conformer_ctc/conformer.py | 497 ++++++++++++++++- .../streaming_decode.py | 506 ++++++++++++++++++ .../ASR/streaming_conformer_ctc/train.py | 21 +- .../streaming_conformer_ctc/transformer.py | 28 +- 4 files changed, 1018 insertions(+), 34 deletions(-) create mode 100755 egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py index b19b94db1..33f8e2e15 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py @@ -25,6 +25,42 @@ from torch import Tensor, nn from transformer import Supervisions, Transformer, encoder_padding_mask +# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py#L42 +def subsequent_chunk_mask( + size: int, + chunk_size: int, + num_left_chunks: int = -1, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Create mask for subsequent steps (size, size) with chunk size, + this is for streaming encoder + Args: + size (int): size of mask + chunk_size (int): size of chunk + num_left_chunks (int): number of left chunks + <0: use full chunk + >=0: use num_left_chunks + device (torch.device): "cpu" or "cuda" or torch.Tensor.device + Returns: + torch.Tensor: mask + Examples: + >>> subsequent_chunk_mask(4, 2) + [[1, 1, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 1], + [1, 1, 1, 1]] + """ + ret = torch.zeros(size, size, device=device, dtype=torch.bool) + for i in range(size): + if num_left_chunks < 0: + start = 0 + else: + start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) + ending = min((i // chunk_size + 1) * chunk_size, size) + ret[i, start:ending] = True + return ret + + class Conformer(Transformer): """ Args: @@ -57,6 +93,7 @@ class Conformer(Transformer): normalize_before: bool = True, vgg_frontend: bool = False, use_feat_batchnorm: bool = False, + causal: bool = False, ) -> None: super(Conformer, self).__init__( num_features=num_features, @@ -82,6 +119,7 @@ class Conformer(Transformer): dropout, cnn_module_kernel, normalize_before, + causal, ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) self.normalize_before = normalize_before @@ -93,7 +131,13 @@ class Conformer(Transformer): self.after_norm = identity def run_encoder( - self, x: Tensor, supervisions: Optional[Supervisions] = None + self, + x: Tensor, + supervisions: Optional[Supervisions] = None, + dynamic_chunk_training: bool = False, + short_chunk_proportion: float = 0.5, + chunk_size: int = -1, + simulate_streaming: bool = False, ) -> Tuple[Tensor, Optional[Tensor]]: """ Args: @@ -107,23 +151,235 @@ class Conformer(Transformer): It is read directly from the batch, without any sorting. It is used to compute encoder padding mask, which is used as memory key padding mask for the decoder. + dynamic_chunk_training: + For training only. + IF True, train with dynamic right context for some batches + sampled with a distribution + if False, train with full right context all the time. + short_chunk_proportion: + For training only. + Proportion of samples that will be trained with dynamic chunk. + chunk_size: + For eval only. + right context when evaluating test utts. + -1 means all right context. + simulate_streaming=False, + For eval only. + If true, the feature will be feeded into the model chunk by chunk. + If false, the whole utts if feeded into the model together i.e. the + model only foward once. + + + Returns: + Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). + Tensor: Mask tensor of dimension (batch_size, input_length) + """ + if self.encoder.training: + return self.train_run_encoder( + x, supervisions, dynamic_chunk_training, short_chunk_proportion + ) + else: + return self.eval_run_encoder( + x, supervisions, chunk_size, simulate_streaming + ) + + def train_run_encoder( + self, + x: Tensor, + supervisions: Optional[Supervisions] = None, + dynamic_chunk_training: bool = False, + short_chunk_threshold: float = 0.5, + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x: + The model input. Its shape is (N, T, C). + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute encoder padding mask, which is used as memory key padding + mask for the decoder. + dynamic_chunk_training: + IF True, train with dynamic right context for some batches + sampled with a distribution + if False, train with full right context all the time. + short_chunk_proportion: + Proportion of samples that will be trained with dynamic chunk. + """ + x = self.encoder_embed(x) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + src_key_padding_mask = encoder_padding_mask(x.size(0), supervisions) + if src_key_padding_mask is not None: + src_key_padding_mask = src_key_padding_mask.to(x.device) + + if dynamic_chunk_training: + max_len = x.size(0) + chunk_size = torch.randint(1, max_len, (1,)).item() + if chunk_size > (max_len * short_chunk_threshold): + chunk_size = max_len + else: + chunk_size = chunk_size % 25 + 1 + mask = ~subsequent_chunk_mask( + size=x.size(0), chunk_size=chunk_size, device=x.device + ) + x = self.encoder( + x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask + ) # (T, B, F) + else: + x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) + + if self.normalize_before: + x = self.after_norm(x) + + return x, src_key_padding_mask + + def eval_run_encoder( + self, + feature: Tensor, + supervisions: Optional[Supervisions] = None, + chunk_size: int = -1, + simulate_streaming=False, + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + feature: + The model input. Its shape is (N, T, C). + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute encoder padding mask, which is used as memory key padding + mask for the decoder. Returns: Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). Tensor: Mask tensor of dimension (batch_size, input_length) """ - x = self.encoder_embed(x) - x, pos_emb = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - mask = encoder_padding_mask(x.size(0), supervisions) - if mask is not None: - mask = mask.to(x.device) - x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) + # feature.shape: N T C + num_frames = feature.size(1) + + # As temporarily in icefall only subsampling_rate == 4 is supported, + # following parameters are hard-coded here. + # Change it accordingly if other subsamling_rate are supported. + embed_left_context = 7 + embed_conv_right_context = 3 + subsampling_rate = 4 + stride = chunk_size * subsampling_rate + decoding_window = embed_conv_right_context + stride + + # This is also only compatible to sumsampling_rate == 4 + length_after_subsampling = ((feature.size(1) - 1) // 2 - 1) // 2 + src_key_padding_mask = encoder_padding_mask( + length_after_subsampling, supervisions + ) + if src_key_padding_mask is not None: + src_key_padding_mask = src_key_padding_mask.to(feature.device) + + if chunk_size < 0: + # non-streaming decoding + x = self.encoder_embed(feature) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + x = self.encoder( + x, pos_emb, src_key_padding_mask=src_key_padding_mask + ) # (T, B, F) + else: + if simulate_streaming: + # simulate chunk_by_chunk streaming decoding + # Results of this branch should be identical to following + # "else" branch. + # But this branch is a little slower + # as the feature is feeded chunk by chunk + + # store the result of chunk_by_chunk decoding + encoder_output = [] + + # caches + pos_emb_positive = [] + pos_emb_negative = [] + pos_emb_central = None + encoder_cache = [None for i in range(len(self.encoder.layers))] + conv_cache = [None for i in range(len(self.encoder.layers))] + + # start chunk_by_chunk decoding + offset = 0 + for cur in range( + 0, num_frames - embed_left_context + 1, stride + ): + end = min(cur + decoding_window, num_frames) + cur_feature = feature[:, cur:end, :] + cur_feature = self.encoder_embed(cur_feature) + cur_embed, cur_pos_emb = self.encoder_pos( + cur_feature, offset + ) + cur_embed = cur_embed.permute( + 1, 0, 2 + ) # (B, T, F) -> (T, B, F) + + cur_T = cur_feature.size(1) + if cur == 0: + # for first chunk extract the central pos embedding + pos_emb_central = cur_pos_emb[ + 0, (chunk_size - 1), : + ].view(1, 1, -1) + cur_T -= 1 + pos_emb_positive.append(cur_pos_emb[0, :cur_T].flip(0)) + pos_emb_negative.append(cur_pos_emb[0, -cur_T:]) + assert pos_emb_positive[-1].size(0) == cur_T + + pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze( + 0 + ) + pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze( + 0 + ) + cur_pos_emb = torch.cat( + [pos_emb_pos.flip(1), pos_emb_central, pos_emb_neg], + dim=1, + ) + + x = self.encoder.chunk_forward( + cur_embed, + cur_pos_emb, + src_key_padding_mask=src_key_padding_mask[ + :, : offset + cur_embed.size(0) + ], + encoder_cache=encoder_cache, + conv_cache=conv_cache, + offset=offset, + ) # (T, B, F) + encoder_output.append(x) + offset += cur_embed.size(0) + + x = torch.cat(encoder_output, dim=0) + else: + # NOT simulate chunk_by_chunk decoding + # Results of this branch should be identical to previous + # simulate chunk_by_chunk decoding branch. + # But this branch is faster. + x = self.encoder_embed(feature) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + mask = ~subsequent_chunk_mask( + size=x.size(0), chunk_size=chunk_size, device=x.device + ) + x = self.encoder( + x, + pos_emb, + mask=mask, + src_key_padding_mask=src_key_padding_mask, + ) # (T, B, F) if self.normalize_before: x = self.after_norm(x) - return x, mask + return x, src_key_padding_mask class ConformerEncoderLayer(nn.Module): @@ -154,6 +410,7 @@ class ConformerEncoderLayer(nn.Module): dropout: float = 0.1, cnn_module_kernel: int = 31, normalize_before: bool = True, + causal: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() self.self_attn = RelPositionMultiheadAttention( @@ -174,7 +431,9 @@ class ConformerEncoderLayer(nn.Module): nn.Linear(dim_feedforward, d_model), ) - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + self.conv_module = ConvolutionModule( + d_model, cnn_module_kernel, causal=causal + ) self.norm_ff_macaron = nn.LayerNorm( d_model @@ -264,6 +523,97 @@ class ConformerEncoderLayer(nn.Module): return src + def chunk_forward( + self, + src: Tensor, + pos_emb: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + encoder_cache: Optional[Tensor] = None, + conv_cache: Optional[Tensor] = None, + offset=0, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + + # macaron style feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff_macaron(src) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) + if not self.normalize_before: + src = self.norm_ff_macaron(src) + + # multi-headed self-attention module + residual = src + if self.normalize_before: + src = self.norm_mha(src) + if encoder_cache is None: + # src: [chunk_size, N, F] e.g. [8, 41, 512] + key = src + val = key + encoder_cache = key + else: + key = torch.cat([encoder_cache, src], dim=0) + val = key + encoder_cache = key + src_att = self.self_attn( + src, + key, + val, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + offset=offset, + )[0] + src = residual + self.dropout(src_att) + if not self.normalize_before: + src = self.norm_mha(src) + + # convolution module + residual = src # [chunk_size, N, F] e.g. [8, 41, 512] + if self.normalize_before: + src = self.norm_conv(src) + if conv_cache is not None: + src = torch.cat([conv_cache, src], dim=0) + conv_cache = src + + src = self.conv_module(src) + src = src[-residual.size(0) :, :, :] # noqa: E203 + + src = residual + self.dropout(src) + if not self.normalize_before: + src = self.norm_conv(src) + + # feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff(src) + src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) + if not self.normalize_before: + src = self.norm_ff(src) + + if self.normalize_before: + src = self.norm_final(src) + + return src, encoder_cache, conv_cache + class ConformerEncoder(nn.TransformerEncoder): r"""ConformerEncoder is a stack of N encoder layers @@ -326,6 +676,52 @@ class ConformerEncoder(nn.TransformerEncoder): return output + def chunk_forward( + self, + src: Tensor, + pos_emb: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + encoder_cache=None, + conv_cache=None, + offset=0, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + output = src + + for layer_index, mod in enumerate(self.layers): + output, e_cache, c_cache = mod.chunk_forward( + output, + pos_emb, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + encoder_cache=encoder_cache[layer_index], + conv_cache=conv_cache[layer_index], + offset=offset, + ) + encoder_cache[layer_index] = e_cache + conv_cache[layer_index] = c_cache + + if self.norm is not None: + output = self.norm(output) + + return output + class RelPositionalEncoding(torch.nn.Module): """Relative positional encoding module. @@ -351,12 +747,13 @@ class RelPositionalEncoding(torch.nn.Module): self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - def extend_pe(self, x: Tensor) -> None: + def extend_pe(self, x: Tensor, offset: int = 0) -> None: """Reset the positional encodings.""" + x_size_1 = offset + x.size(1) if self.pe is not None: # self.pe contains both positive and negative parts # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device if self.pe.dtype != x.dtype or str(self.pe.device) != str( x.device @@ -366,9 +763,9 @@ class RelPositionalEncoding(torch.nn.Module): # Suppose `i` means to the position of query vecotr and `j` means the # position of key vector. We use position relative positions when keys # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + def forward( + self, x: torch.Tensor, offset: int = 0 + ) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: @@ -397,15 +796,31 @@ class RelPositionalEncoding(torch.nn.Module): torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). """ - self.extend_pe(x) + self.extend_pe(x, offset) x = x * self.xscale + x_size_1 = offset + x.size(1) pos_emb = self.pe[ :, self.pe.size(1) // 2 - - x.size(1) + - x_size_1 + 1 : self.pe.size(1) // 2 # noqa E203 - + x.size(1), + + x_size_1, ] + x_T = x.size(1) + if offset > 0: + pos_emb = torch.cat([pos_emb[:, :x_T], pos_emb[:, -x_T:]], dim=1) + else: + pos_emb = torch.cat( + [ + pos_emb[:, : (x_T - 1)], + self.pe[0, self.pe.size(1) // 2].view( + 1, 1, self.pe.size(-1) + ), + pos_emb[:, -(x_T - 1) :], # noqa: E203 + ], + dim=1, + ) + return self.dropout(x), self.dropout(pos_emb) @@ -469,6 +884,7 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, + offset=0, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: @@ -527,9 +943,10 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, + offset=offset, ) - def rel_shift(self, x: Tensor) -> Tensor: + def rel_shift(self, x: Tensor, offset=0) -> Tensor: """Compute relative positional encoding. Args: @@ -538,18 +955,20 @@ class RelPositionMultiheadAttention(nn.Module): Returns: Tensor: tensor of shape (batch, head, time1, time2) - (note: time2 has the same value as time1, but it is for + (note: time2 == time1 + offset, since it is for the key, while time1 is for the query). """ (batch_size, num_heads, time1, n) = x.shape - assert n == 2 * time1 - 1 + time2 = time1 + offset + assert n == 2 * time2 - 1 # Note: TorchScript requires explicit arg for stride() batch_stride = x.stride(0) head_stride = x.stride(1) time1_stride = x.stride(2) n_stride = x.stride(3) + return x.as_strided( - (batch_size, num_heads, time1, time1), + (batch_size, num_heads, time1, time2), (batch_stride, head_stride, time1_stride - n_stride, n_stride), storage_offset=n_stride * (time1 - 1), ) @@ -571,6 +990,7 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, + offset=0, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: @@ -749,7 +1169,9 @@ class RelPositionMultiheadAttention(nn.Module): pos_emb_bsz = pos_emb.size(0) assert pos_emb_bsz in (1, bsz) # actually it is 1 p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) - p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1) + p = p.permute(0, 2, 3, 1) q_with_bias_u = (q + self.pos_bias_u).transpose( 1, 2 @@ -769,10 +1191,11 @@ class RelPositionMultiheadAttention(nn.Module): # compute matrix b and matrix d matrix_bd = torch.matmul( - q_with_bias_v, p.transpose(-2, -1) + q_with_bias_v, p ) # (batch, head, time1, 2*time1-1) - matrix_bd = self.rel_shift(matrix_bd) - + matrix_bd = self.rel_shift( + matrix_bd, offset=offset + ) # [B, head, time1, time2] attn_output_weights = ( matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) @@ -843,7 +1266,11 @@ class ConvolutionModule(nn.Module): """ def __init__( - self, channels: int, kernel_size: int, bias: bool = True + self, + channels: int, + kernel_size: int, + bias: bool = True, + causal: bool = False, ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() @@ -858,12 +1285,20 @@ class ConvolutionModule(nn.Module): padding=0, bias=bias, ) + # from https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/convolution.py#L41 + if causal: + self.lorder = kernel_size - 1 + padding = 0 # manualy padding self.lorder zeros to the left later + else: + assert (kernel_size - 1) % 2 == 0 + self.lorder = 0 + padding = (kernel_size - 1) // 2 self.depthwise_conv = nn.Conv1d( channels, channels, kernel_size, stride=1, - padding=(kernel_size - 1) // 2, + padding=padding, groups=channels, bias=bias, ) @@ -896,6 +1331,10 @@ class ConvolutionModule(nn.Module): x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv + if self.lorder > 0: + # manualy padding self.lorder zeros to the left + # make depthwise_conv causal + x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) x = self.depthwise_conv(x) x = self.activation(self.norm(x)) diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py new file mode 100755 index 000000000..e88a4323c --- /dev/null +++ b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py @@ -0,0 +1,506 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py#L166 +def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: + new_hyp: List[int] = [] + cur = 0 + while cur < len(hyp): + if hyp[cur] != 0: + new_hyp.append(hyp[cur]) + prev = cur + while cur < len(hyp) and hyp[cur] == hyp[prev]: + cur += 1 + return new_hyp + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=34, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=20, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--chunk-size", + type=int, + default=8, + help="Frames of right context" + "-1 for whole right context, i.e. non-streaming decoding", + ) + + parser.add_argument( + "--tailing-num-frames", + type=int, + default=20, + help="tailing dummy frames padded to the right," + "only used during decoding", + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="simulate chunk by chunk decoding", + ) + parser.add_argument( + "--method", + type=str, + default="ctc-greedy-search", + help="Streaming Decoding method", + ) + + parser.add_argument( + "--export", + type=str2bool, + default=False, + help="""When enabled, the averaged model is saved to + conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved. + pretrained.pt contains a dict {"model": model.state_dict()}, + which can be loaded by `icefall.checkpoint.load_checkpoint()`. + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="streaming_conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe", + help="The lang dir", + ) + + parser.add_argument( + "--avg-models", + type=str, + default=None, + help="Manually select models to average, seperated by comma;" + "e.g. 60,62,63,72", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "exp_dir": Path("conformer_ctc/exp"), + "lang_dir": Path("data/lang_bpe"), + "lm_dir": Path("data/lm"), + # parameters for conformer + "causal": True, + "subsampling_factor": 4, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "num_decoder_layers": 6, + # parameters for decoding + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + chunk_size: int = -1, + simulate_streaming=False, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + + model: + The neural model. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + sos_id: + The token ID of the SOS. + eos_id: + The token ID of the EOS. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + feature = batch["inputs"] + device = torch.device("cuda") + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + # Extra dummy tailing frames my reduce deletion error + # example WITHOUT padding: + # CHAPTER SEVEN ON THE RACES OF MAN + # example WITH padding: + # CHAPTER SEVEN ON THE RACES OF (MAN->*) + tailing_frames = ( + torch.tensor([-23.0259]) + .expand([feature.size(0), params.tailing_num_frames, 80]) + .to(feature.device) + ) + feature = torch.cat([feature, tailing_frames], dim=1) + supervisions["num_frames"] += params.tailing_num_frames + + nnet_output, memory, memory_key_padding_mask = model( + feature, + supervisions, + chunk_size=chunk_size, + simulate_streaming=simulate_streaming, + ) + + assert params.method == "ctc-greedy-search" + key = "ctc-greedy-search" + batch_size = nnet_output.size(0) + maxlen = nnet_output.size(1) + topk_prob, topk_index = nnet_output.topk(1, dim=2) # (B, maxlen, 1) + topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) + topk_index = topk_index.masked_fill_( + memory_key_padding_mask, 0 + ) # (B, maxlen) + token_ids = [token_id.tolist() for token_id in topk_index] + token_ids = [ + remove_duplicates_and_blank(token_id) for token_id in token_ids + ] + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + return {key: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + chunk_size: int = -1, + simulate_streaming=False, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + word_table: + It is the word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + chunk_size: + right context to simulate streaming decoding + -1 for whole right context, i.e. non-stream decoding + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + results = [] + + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + + hyps_dict = decode_one_batch( + params=params, + model=model, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + sos_id=sos_id, + eos_id=eos_id, + chunk_size=chunk_size, + simulate_streaming=simulate_streaming, + ) + + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + + num_cuts += len(batch["supervisions"]["text"]) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + if params.method == "attention-decoder": + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + test_set_wers = dict() + if params.avg_models is not None: + avg_models = params.avg_models.replace(",", "_") + result_file_prefix = f"epoch-avg-{avg_models}-chunksize \ + -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-" + else: + result_file_prefix = f"epoch-{params.epoch}-avg-{params.avg}-chunksize \ + -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-" + for key, results in results_dict.items(): + recog_path = ( + params.exp_dir + / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt" + ) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.exp_dir + / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=enable_log + ) + test_set_wers[key] = wer + + if enable_log: + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") + logging.info("Decoding started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + causal=params.causal, + ) + + if params.avg == 1 and params.avg_models is not None: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + filenames = [] + if params.avg_models is not None: + model_ids = params.avg_models.split(",") + for i in model_ids: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + else: + start = params.epoch - params.avg + 1 + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + if params.export: + logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) + return + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + # CAUTION: `test_sets` is for displaying only. + # If you want to skip test-clean, you have to skip + # it inside the for loop. That is, use + # + # if test_set == 'test-clean': continue + # + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + test_sets = ["test-clean", "test-other"] + for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + bpe_model=bpe_model, + word_table=lexicon.word_table, + sos_id=sos_id, + eos_id=eos_id, + chunk_size=params.chunk_size, + simulate_streaming=params.simulate_streaming, + ) + + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/train.py b/egs/librispeech/ASR/streaming_conformer_ctc/train.py index c6063fade..1881cfcd0 100755 --- a/egs/librispeech/ASR/streaming_conformer_ctc/train.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/train.py @@ -124,6 +124,20 @@ def get_parser(): """, ) + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="Whether to use dynamic right context during training.", + ) + + parser.add_argument( + "--short-chunk-proportion", + type=float, + default=0.7, + help="Proportion of samples trained with short right context", + ) + return parser @@ -340,7 +354,12 @@ def compute_loss( supervisions = batch["supervisions"] with torch.set_grad_enabled(is_training): - nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + nnet_output, encoder_memory, memory_mask = model( + feature, + supervisions, + dynamic_chunk_training=params.dynamic_chunk_training, + short_chunk_proportion=params.short_chunk_proportion, + ) # nnet_output is (N, T, C) # NOTE: We need `encode_supervisions` to sort sequences with diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py index f93914aaa..bc78e4a41 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py @@ -158,7 +158,13 @@ class Transformer(nn.Module): self.decoder_criterion = None def forward( - self, x: torch.Tensor, supervision: Optional[Supervisions] = None + self, + x: torch.Tensor, + supervision: Optional[Supervisions] = None, + dynamic_chunk_training: bool = False, + short_chunk_proportion: float = 0.5, + chunk_size: int = -1, + simulate_streaming=False, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Args: @@ -184,13 +190,21 @@ class Transformer(nn.Module): x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision + x, + supervision, + dynamic_chunk_training=dynamic_chunk_training, + short_chunk_proportion=short_chunk_proportion, + chunk_size=chunk_size, + simulate_streaming=simulate_streaming, ) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask def run_encoder( - self, x: torch.Tensor, supervisions: Optional[Supervisions] = None + self, + x: torch.Tensor, + supervisions: Optional[Supervisions] = None, + chunk_size: int = -1, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Run the transformer encoder. @@ -205,6 +219,8 @@ class Transformer(nn.Module): It is read directly from the batch, without any sorting. It is used to compute the encoder padding mask, which is used as memory key padding mask for the decoder. + chunk_size: right chunk_size to simulate streaming decoding + -1 for whole right context Returns: Return a tuple with two tensors: - The encoder output, with shape (T, N, C) @@ -212,12 +228,16 @@ class Transformer(nn.Module): The mask is None if `supervisions` is None. It is used as memory key padding mask in the decoder. """ + # streaming decoding(chunk_size >= 0) is only verified with Conformer + assert chunk_size == -1 x = self.encoder_embed(x) x = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) mask = encoder_padding_mask(x.size(0), supervisions) mask = mask.to(x.device) if mask is not None else None - x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C) + x = self.encoder( + x, src_key_padding_mask=mask, chunk_size=chunk_size + ) # (T, N, C) return x, mask