move encoder_pos

This commit is contained in:
yaozengwei 2022-05-09 16:27:56 +08:00
parent 8e6a51edaa
commit d0cea4f2f8

View File

@ -1231,10 +1231,7 @@ class EmformerEncoder(nn.Module):
return attention_mask
def forward(
self,
x: torch.Tensor,
lengths: torch.Tensor,
pos_emb: torch.Tensor,
self, x: torch.Tensor, lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for training and non-streaming inference.
@ -1250,9 +1247,6 @@ class EmformerEncoder(nn.Module):
With shape (B,) and i-th element representing number of valid
utterance frames for i-th batch element in x, which contains the
right_context at the end.
pos_emb (torch.Tensor):
Position encoding embedding, with shape (PE, D).
For training mode, P = 2*U-1.
Returns:
A tuple of 2 tensors:
@ -1260,8 +1254,11 @@ class EmformerEncoder(nn.Module):
- output_lengths, with shape (B,), without containing the
right_context at the end.
"""
U = x.size(0) - self.right_context_length
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
right_context = self._gen_right_context(x)
utterance = x[: x.size(0) - self.right_context_length]
utterance = x[:U]
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
attention_mask = self._gen_attention_mask(utterance)
memory = (
@ -1271,6 +1268,7 @@ class EmformerEncoder(nn.Module):
if self.use_memory
else torch.empty(0).to(dtype=x.dtype, device=x.device)
)
output = utterance
for layer in self.emformer_layers:
output, right_context, memory = layer(
@ -1289,7 +1287,6 @@ class EmformerEncoder(nn.Module):
self,
x: torch.Tensor,
lengths: torch.Tensor,
pos_emb: torch.Tensor,
states: Optional[List[List[torch.Tensor]]] = None,
conv_caches: Optional[List[torch.Tensor]] = None,
) -> Tuple[
@ -1309,9 +1306,6 @@ class EmformerEncoder(nn.Module):
With shape (B,) and i-th element representing number of valid
utterance frames for i-th batch element in x, which contains the
right_context at the end.
pos_emb (torch.Tensor):
Position encoding embedding, with shape (PE, D).
For infer mode, PE = L+2*U-1.
states (List[List[torch.Tensor]], optional):
Cached states from proceeding chunk's computation, where each
element (List[torch.Tensor]) corresponds to each emformer layer.
@ -1332,6 +1326,11 @@ class EmformerEncoder(nn.Module):
f"expected size of {self.chunk_length + self.right_context_length} "
f"for dimension 1 of x, but got {x.size(1)}."
)
pos_len = self.chunk_length + self.left_context_length
neg_len = self.chunk_length
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
right_context_start_idx = x.size(0) - self.right_context_length
right_context = x[right_context_start_idx:]
utterance = x[:right_context_start_idx]
@ -1414,8 +1413,6 @@ class Emformer(EncoderInterface):
else:
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
self.encoder = EmformerEncoder(
chunk_length // 4,
d_model,
@ -1463,10 +1460,6 @@ class Emformer(EncoderInterface):
right_context at the end.
"""
x = self.encoder_embed(x)
# TODO: The length computation in the encoder class should be moved here. # noqa
U = x.size(1) - self.right_context_length // 4
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# Caution: We assume the subsampling factor is 4!
@ -1475,7 +1468,7 @@ class Emformer(EncoderInterface):
x_lens = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == x_lens.max().item()
output, output_lengths = self.encoder(x, x_lens, pos_emb) # (T, N, C)
output, output_lengths = self.encoder(x, x_lens) # (T, N, C)
logits = self.encoder_output_layer(output)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@ -1521,12 +1514,6 @@ class Emformer(EncoderInterface):
- updated convolution caches from current chunk.
"""
x = self.encoder_embed(x)
# TODO: The length computation in the encoder class should be moved here. # noqa
pos_len = self.chunk_length // 4 + self.left_context_length // 4
neg_len = self.chunk_length // 4
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# Caution: We assume the subsampling factor is 4!
@ -1540,7 +1527,7 @@ class Emformer(EncoderInterface):
output_lengths,
output_states,
output_conv_caches,
) = self.encoder.infer(x, x_lens, pos_emb, states, conv_caches)
) = self.encoder.infer(x, x_lens, states, conv_caches)
logits = self.encoder_output_layer(output)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)