diff --git a/egs/librispeech/ASR/conv_emformer_transducer/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py index bde228af7..71ddd34cf 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py @@ -1231,10 +1231,7 @@ class EmformerEncoder(nn.Module): return attention_mask def forward( - self, - x: torch.Tensor, - lengths: torch.Tensor, - pos_emb: torch.Tensor, + self, x: torch.Tensor, lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass for training and non-streaming inference. @@ -1250,9 +1247,6 @@ class EmformerEncoder(nn.Module): With shape (B,) and i-th element representing number of valid utterance frames for i-th batch element in x, which contains the right_context at the end. - pos_emb (torch.Tensor): - Position encoding embedding, with shape (PE, D). - For training mode, P = 2*U-1. Returns: A tuple of 2 tensors: @@ -1260,8 +1254,11 @@ class EmformerEncoder(nn.Module): - output_lengths, with shape (B,), without containing the right_context at the end. """ + U = x.size(0) - self.right_context_length + x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U) + right_context = self._gen_right_context(x) - utterance = x[: x.size(0) - self.right_context_length] + utterance = x[:U] output_lengths = torch.clamp(lengths - self.right_context_length, min=0) attention_mask = self._gen_attention_mask(utterance) memory = ( @@ -1271,6 +1268,7 @@ class EmformerEncoder(nn.Module): if self.use_memory else torch.empty(0).to(dtype=x.dtype, device=x.device) ) + output = utterance for layer in self.emformer_layers: output, right_context, memory = layer( @@ -1289,7 +1287,6 @@ class EmformerEncoder(nn.Module): self, x: torch.Tensor, lengths: torch.Tensor, - pos_emb: torch.Tensor, states: Optional[List[List[torch.Tensor]]] = None, conv_caches: Optional[List[torch.Tensor]] = None, ) -> Tuple[ @@ -1309,9 +1306,6 @@ class EmformerEncoder(nn.Module): With shape (B,) and i-th element representing number of valid utterance frames for i-th batch element in x, which contains the right_context at the end. - pos_emb (torch.Tensor): - Position encoding embedding, with shape (PE, D). - For infer mode, PE = L+2*U-1. states (List[List[torch.Tensor]], optional): Cached states from proceeding chunk's computation, where each element (List[torch.Tensor]) corresponds to each emformer layer. @@ -1332,6 +1326,11 @@ class EmformerEncoder(nn.Module): f"expected size of {self.chunk_length + self.right_context_length} " f"for dimension 1 of x, but got {x.size(1)}." ) + + pos_len = self.chunk_length + self.left_context_length + neg_len = self.chunk_length + x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len) + right_context_start_idx = x.size(0) - self.right_context_length right_context = x[right_context_start_idx:] utterance = x[:right_context_start_idx] @@ -1414,8 +1413,6 @@ class Emformer(EncoderInterface): else: self.encoder_embed = Conv2dSubsampling(num_features, d_model) - self.encoder_pos = RelPositionalEncoding(d_model, dropout) - self.encoder = EmformerEncoder( chunk_length // 4, d_model, @@ -1463,10 +1460,6 @@ class Emformer(EncoderInterface): right_context at the end. """ x = self.encoder_embed(x) - - # TODO: The length computation in the encoder class should be moved here. # noqa - U = x.size(1) - self.right_context_length // 4 - x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) # Caution: We assume the subsampling factor is 4! @@ -1475,7 +1468,7 @@ class Emformer(EncoderInterface): x_lens = ((x_lens - 1) // 2 - 1) // 2 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder(x, x_lens, pos_emb) # (T, N, C) + output, output_lengths = self.encoder(x, x_lens) # (T, N, C) logits = self.encoder_output_layer(output) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -1521,12 +1514,6 @@ class Emformer(EncoderInterface): - updated convolution caches from current chunk. """ x = self.encoder_embed(x) - - # TODO: The length computation in the encoder class should be moved here. # noqa - pos_len = self.chunk_length // 4 + self.left_context_length // 4 - neg_len = self.chunk_length // 4 - x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) # Caution: We assume the subsampling factor is 4! @@ -1540,7 +1527,7 @@ class Emformer(EncoderInterface): output_lengths, output_states, output_conv_caches, - ) = self.encoder.infer(x, x_lens, pos_emb, states, conv_caches) + ) = self.encoder.infer(x, x_lens, states, conv_caches) logits = self.encoder_output_layer(output) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)