diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index fb40bf5a5..a5186e150 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -15,7 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import copy import math import warnings @@ -172,22 +171,23 @@ class Conformer(EncoderInterface): chunk_size = chunk_size % self.short_chunk_size + 1 mask = ~subsequent_chunk_mask( - size=x.size(0), chunk_size=chunk_size, - num_left_chunks=self.num_left_chunks, device=x.device + size=x.size(0), + chunk_size=chunk_size, + num_left_chunks=self.num_left_chunks, + device=x.device, ) x, _ = self.encoder( - x, pos_emb, + x, + pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask, warmup=warmup, ) # (T, N, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - return x, lengths - def streaming_forward( self, x: torch.Tensor, @@ -250,9 +250,16 @@ class Conformer(EncoderInterface): ), "Require cache when sending data in streaming mode" assert ( - len(states) == 2 and - states[0].shape == (self.encoder_layers, left_context, x.size(0), self.d_model) and - states[1].shape == (self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model) + len(states) == 2 + and states[0].shape + == (self.encoder_layers, left_context, x.size(0), self.d_model) + and states[1].shape + == ( + self.encoder_layers, + self.cnn_module_kernel - 1, + x.size(0), + self.d_model, + ) ), f"""The length of states MUST be equal to 2, and the shape of first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)}, given {states[0].shape}. the shape of second element should be @@ -293,7 +300,7 @@ class Conformer(EncoderInterface): size=x.size(0), chunk_size=chunk_size, num_left_chunks=num_left_chunks, - device=x.device + device=x.device, ) x = self.encoder( x, @@ -366,9 +373,7 @@ class ConformerEncoderLayer(nn.Module): ) self.conv_module = ConvolutionModule( - d_model, - cnn_module_kernel, - causal=causal + d_model, cnn_module_kernel, causal=causal ) self.norm_final = BasicNorm(d_model) @@ -546,7 +551,11 @@ class ConformerEncoder(nn.Module): assert left_context >= 0 for layer_index, mod in enumerate(self.layers): - cache = None if states is None else [states[0][layer_index], states[1][layer_index]] + cache = ( + None + if states is None + else [states[0][layer_index], states[1][layer_index]] + ) output = mod( output, pos_emb, @@ -623,10 +632,10 @@ class RelPositionalEncoding(torch.nn.Module): self.pe = pe.to(device=x.device, dtype=x.dtype) def forward( - self, - x: torch.Tensor, - context: int = 0 - ) -> Tuple[Tensor, Tensor]: + self, + x: torch.Tensor, + context: int = 0, + ) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: @@ -1079,16 +1088,23 @@ class RelPositionMultiheadAttention(nn.Module): # the whole column of `attn_output_weights` will be `-inf` # (i.e. be `nan` after softmax), so, we fill `0.0` at the masking # positions to avoid invalid loss value below. - if attn_mask is not None and attn_mask.dtype == torch.bool and \ - key_padding_mask is not None: - combined_mask = attn_mask.unsqueeze( - 0) | key_padding_mask.unsqueeze(1).unsqueeze(2) + if ( + attn_mask is not None + and attn_mask.dtype == torch.bool + and key_padding_mask is not None + ): + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len) + bsz, num_heads, tgt_len, src_len + ) attn_output_weights = attn_output_weights.masked_fill( - combined_mask, 0.0) + combined_mask, 0.0 + ) attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, src_len) + bsz * num_heads, tgt_len, src_len + ) attn_output_weights = nn.functional.dropout( attn_output_weights, p=dropout_p, training=training @@ -1131,7 +1147,7 @@ class ConvolutionModule(nn.Module): channels: int, kernel_size: int, bias: bool = True, - causal: bool = False + causal: bool = False, ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() @@ -1197,10 +1213,10 @@ class ConvolutionModule(nn.Module): ) def forward( - self, - x: Tensor, - cache: Optional[Tensor] = None - ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + self, + x: Tensor, + cache: Optional[Tensor] = None, + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Compute convolution module. Args: @@ -1231,10 +1247,12 @@ class ConvolutionModule(nn.Module): # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert not self.training, "Cache should be None in training time" + assert ( + not self.training + ), "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) - cache = x.permute(2, 0, 1)[-self.lorder:,...] + cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa x = self.depthwise_conv(x) x = self.deriv_balancer2(x) @@ -1242,7 +1260,9 @@ class ConvolutionModule(nn.Module): x = self.pointwise_conv2(x) # (batch, channel, time) - return x.permute(2, 0, 1) if cache is None else (x.permute(2, 0, 1), cache) + return ( + x.permute(2, 0, 1) if cache is None else (x.permute(2, 0, 1), cache) + ) class Conv2dSubsampling(nn.Module): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 92038222c..4b787363e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -15,7 +15,6 @@ # limitations under the License. -import logging import k2 import torch import torch.nn as nn @@ -177,9 +176,9 @@ class Transducer(nn.Module): else: offset = (boundary[:, 3] - 1) / 2 total_syms = torch.sum(boundary[:, 2]) - offset = torch.arange( - T0, device=px_grad.device - ).reshape(1, 1, T0) - offset.reshape(B, 1, 1) + offset = torch.arange(T0, device=px_grad.device).reshape( + 1, 1, T0 + ) - offset.reshape(B, 1, 1) sym_delay = px_grad * offset sym_delay = torch.sum(sym_delay) / total_syms