move encoder_pos

This commit is contained in:
yaozengwei 2022-05-09 16:27:56 +08:00
parent 8e6a51edaa
commit d0cea4f2f8

View File

@ -1231,10 +1231,7 @@ class EmformerEncoder(nn.Module):
return attention_mask return attention_mask
def forward( def forward(
self, self, x: torch.Tensor, lengths: torch.Tensor
x: torch.Tensor,
lengths: torch.Tensor,
pos_emb: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for training and non-streaming inference. """Forward pass for training and non-streaming inference.
@ -1250,9 +1247,6 @@ class EmformerEncoder(nn.Module):
With shape (B,) and i-th element representing number of valid With shape (B,) and i-th element representing number of valid
utterance frames for i-th batch element in x, which contains the utterance frames for i-th batch element in x, which contains the
right_context at the end. right_context at the end.
pos_emb (torch.Tensor):
Position encoding embedding, with shape (PE, D).
For training mode, P = 2*U-1.
Returns: Returns:
A tuple of 2 tensors: A tuple of 2 tensors:
@ -1260,8 +1254,11 @@ class EmformerEncoder(nn.Module):
- output_lengths, with shape (B,), without containing the - output_lengths, with shape (B,), without containing the
right_context at the end. right_context at the end.
""" """
U = x.size(0) - self.right_context_length
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
right_context = self._gen_right_context(x) right_context = self._gen_right_context(x)
utterance = x[: x.size(0) - self.right_context_length] utterance = x[:U]
output_lengths = torch.clamp(lengths - self.right_context_length, min=0) output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
attention_mask = self._gen_attention_mask(utterance) attention_mask = self._gen_attention_mask(utterance)
memory = ( memory = (
@ -1271,6 +1268,7 @@ class EmformerEncoder(nn.Module):
if self.use_memory if self.use_memory
else torch.empty(0).to(dtype=x.dtype, device=x.device) else torch.empty(0).to(dtype=x.dtype, device=x.device)
) )
output = utterance output = utterance
for layer in self.emformer_layers: for layer in self.emformer_layers:
output, right_context, memory = layer( output, right_context, memory = layer(
@ -1289,7 +1287,6 @@ class EmformerEncoder(nn.Module):
self, self,
x: torch.Tensor, x: torch.Tensor,
lengths: torch.Tensor, lengths: torch.Tensor,
pos_emb: torch.Tensor,
states: Optional[List[List[torch.Tensor]]] = None, states: Optional[List[List[torch.Tensor]]] = None,
conv_caches: Optional[List[torch.Tensor]] = None, conv_caches: Optional[List[torch.Tensor]] = None,
) -> Tuple[ ) -> Tuple[
@ -1309,9 +1306,6 @@ class EmformerEncoder(nn.Module):
With shape (B,) and i-th element representing number of valid With shape (B,) and i-th element representing number of valid
utterance frames for i-th batch element in x, which contains the utterance frames for i-th batch element in x, which contains the
right_context at the end. right_context at the end.
pos_emb (torch.Tensor):
Position encoding embedding, with shape (PE, D).
For infer mode, PE = L+2*U-1.
states (List[List[torch.Tensor]], optional): states (List[List[torch.Tensor]], optional):
Cached states from proceeding chunk's computation, where each Cached states from proceeding chunk's computation, where each
element (List[torch.Tensor]) corresponds to each emformer layer. element (List[torch.Tensor]) corresponds to each emformer layer.
@ -1332,6 +1326,11 @@ class EmformerEncoder(nn.Module):
f"expected size of {self.chunk_length + self.right_context_length} " f"expected size of {self.chunk_length + self.right_context_length} "
f"for dimension 1 of x, but got {x.size(1)}." f"for dimension 1 of x, but got {x.size(1)}."
) )
pos_len = self.chunk_length + self.left_context_length
neg_len = self.chunk_length
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
right_context_start_idx = x.size(0) - self.right_context_length right_context_start_idx = x.size(0) - self.right_context_length
right_context = x[right_context_start_idx:] right_context = x[right_context_start_idx:]
utterance = x[:right_context_start_idx] utterance = x[:right_context_start_idx]
@ -1414,8 +1413,6 @@ class Emformer(EncoderInterface):
else: else:
self.encoder_embed = Conv2dSubsampling(num_features, d_model) self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
self.encoder = EmformerEncoder( self.encoder = EmformerEncoder(
chunk_length // 4, chunk_length // 4,
d_model, d_model,
@ -1463,10 +1460,6 @@ class Emformer(EncoderInterface):
right_context at the end. right_context at the end.
""" """
x = self.encoder_embed(x) x = self.encoder_embed(x)
# TODO: The length computation in the encoder class should be moved here. # noqa
U = x.size(1) - self.right_context_length // 4
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# Caution: We assume the subsampling factor is 4! # Caution: We assume the subsampling factor is 4!
@ -1475,7 +1468,7 @@ class Emformer(EncoderInterface):
x_lens = ((x_lens - 1) // 2 - 1) // 2 x_lens = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == x_lens.max().item() assert x.size(0) == x_lens.max().item()
output, output_lengths = self.encoder(x, x_lens, pos_emb) # (T, N, C) output, output_lengths = self.encoder(x, x_lens) # (T, N, C)
logits = self.encoder_output_layer(output) logits = self.encoder_output_layer(output)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@ -1521,12 +1514,6 @@ class Emformer(EncoderInterface):
- updated convolution caches from current chunk. - updated convolution caches from current chunk.
""" """
x = self.encoder_embed(x) x = self.encoder_embed(x)
# TODO: The length computation in the encoder class should be moved here. # noqa
pos_len = self.chunk_length // 4 + self.left_context_length // 4
neg_len = self.chunk_length // 4
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# Caution: We assume the subsampling factor is 4! # Caution: We assume the subsampling factor is 4!
@ -1540,7 +1527,7 @@ class Emformer(EncoderInterface):
output_lengths, output_lengths,
output_states, output_states,
output_conv_caches, output_conv_caches,
) = self.encoder.infer(x, x_lens, pos_emb, states, conv_caches) ) = self.encoder.infer(x, x_lens, states, conv_caches)
logits = self.encoder_output_layer(output) logits = self.encoder_output_layer(output)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)