mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
move encoder_pos
This commit is contained in:
parent
8e6a51edaa
commit
d0cea4f2f8
@ -1231,10 +1231,7 @@ class EmformerEncoder(nn.Module):
|
||||
return attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
self, x: torch.Tensor, lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass for training and non-streaming inference.
|
||||
|
||||
@ -1250,9 +1247,6 @@ class EmformerEncoder(nn.Module):
|
||||
With shape (B,) and i-th element representing number of valid
|
||||
utterance frames for i-th batch element in x, which contains the
|
||||
right_context at the end.
|
||||
pos_emb (torch.Tensor):
|
||||
Position encoding embedding, with shape (PE, D).
|
||||
For training mode, P = 2*U-1.
|
||||
|
||||
Returns:
|
||||
A tuple of 2 tensors:
|
||||
@ -1260,8 +1254,11 @@ class EmformerEncoder(nn.Module):
|
||||
- output_lengths, with shape (B,), without containing the
|
||||
right_context at the end.
|
||||
"""
|
||||
U = x.size(0) - self.right_context_length
|
||||
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
|
||||
|
||||
right_context = self._gen_right_context(x)
|
||||
utterance = x[: x.size(0) - self.right_context_length]
|
||||
utterance = x[:U]
|
||||
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||
attention_mask = self._gen_attention_mask(utterance)
|
||||
memory = (
|
||||
@ -1271,6 +1268,7 @@ class EmformerEncoder(nn.Module):
|
||||
if self.use_memory
|
||||
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||
)
|
||||
|
||||
output = utterance
|
||||
for layer in self.emformer_layers:
|
||||
output, right_context, memory = layer(
|
||||
@ -1289,7 +1287,6 @@ class EmformerEncoder(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
states: Optional[List[List[torch.Tensor]]] = None,
|
||||
conv_caches: Optional[List[torch.Tensor]] = None,
|
||||
) -> Tuple[
|
||||
@ -1309,9 +1306,6 @@ class EmformerEncoder(nn.Module):
|
||||
With shape (B,) and i-th element representing number of valid
|
||||
utterance frames for i-th batch element in x, which contains the
|
||||
right_context at the end.
|
||||
pos_emb (torch.Tensor):
|
||||
Position encoding embedding, with shape (PE, D).
|
||||
For infer mode, PE = L+2*U-1.
|
||||
states (List[List[torch.Tensor]], optional):
|
||||
Cached states from proceeding chunk's computation, where each
|
||||
element (List[torch.Tensor]) corresponds to each emformer layer.
|
||||
@ -1332,6 +1326,11 @@ class EmformerEncoder(nn.Module):
|
||||
f"expected size of {self.chunk_length + self.right_context_length} "
|
||||
f"for dimension 1 of x, but got {x.size(1)}."
|
||||
)
|
||||
|
||||
pos_len = self.chunk_length + self.left_context_length
|
||||
neg_len = self.chunk_length
|
||||
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
|
||||
|
||||
right_context_start_idx = x.size(0) - self.right_context_length
|
||||
right_context = x[right_context_start_idx:]
|
||||
utterance = x[:right_context_start_idx]
|
||||
@ -1414,8 +1413,6 @@ class Emformer(EncoderInterface):
|
||||
else:
|
||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
|
||||
self.encoder = EmformerEncoder(
|
||||
chunk_length // 4,
|
||||
d_model,
|
||||
@ -1463,10 +1460,6 @@ class Emformer(EncoderInterface):
|
||||
right_context at the end.
|
||||
"""
|
||||
x = self.encoder_embed(x)
|
||||
|
||||
# TODO: The length computation in the encoder class should be moved here. # noqa
|
||||
U = x.size(1) - self.right_context_length // 4
|
||||
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
@ -1475,7 +1468,7 @@ class Emformer(EncoderInterface):
|
||||
x_lens = ((x_lens - 1) // 2 - 1) // 2
|
||||
assert x.size(0) == x_lens.max().item()
|
||||
|
||||
output, output_lengths = self.encoder(x, x_lens, pos_emb) # (T, N, C)
|
||||
output, output_lengths = self.encoder(x, x_lens) # (T, N, C)
|
||||
|
||||
logits = self.encoder_output_layer(output)
|
||||
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
@ -1521,12 +1514,6 @@ class Emformer(EncoderInterface):
|
||||
- updated convolution caches from current chunk.
|
||||
"""
|
||||
x = self.encoder_embed(x)
|
||||
|
||||
# TODO: The length computation in the encoder class should be moved here. # noqa
|
||||
pos_len = self.chunk_length // 4 + self.left_context_length // 4
|
||||
neg_len = self.chunk_length // 4
|
||||
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
|
||||
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
@ -1540,7 +1527,7 @@ class Emformer(EncoderInterface):
|
||||
output_lengths,
|
||||
output_states,
|
||||
output_conv_caches,
|
||||
) = self.encoder.infer(x, x_lens, pos_emb, states, conv_caches)
|
||||
) = self.encoder.infer(x, x_lens, states, conv_caches)
|
||||
|
||||
logits = self.encoder_output_layer(output)
|
||||
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
Loading…
x
Reference in New Issue
Block a user