diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py index 2f2145c70..78baa2b78 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py @@ -46,7 +46,8 @@ class Conformer(EncoderInterface): num_encoder_layers (int): number of encoder layers dropout (float): dropout rate layer_dropout (float): layer-dropout rate. - cnn_module_kernel (int): Kernel size of convolution module. + cnn_module_kernel (int): Kernel size of convolution module + vgg_frontend (bool): whether to use vgg frontend. dynamic_chunk_training (bool): whether to use dynamic chunk training, if you want to train a streaming model, this is expected to be True. When setting True, it will use a masking strategy to make the attention @@ -80,7 +81,6 @@ class Conformer(EncoderInterface): dropout: float = 0.1, layer_dropout: float = 0.075, cnn_module_kernel: int = 31, - aux_layer_period: int = 3, dynamic_chunk_training: bool = False, short_chunk_threshold: float = 0.75, short_chunk_size: int = 25, @@ -101,8 +101,6 @@ class Conformer(EncoderInterface): # (2) embedding: num_features -> d_model self.encoder_embed = Conv2dSubsampling(num_features, d_model) - self.encoder_pos = RelPositionalEncoding(d_model, dropout) - self.encoder_layers = num_encoder_layers self.d_model = d_model self.cnn_module_kernel = cnn_module_kernel @@ -112,6 +110,8 @@ class Conformer(EncoderInterface): self.short_chunk_size = short_chunk_size self.num_left_chunks = num_left_chunks + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + encoder_layer = ConformerEncoderLayer( d_model, nhead, @@ -119,12 +119,10 @@ class Conformer(EncoderInterface): dropout, layer_dropout, cnn_module_kernel, + causal, ) - self.encoder = ConformerEncoder( - encoder_layer, - num_encoder_layers, - aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)), - ) + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + self._init_state: List[torch.Tensor] = [torch.empty(0)] def forward( self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 @@ -150,8 +148,15 @@ class Conformer(EncoderInterface): x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + # Caution: We assume the subsampling factor is 4! + + # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning + # + # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 lengths = (((x_lens - 1) >> 1) - 1) >> 1 + assert x.size(0) == lengths.max().item() + src_key_padding_mask = make_pad_mask(lengths) if self.dynamic_chunk_training: @@ -182,13 +187,215 @@ class Conformer(EncoderInterface): x = self.encoder( x, pos_emb, + mask=None, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + ) # (T, N, C) + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + return x, lengths + + @torch.jit.export + def get_init_state( + self, left_context: int, device: torch.device + ) -> List[torch.Tensor]: + """Return the initial cache state of the model. + Args: + left_context: The left context size (in frames after subsampling). + Returns: + Return the initial state of the model, it is a list containing two + tensors, the first one is the cache for attentions which has a shape + of (num_encoder_layers, left_context, encoder_dim), the second one + is the cache of conv_modules which has a shape of + (num_encoder_layers, cnn_module_kernel - 1, encoder_dim). + NOTE: the returned tensors are on the given device. + """ + if ( + len(self._init_state) == 2 + and self._init_state[0].size(1) == left_context + ): + # Note: It is OK to share the init state as it is + # not going to be modified by the model + return self._init_state + + init_states: List[torch.Tensor] = [ + torch.zeros( + ( + self.encoder_layers, + left_context, + self.d_model, + ), + device=device, + ), + torch.zeros( + ( + self.encoder_layers, + self.cnn_module_kernel - 1, + self.d_model, + ), + device=device, + ), + ] + + self._init_state = init_states + + return init_states + + @torch.jit.export + def streaming_forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + states: Optional[List[Tensor]] = None, + processed_lens: Optional[Tensor] = None, + left_context: int = 64, + right_context: int = 4, + chunk_size: int = 16, + simulate_streaming: bool = False, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + states: + The decode states for previous frames which contains the cached data. + It has two elements, the first element is the attn_cache which has + a shape of (encoder_layers, left_context, batch, attention_dim), + the second element is the conv_cache which has a shape of + (encoder_layers, cnn_module_kernel-1, batch, conv_dim). + Note: states will be modified in this function. + processed_lens: + How many frames (after subsampling) have been processed for each sequence. + left_context: + How many previous frames the attention can see in current chunk. + Note: It's not that each individual frame has `left_context` frames + of left context, some have more. + right_context: + How many future frames the attention can see in current chunk. + Note: It's not that each individual frame has `right_context` frames + of right context, some have more. + chunk_size: + The chunk size for decoding, this will be used to simulate streaming + decoding using masking. + simulate_streaming: + If setting True, it will use a masking strategy to simulate streaming + fashion (i.e. every chunk data only see limited left context and + right context). The whole sequence is supposed to be send at a time + When using simulate_streaming. + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + Returns: + Return a tuple containing 2 tensors: + - logits, its shape is (batch_size, output_seq_len, output_dim) + - logit_lens, a tensor of shape (batch_size,) containing the number + of frames in `logits` before padding. + - decode_states, the updated states including the information + of current chunk. + """ + + # x: [N, T, C] + # Caution: We assume the subsampling factor is 4! + + # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning + # + # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 + lengths = (((x_lens - 1) >> 1) - 1) >> 1 + + if not simulate_streaming: + assert states is not None + assert processed_lens is not None + assert ( + len(states) == 2 + and states[0].shape + == (self.encoder_layers, left_context, x.size(0), self.d_model) + and states[1].shape + == ( + self.encoder_layers, + self.cnn_module_kernel - 1, + x.size(0), + self.d_model, + ) + ), f"""The length of states MUST be equal to 2, and the shape of + first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)}, + given {states[0].shape}. the shape of second element should be + {(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)}, + given {states[1].shape}.""" + + lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output + + src_key_padding_mask = make_pad_mask(lengths) + + processed_mask = torch.arange(left_context, device=x.device).expand( + x.size(0), left_context + ) + processed_lens = processed_lens.view(x.size(0), 1) + processed_mask = (processed_lens <= processed_mask).flip(1) + + src_key_padding_mask = torch.cat( + [processed_mask, src_key_padding_mask], dim=1 + ) + + embed = self.encoder_embed(x) + + # cut off 1 frame on each size of embed as they see the padding + # value which causes a training and decoding mismatch. + embed = embed[:, 1:-1, :] + + embed, pos_enc = self.encoder_pos(embed, left_context) + embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + + x, states = self.encoder.chunk_forward( + embed, + pos_enc, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + states=states, + left_context=left_context, + right_context=right_context, + ) # (T, B, F) + if right_context > 0: + x = x[0:-right_context, ...] + lengths -= right_context + else: + assert states is None + states = [] # just to make torch.script.jit happy + # this branch simulates streaming decoding using mask as we are + # using in training time. + src_key_padding_mask = make_pad_mask(lengths) + x = self.encoder_embed(x) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + assert x.size(0) == lengths.max().item() + + num_left_chunks = -1 + if left_context >= 0: + assert left_context % chunk_size == 0 + num_left_chunks = left_context // chunk_size + + mask = ~subsequent_chunk_mask( + size=x.size(0), + chunk_size=chunk_size, + num_left_chunks=num_left_chunks, + device=x.device, + ) + x = self.encoder( + x, + pos_emb, + mask=mask, src_key_padding_mask=src_key_padding_mask, warmup=warmup, ) # (T, N, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - return x, lengths + return x, lengths, states class ConformerEncoderLayer(nn.Module): @@ -309,6 +516,7 @@ class ConformerEncoderLayer(nn.Module): attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] + src = src + self.dropout(src_att) # convolution module @@ -325,6 +533,98 @@ class ConformerEncoderLayer(nn.Module): return src + @torch.jit.export + def chunk_forward( + self, + src: Tensor, + pos_emb: Tensor, + states: List[Tensor], + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + warmup: float = 1.0, + left_context: int = 0, + right_context: int = 0, + ) -> Tuple[Tensor, List[Tensor]]: + """ + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + states: + The decode states for previous frames which contains the cached data. + It has two elements, the first element is the attn_cache which has + a shape of (left_context, batch, attention_dim), + the second element is the conv_cache which has a shape of + (cnn_module_kernel-1, batch, conv_dim). + Note: states will be modified in this function. + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + warmup: controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + left_context: + How many previous frames the attention can see in current chunk. + Note: It's not that each individual frame has `left_context` frames + of left context, some have more. + right_context: + How many future frames the attention can see in current chunk. + Note: It's not that each individual frame has `right_context` frames + of right context, some have more. + Shape: + src: (S, N, E). + pos_emb: (N, 2*(S+left_context)-1, E). + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + + assert not self.training + assert len(states) == 2 + assert states[0].shape == (left_context, src.size(1), src.size(2)) + + # macaron style feed forward module + src = src + self.dropout(self.feed_forward_macaron(src)) + + # We put the attention cache this level (i.e. before linear transformation) + # to save memory consumption, when decoding in streaming fashion, the + # batch size would be thousands (for 32GB machine), if we cache key & val + # separately, it needs extra several GB memory. + # TODO(WeiKang): Move cache to self_attn level (i.e. cache key & val + # separately) if needed. + key = torch.cat([states[0], src], dim=0) + val = key + if right_context > 0: + states[0] = key[ + -(left_context + right_context) : -right_context, ... # noqa + ] + else: + states[0] = key[-left_context:, ...] + + # multi-headed self-attention module + src_att = self.self_attn( + src, + key, + val, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + left_context=left_context, + )[0] + + src = src + self.dropout(src_att) + + # convolution module + conv, conv_cache = self.conv_module(src, states[1], right_context) + states[1] = conv_cache + + src = src + self.dropout(conv) + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + return src, states + class ConformerEncoder(nn.Module): r"""ConformerEncoder is a stack of N encoder layers @@ -339,32 +639,13 @@ class ConformerEncoder(nn.Module): >>> out = conformer_encoder(src, pos_emb) """ - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - aux_layers: List[int], - ) -> None: + def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: super().__init__() self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) self.num_layers = num_layers - assert len(set(aux_layers)) == len(aux_layers) - - assert num_layers - 1 not in aux_layers - self.aux_layers = aux_layers + [num_layers - 1] - - num_channels = encoder_layer.norm_final.num_channels - self.combiner = RandomCombine( - num_inputs=len(self.aux_layers), - num_channels=num_channels, - final_weight=0.5, - pure_prob=0.333, - stddev=2.0, - ) - def forward( self, src: Tensor, @@ -379,6 +660,8 @@ class ConformerEncoder(nn.Module): pos_emb: Positional embedding tensor (required). mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). + warmup: controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. Shape: src: (S, N, E). pos_emb: (N, 2*S-1, E) @@ -388,9 +671,7 @@ class ConformerEncoder(nn.Module): """ output = src - outputs = [] - - for i, mod in enumerate(self.layers): + for layer_index, mod in enumerate(self.layers): output = mod( output, pos_emb, @@ -398,13 +679,80 @@ class ConformerEncoder(nn.Module): src_key_padding_mask=src_key_padding_mask, warmup=warmup, ) - if i in self.aux_layers: - outputs.append(output) - - output = self.combiner(outputs) return output + @torch.jit.export + def chunk_forward( + self, + src: Tensor, + pos_emb: Tensor, + states: List[Tensor], + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + warmup: float = 1.0, + left_context: int = 0, + right_context: int = 0, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Pass the input through the encoder layers in turn. + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + states: + The decode states for previous frames which contains the cached data. + It has two elements, the first element is the attn_cache which has + a shape of (encoder_layers, left_context, batch, attention_dim), + the second element is the conv_cache which has a shape of + (encoder_layers, cnn_module_kernel-1, batch, conv_dim). + Note: states will be modified in this function. + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + warmup: controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + left_context: + How many previous frames the attention can see in current chunk. + Note: It's not that each individual frame has `left_context` frames + of left context, some have more. + right_context: + How many future frames the attention can see in current chunk. + Note: It's not that each individual frame has `right_context` frames + of right context, some have more. + Shape: + src: (S, N, E). + pos_emb: (N, 2*(S+left_context)-1, E). + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + """ + assert not self.training + assert len(states) == 2 + assert states[0].shape == ( + self.num_layers, + left_context, + src.size(1), + src.size(2), + ) + assert states[1].size(0) == self.num_layers + + output = src + + for layer_index, mod in enumerate(self.layers): + cache = [states[0][layer_index], states[1][layer_index]] + output, cache = mod.chunk_forward( + output, + pos_emb, + states=cache, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + left_context=left_context, + right_context=right_context, + ) + states[0][layer_index] = cache[0] + states[1][layer_index] = cache[1] + + return output, states + class RelPositionalEncoding(torch.nn.Module): """Relative positional encoding module. @@ -426,12 +774,13 @@ class RelPositionalEncoding(torch.nn.Module): self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - def extend_pe(self, x: Tensor) -> None: + def extend_pe(self, x: Tensor, left_context: int = 0) -> None: """Reset the positional encodings.""" + x_size_1 = x.size(1) + left_context if self.pe is not None: # self.pe contains both positive and negative parts # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device if self.pe.dtype != x.dtype or str(self.pe.device) != str( x.device @@ -441,9 +790,9 @@ class RelPositionalEncoding(torch.nn.Module): # Suppose `i` means to the position of query vector and `j` means the # position of key vector. We use position relative positions when keys # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + def forward( + self, + x: torch.Tensor, + left_context: int = 0, + ) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: x (torch.Tensor): Input tensor (batch, time, `*`). + left_context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. Returns: torch.Tensor: Encoded tensor (batch, time, `*`). torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). """ - self.extend_pe(x) + self.extend_pe(x, left_context) + x_size_1 = x.size(1) + left_context pos_emb = self.pe[ :, self.pe.size(1) // 2 - - x.size(1) + - x_size_1 + 1 : self.pe.size(1) // 2 # noqa E203 + x.size(1), ] @@ -541,6 +898,7 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, + left_context: int = 0, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: @@ -554,6 +912,9 @@ class RelPositionMultiheadAttention(nn.Module): need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. + left_context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. Shape: - Inputs: - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is @@ -597,27 +958,36 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, + left_context=left_context, ) - def rel_shift(self, x: Tensor) -> Tensor: + def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor: """Compute relative positional encoding. Args: x: Input tensor (batch, head, time1, 2*time1-1). time1 means the length of query vector. + left_context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. Returns: Tensor: tensor of shape (batch, head, time1, time2) (note: time2 has the same value as time1, but it is for the key, while time1 is for the query). """ (batch_size, num_heads, time1, n) = x.shape - assert n == 2 * time1 - 1 + + time2 = time1 + left_context + assert ( + n == left_context + 2 * time1 - 1 + ), f"{n} == {left_context} + 2 * {time1} - 1" + # Note: TorchScript requires explicit arg for stride() batch_stride = x.stride(0) head_stride = x.stride(1) time1_stride = x.stride(2) n_stride = x.stride(3) return x.as_strided( - (batch_size, num_heads, time1, time1), + (batch_size, num_heads, time1, time2), (batch_stride, head_stride, time1_stride - n_stride, n_stride), storage_offset=n_stride * (time1 - 1), ) @@ -639,6 +1009,7 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, + left_context: int = 0, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: @@ -656,6 +1027,9 @@ class RelPositionMultiheadAttention(nn.Module): need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. + left_context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. Shape: Inputs: - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is @@ -817,7 +1191,8 @@ class RelPositionMultiheadAttention(nn.Module): pos_emb_bsz = pos_emb.size(0) assert pos_emb_bsz in (1, bsz) # actually it is 1 p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) - p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1) + p = p.permute(0, 2, 3, 1) q_with_bias_u = (q + self._pos_bias_u()).transpose( 1, 2 @@ -837,9 +1212,9 @@ class RelPositionMultiheadAttention(nn.Module): # compute matrix b and matrix d matrix_bd = torch.matmul( - q_with_bias_v, p.transpose(-2, -1) + q_with_bias_v, p ) # (batch, head, time1, 2*time1-1) - matrix_bd = self.rel_shift(matrix_bd) + matrix_bd = self.rel_shift(matrix_bd, left_context) attn_output_weights = ( matrix_ac + matrix_bd @@ -953,7 +1328,6 @@ class ConvolutionModule(nn.Module): super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding assert (kernel_size - 1) % 2 == 0 - self.causal = causal self.pointwise_conv1 = ScaledConv1d( @@ -1013,14 +1387,22 @@ class ConvolutionModule(nn.Module): initial_scale=0.25, ) - def forward(self, x: Tensor, cache: Optional[Tensor] = None) -> Tensor: + def forward( + self, + x: Tensor, + cache: Optional[Tensor] = None, + right_context: int = 0, + ) -> Tuple[Tensor, Tensor]: """Compute convolution module. Args: x: Input tensor (#time, batch, channels). cache: The cache of depthwise_conv, only used in real streaming decoding. + right_context: + How many future frames the attention can see in current chunk. + Note: It's not that each individual frame has `right_context` frames + of right context, some have more. Returns: - Tensor: Output tensor (#time, batch, channels). If cache is None return the output tensor (#time, batch, channels). If cache is not None, return a tuple of Tensor, the first one is the output tensor (#time, batch, channels), the second one is the @@ -1047,8 +1429,15 @@ class ConvolutionModule(nn.Module): ), "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) - cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa - + if right_context > 0: + cache = x.permute(2, 0, 1)[ + -(self.lorder + right_context) : ( # noqa + -right_context + ), + ..., + ] + else: + cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa x = self.depthwise_conv(x) x = self.deriv_balancer2(x) @@ -1153,282 +1542,6 @@ class Conv2dSubsampling(nn.Module): return x -class RandomCombine(nn.Module): - """ - This module combines a list of Tensors, all with the same shape, to - produce a single output of that same shape which, in training time, - is a random combination of all the inputs; but which in test time - will be just the last input. - All but the last input will have a linear transform before we - randomly combine them; these linear transforms will be initialized - to the identity transform. - The idea is that the list of Tensors will be a list of outputs of multiple - conformer layers. This has a similar effect as iterated loss. (See: - DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER - NETWORKS). - """ - - def __init__( - self, - num_inputs: int, - num_channels: int, - final_weight: float = 0.5, - pure_prob: float = 0.5, - stddev: float = 2.0, - ) -> None: - """ - Args: - num_inputs: - The number of tensor inputs, which equals the number of layers' - outputs that are fed into this module. E.g. in an 18-layer neural - net if we output layers 16, 12, 18, num_inputs would be 3. - num_channels: - The number of channels on the input, e.g. 512. - final_weight: - The amount of weight or probability we assign to the - final layer when randomly choosing layers or when choosing - continuous layer weights. - pure_prob: - The probability, on each frame, with which we choose - only a single layer to output (rather than an interpolation) - stddev: - A standard deviation that we add to log-probs for computing - randomized weights. - The method of choosing which layers, or combinations of layers, to use, - is conceptually as follows:: - With probability `pure_prob`:: - With probability `final_weight`: choose final layer, - Else: choose random non-final layer. - Else:: - Choose initial log-weights that correspond to assigning - weight `final_weight` to the final layer and equal - weights to other layers; then add Gaussian noise - with variance `stddev` to these log-weights, and normalize - to weights (note: the average weight assigned to the - final layer here will not be `final_weight` if stddev>0). - """ - super().__init__() - assert 0 <= pure_prob <= 1, pure_prob - assert 0 < final_weight < 1, final_weight - assert num_inputs >= 1 - - self.linear = nn.ModuleList( - [ - nn.Linear(num_channels, num_channels, bias=True) - for _ in range(num_inputs - 1) - ] - ) - - self.num_inputs = num_inputs - self.final_weight = final_weight - self.pure_prob = pure_prob - self.stddev = stddev - - self.final_log_weight = ( - torch.tensor( - (final_weight / (1 - final_weight)) * (self.num_inputs - 1) - ) - .log() - .item() - ) - self._reset_parameters() - - def _reset_parameters(self): - for i in range(len(self.linear)): - nn.init.eye_(self.linear[i].weight) - nn.init.constant_(self.linear[i].bias, 0.0) - - def forward(self, inputs: List[Tensor]) -> Tensor: - """Forward function. - Args: - inputs: - A list of Tensor, e.g. from various layers of a transformer. - All must be the same shape, of (*, num_channels) - Returns: - A Tensor of shape (*, num_channels). In test mode - this is just the final input. - """ - num_inputs = self.num_inputs - assert len(inputs) == num_inputs - if not self.training or torch.jit.is_scripting(): - return inputs[-1] - - # Shape of weights: (*, num_inputs) - num_channels = inputs[0].shape[-1] - num_frames = inputs[0].numel() // num_channels - - mod_inputs = [] - - if False: - # It throws the following error for torch 1.6.0 when using - # torch script. - # - # Expected integer literal for index. ModuleList/Sequential - # indexing is only supported with integer literals. Enumeration is - # supported, e.g. 'for index, v in enumerate(self): ...': - # for i in range(num_inputs - 1): - # mod_inputs.append(self.linear[i](inputs[i])) - assert False - else: - for i, linear in enumerate(self.linear): - if i < num_inputs - 1: - mod_inputs.append(linear(inputs[i])) - - mod_inputs.append(inputs[num_inputs - 1]) - - ndim = inputs[0].ndim - # stacked_inputs: (num_frames, num_channels, num_inputs) - stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape( - (num_frames, num_channels, num_inputs) - ) - - # weights: (num_frames, num_inputs) - weights = self._get_random_weights( - inputs[0].dtype, inputs[0].device, num_frames - ) - - weights = weights.reshape(num_frames, num_inputs, 1) - # ans: (num_frames, num_channels, 1) - ans = torch.matmul(stacked_inputs, weights) - # ans: (*, num_channels) - - ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,)) - - # The following if causes errors for torch script in torch 1.6.0 - # if __name__ == "__main__": - # # for testing only... - # print("Weights = ", weights.reshape(num_frames, num_inputs)) - return ans - - def _get_random_weights( - self, dtype: torch.dtype, device: torch.device, num_frames: int - ) -> Tensor: - """Return a tensor of random weights, of shape - `(num_frames, self.num_inputs)`, - Args: - dtype: - The data-type desired for the answer, e.g. float, double. - device: - The device needed for the answer. - num_frames: - The number of sets of weights desired - Returns: - A tensor of shape (num_frames, self.num_inputs), such that - `ans.sum(dim=1)` is all ones. - """ - pure_prob = self.pure_prob - if pure_prob == 0.0: - return self._get_random_mixed_weights(dtype, device, num_frames) - elif pure_prob == 1.0: - return self._get_random_pure_weights(dtype, device, num_frames) - else: - p = self._get_random_pure_weights(dtype, device, num_frames) - m = self._get_random_mixed_weights(dtype, device, num_frames) - return torch.where( - torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m - ) - - def _get_random_pure_weights( - self, dtype: torch.dtype, device: torch.device, num_frames: int - ): - """Return a tensor of random one-hot weights, of shape - `(num_frames, self.num_inputs)`, - Args: - dtype: - The data-type desired for the answer, e.g. float, double. - device: - The device needed for the answer. - num_frames: - The number of sets of weights desired. - Returns: - A one-hot tensor of shape `(num_frames, self.num_inputs)`, with - exactly one weight equal to 1.0 on each frame. - """ - final_prob = self.final_weight - - # final contains self.num_inputs - 1 in all elements - final = torch.full((num_frames,), self.num_inputs - 1, device=device) - # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. - nonfinal = torch.randint( - self.num_inputs - 1, (num_frames,), device=device - ) - - indexes = torch.where( - torch.rand(num_frames, device=device) < final_prob, final, nonfinal - ) - ans = torch.nn.functional.one_hot( - indexes, num_classes=self.num_inputs - ).to(dtype=dtype) - return ans - - def _get_random_mixed_weights( - self, dtype: torch.dtype, device: torch.device, num_frames: int - ): - """Return a tensor of random one-hot weights, of shape - `(num_frames, self.num_inputs)`, - Args: - dtype: - The data-type desired for the answer, e.g. float, double. - device: - The device needed for the answer. - num_frames: - The number of sets of weights desired. - Returns: - A tensor of shape (num_frames, self.num_inputs), which elements - in [0..1] that sum to one over the second axis, i.e. - `ans.sum(dim=1)` is all ones. - """ - logprobs = ( - torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) - * self.stddev - ) - logprobs[:, -1] += self.final_log_weight - return logprobs.softmax(dim=1) - - -def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): - print( - f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" - ) - num_inputs = 3 - num_channels = 50 - m = RandomCombine( - num_inputs=num_inputs, - num_channels=num_channels, - final_weight=final_weight, - pure_prob=pure_prob, - stddev=stddev, - ) - - x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] - - y = m(x) - assert y.shape == x[0].shape - assert torch.allclose(y, x[0]) # .. since actually all ones. - - -def _test_random_combine_main(): - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.0) - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.3) - _test_random_combine(0.5, 1, 0.3) - _test_random_combine(0.5, 0.5, 0.3) - - feature_dim = 50 - c = Conformer( - num_features=feature_dim, output_dim=256, d_model=128, nhead=4 - ) - batch_size = 5 - seq_len = 20 - # Just make sure the forward pass runs. - f = c( - torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - f # to remove flake8 warnings - - if __name__ == "__main__": feature_dim = 50 c = Conformer(num_features=feature_dim, d_model=128, nhead=4) @@ -1440,5 +1553,3 @@ if __name__ == "__main__": torch.full((batch_size,), seq_len, dtype=torch.int64), warmup=0.5, ) - - _test_random_combine_main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index 41e7a0f44..6962aef86 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -53,6 +53,7 @@ When training with the L subset, usage: import argparse import logging +import math from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -68,10 +69,11 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, + average_checkpoints_with_averaged_model, find_checkpoints, load_checkpoint, ) @@ -80,9 +82,12 @@ from icefall.utils import ( AttributeDict, setup_logger, store_transcripts, + str2bool, write_error_stats, ) +LOG_EPS = math.log(1e-10) + def get_parser(): parser = argparse.ArgumentParser( @@ -92,17 +97,20 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", ) parser.add_argument( - "--batch", + "--iter", type=int, - default=None, - help="It specifies the batch checkpoint to use for decoding." - "Note: Epoch counts from 0.", + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, ) parser.add_argument( @@ -111,24 +119,24 @@ def get_parser(): default=15, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + "'--epoch' and '--iter'", ) parser.add_argument( - "--avg-last-n", - type=int, - default=0, - help="""If positive, --epoch and --avg are ignored and it - will use the last n checkpoints exp_dir/checkpoint-xxx.pt - where xxx is the number of processed batches while - saving that checkpoint. - """, + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless2/exp", + default="pruned_transducer_stateless5/exp", help="The experiment dir", ) @@ -204,6 +212,31 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + add_model_arguments(parser) + return parser @@ -240,7 +273,7 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. """ - device = model.device + device = next(model.parameters()).device feature = batch["inputs"] assert feature.ndim == 3 @@ -250,9 +283,26 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, ) + + if params.simulate_streaming: + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) + hyps = [] if params.decoding_method == "fast_beam_search": @@ -482,32 +532,92 @@ def main(): params.blank_id = lexicon.token_table[""] params.vocab_size = max(lexicon.tokens) + 1 + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + logging.info(params) logging.info("About to create model") model = get_transducer_model(params) - if params.avg_last_n > 0: - filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - elif params.batch is not None: - filenames = f"{params.exp_dir}/checkpoint-{params.batch}.pt" - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints([filenames], device=device)) + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) model.to(device) model.eval() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py new file mode 100644 index 000000000..ba5e80555 --- /dev/null +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py @@ -0,0 +1,126 @@ +# Copyright 2022 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple + +import k2 +import torch + +from icefall.utils import AttributeDict + + +class DecodeStream(object): + def __init__( + self, + params: AttributeDict, + initial_states: List[torch.Tensor], + decoding_graph: Optional[k2.Fsa] = None, + device: torch.device = torch.device("cpu"), + ) -> None: + """ + Args: + initial_states: + Initial decode states of the model, e.g. the return value of + `get_init_state` in conformer.py + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + Used only when decoding_method is fast_beam_search. + device: + The device to run this stream. + """ + if decoding_graph is not None: + assert device == decoding_graph.device + + self.params = params + self.LOG_EPS = math.log(1e-10) + + self.states = initial_states + + # It contains a 2-D tensors representing the feature frames. + self.features: torch.Tensor = None + + self.num_frames: int = 0 + # how many frames have been processed. (before subsampling). + # we only modify this value in `func:get_feature_frames`. + self.num_processed_frames: int = 0 + + self._done: bool = False + + # The transcript of current utterance. + self.ground_truth: str = "" + + # The decoding result (partial or final) of current utterance. + self.hyp: List = [] + + # how many frames have been processed, after subsampling (i.e. a + # cumulative sum of the second return value of + # encoder.streaming_forward + self.done_frames: int = 0 + + self.pad_length = ( + params.right_context + 2 + ) * params.subsampling_factor + 3 + + if params.decoding_method == "greedy_search": + self.hyp = [params.blank_id] * params.context_size + elif params.decoding_method == "fast_beam_search": + # The rnnt_decoding_stream for fast_beam_search. + self.rnnt_decoding_stream: k2.RnntDecodingStream = ( + k2.RnntDecodingStream(decoding_graph) + ) + else: + assert ( + False + ), f"Decoding method :{params.decoding_method} do not support." + + @property + def done(self) -> bool: + """Return True if all the features are processed.""" + return self._done + + def set_features( + self, + features: torch.Tensor, + ) -> None: + """Set features tensor of current utterance.""" + assert features.dim() == 2, features.dim() + self.features = torch.nn.functional.pad( + features, + (0, 0, 0, self.pad_length), + mode="constant", + value=self.LOG_EPS, + ) + self.num_frames = self.features.size(0) + + def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: + """Consume chunk_size frames of features""" + chunk_length = chunk_size + self.pad_length + + ret_length = min( + self.num_frames - self.num_processed_frames, chunk_length + ) + + ret_features = self.features[ + self.num_processed_frames : self.num_processed_frames # noqa + + ret_length + ] + + self.num_processed_frames += chunk_size + if self.num_processed_frames >= self.num_frames: + self._done = True + + return ret_features, ret_length diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py new file mode 100644 index 000000000..03bd45d20 --- /dev/null +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -0,0 +1,739 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +python pruned_transducer_stateless5/streaming_decode.py \ + --epoch 5 \ + --avg 1 \ + --decode-chunk-size 16 \ + --left-context 64 \ + --right-context 0 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --decoding-method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import torch +import torch.nn as nn +from asr_datamodule import WenetSpeechAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import one_best_decoding +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Support only greedy_search and fast_beam_search now. + """, + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--right-context", + type=int, + default=0, + help="right context can be seen during decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def greedy_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[DecodeStream], +) -> List[List[int]]: + + assert len(streams) == encoder_out.size(0) + assert encoder_out.ndim == 3 + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + T = encoder_out.size(1) + + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + # decoder_out is of shape (N, decoder_out_dim) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + for t in range(T): + # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa + # print("encoder_out shape: ", current_encoder_out.shape, "decoder_out shape: ", decoder_out.shape) + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + # logits'shape (batch_size, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + streams[i].hyp.append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ) + decoder_out = model.joiner.decoder_proj(decoder_out) + + hyp_tokens = [] + for stream in streams: + hyp_tokens.append(stream.hyp) + return hyp_tokens + + +def fast_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + processed_lens: torch.Tensor, + decoding_streams: k2.RnntDecodingStreams, +) -> List[List[int]]: + + B, T, C = encoder_out.shape + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + ) + logits = logits.squeeze(1).squeeze(1) + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + + decoding_streams.terminate_and_flush_to_streams() + + lattice = decoding_streams.format_output(processed_lens.tolist()) + best_path = one_best_decoding(lattice) + hyp_tokens = get_texts(best_path) + return hyp_tokens + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + + features = [] + feature_lens = [] + states = [] + + rnnt_stream_list = [] + processed_lens = [] + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames( + params.decode_chunk_size * params.subsampling_factor + ) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + if params.decoding_method == "fast_beam_search": + rnnt_stream_list.append(stream.rnnt_decoding_stream) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # if T is less than 7 there will be an error in time reduction layer, + # because we subsample features with ((x_len - 1) // 2 - 1) // 2 + # we plus 2 here because we will cut off one frame on each size of + # encoder_embed output as they see invalid paddings. so we need extra 2 + # frames. + tail_length = 7 + (2 + params.right_context) * params.subsampling_factor + if features.size(1) < tail_length: + feature_lens += tail_length - features.size(1) + features = torch.cat( + [ + features, + torch.tensor( + LOG_EPS, dtype=features.dtype, device=device + ).expand( + features.size(0), + tail_length - features.size(1), + features.size(2), + ), + ], + dim=1, + ) + + states = [ + torch.stack([x[0] for x in states], dim=2), + torch.stack([x[1] for x in states], dim=2), + ] + + processed_lens = torch.tensor(processed_lens, device=device) + + encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( + x=features, + x_lens=feature_lens, + states=states, + left_context=params.left_context, + right_context=params.right_context, + processed_lens=processed_lens, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + hyp_tokens = greedy_search(model, encoder_out, decode_streams) + elif params.decoding_method == "fast_beam_search": + config = k2.RnntDecodingConfig( + vocab_size=params.vocab_size, + decoder_history_len=params.context_size, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) + processed_lens = processed_lens + encoder_out_lens + hyp_tokens = fast_beam_search( + model, encoder_out, processed_lens, decoding_streams + ) + else: + assert False + + states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = [states[0][i], states[1][i]] + decode_streams[i].done_frames += encoder_out_lens[i] + if params.decoding_method == "fast_beam_search": + decode_streams[i].hyp = hyp_tokens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 100 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + decode_stream = DecodeStream( + params=params, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + decode_stream.set_features(fbank(samples.to(device))) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + hyp = decode_streams[i].hyp + if params.decoding_method == "greedy_search": + hyp = hyp[params.context_size :] # noqa + decode_results.append( + ( + list(decode_streams[i].ground_truth), + [lexicon.token_table[idx] for idx in hyp], + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + hyp = decode_streams[i].hyp + if params.decoding_method == "greedy_search": + hyp = hyp[params.context_size :] # noqa + decode_results.append( + ( + list(decode_streams[i].ground_truth), + [lexicon.token_table[idx] for idx in hyp], + ) + ) + del decode_streams[i] + + key = "greedy_search" + if params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + # sort results so we can easily compare the difference between two + # recognition results + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + WenetSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + # for streaming + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + params.suffix += f"-right-context-{params.right_context}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + params.causal_convolution = True + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + wenetspeech = WenetSpeechAsrDataModule(args) + + dev_cuts = wenetspeech.valid_cuts() + test_net_cuts = wenetspeech.test_net_cuts() + test_meeting_cuts = wenetspeech.test_meeting_cuts() + + test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] + test_cuts = [dev_cuts, test_net_cuts, test_meeting_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + lexicon=lexicon, + decoding_graph=decoding_graph, + ) + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main()