mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
move encoder_pos
This commit is contained in:
parent
8e6a51edaa
commit
d0cea4f2f8
@ -1231,10 +1231,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self, x: torch.Tensor, lengths: torch.Tensor
|
||||||
x: torch.Tensor,
|
|
||||||
lengths: torch.Tensor,
|
|
||||||
pos_emb: torch.Tensor,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Forward pass for training and non-streaming inference.
|
"""Forward pass for training and non-streaming inference.
|
||||||
|
|
||||||
@ -1250,9 +1247,6 @@ class EmformerEncoder(nn.Module):
|
|||||||
With shape (B,) and i-th element representing number of valid
|
With shape (B,) and i-th element representing number of valid
|
||||||
utterance frames for i-th batch element in x, which contains the
|
utterance frames for i-th batch element in x, which contains the
|
||||||
right_context at the end.
|
right_context at the end.
|
||||||
pos_emb (torch.Tensor):
|
|
||||||
Position encoding embedding, with shape (PE, D).
|
|
||||||
For training mode, P = 2*U-1.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of 2 tensors:
|
A tuple of 2 tensors:
|
||||||
@ -1260,8 +1254,11 @@ class EmformerEncoder(nn.Module):
|
|||||||
- output_lengths, with shape (B,), without containing the
|
- output_lengths, with shape (B,), without containing the
|
||||||
right_context at the end.
|
right_context at the end.
|
||||||
"""
|
"""
|
||||||
|
U = x.size(0) - self.right_context_length
|
||||||
|
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
|
||||||
|
|
||||||
right_context = self._gen_right_context(x)
|
right_context = self._gen_right_context(x)
|
||||||
utterance = x[: x.size(0) - self.right_context_length]
|
utterance = x[:U]
|
||||||
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||||
attention_mask = self._gen_attention_mask(utterance)
|
attention_mask = self._gen_attention_mask(utterance)
|
||||||
memory = (
|
memory = (
|
||||||
@ -1271,6 +1268,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
if self.use_memory
|
if self.use_memory
|
||||||
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||||
)
|
)
|
||||||
|
|
||||||
output = utterance
|
output = utterance
|
||||||
for layer in self.emformer_layers:
|
for layer in self.emformer_layers:
|
||||||
output, right_context, memory = layer(
|
output, right_context, memory = layer(
|
||||||
@ -1289,7 +1287,6 @@ class EmformerEncoder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
lengths: torch.Tensor,
|
lengths: torch.Tensor,
|
||||||
pos_emb: torch.Tensor,
|
|
||||||
states: Optional[List[List[torch.Tensor]]] = None,
|
states: Optional[List[List[torch.Tensor]]] = None,
|
||||||
conv_caches: Optional[List[torch.Tensor]] = None,
|
conv_caches: Optional[List[torch.Tensor]] = None,
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
@ -1309,9 +1306,6 @@ class EmformerEncoder(nn.Module):
|
|||||||
With shape (B,) and i-th element representing number of valid
|
With shape (B,) and i-th element representing number of valid
|
||||||
utterance frames for i-th batch element in x, which contains the
|
utterance frames for i-th batch element in x, which contains the
|
||||||
right_context at the end.
|
right_context at the end.
|
||||||
pos_emb (torch.Tensor):
|
|
||||||
Position encoding embedding, with shape (PE, D).
|
|
||||||
For infer mode, PE = L+2*U-1.
|
|
||||||
states (List[List[torch.Tensor]], optional):
|
states (List[List[torch.Tensor]], optional):
|
||||||
Cached states from proceeding chunk's computation, where each
|
Cached states from proceeding chunk's computation, where each
|
||||||
element (List[torch.Tensor]) corresponds to each emformer layer.
|
element (List[torch.Tensor]) corresponds to each emformer layer.
|
||||||
@ -1332,6 +1326,11 @@ class EmformerEncoder(nn.Module):
|
|||||||
f"expected size of {self.chunk_length + self.right_context_length} "
|
f"expected size of {self.chunk_length + self.right_context_length} "
|
||||||
f"for dimension 1 of x, but got {x.size(1)}."
|
f"for dimension 1 of x, but got {x.size(1)}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
pos_len = self.chunk_length + self.left_context_length
|
||||||
|
neg_len = self.chunk_length
|
||||||
|
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
|
||||||
|
|
||||||
right_context_start_idx = x.size(0) - self.right_context_length
|
right_context_start_idx = x.size(0) - self.right_context_length
|
||||||
right_context = x[right_context_start_idx:]
|
right_context = x[right_context_start_idx:]
|
||||||
utterance = x[:right_context_start_idx]
|
utterance = x[:right_context_start_idx]
|
||||||
@ -1414,8 +1413,6 @@ class Emformer(EncoderInterface):
|
|||||||
else:
|
else:
|
||||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||||
|
|
||||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
|
||||||
|
|
||||||
self.encoder = EmformerEncoder(
|
self.encoder = EmformerEncoder(
|
||||||
chunk_length // 4,
|
chunk_length // 4,
|
||||||
d_model,
|
d_model,
|
||||||
@ -1463,10 +1460,6 @@ class Emformer(EncoderInterface):
|
|||||||
right_context at the end.
|
right_context at the end.
|
||||||
"""
|
"""
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
|
|
||||||
# TODO: The length computation in the encoder class should be moved here. # noqa
|
|
||||||
U = x.size(1) - self.right_context_length // 4
|
|
||||||
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
|
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
# Caution: We assume the subsampling factor is 4!
|
# Caution: We assume the subsampling factor is 4!
|
||||||
@ -1475,7 +1468,7 @@ class Emformer(EncoderInterface):
|
|||||||
x_lens = ((x_lens - 1) // 2 - 1) // 2
|
x_lens = ((x_lens - 1) // 2 - 1) // 2
|
||||||
assert x.size(0) == x_lens.max().item()
|
assert x.size(0) == x_lens.max().item()
|
||||||
|
|
||||||
output, output_lengths = self.encoder(x, x_lens, pos_emb) # (T, N, C)
|
output, output_lengths = self.encoder(x, x_lens) # (T, N, C)
|
||||||
|
|
||||||
logits = self.encoder_output_layer(output)
|
logits = self.encoder_output_layer(output)
|
||||||
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
@ -1521,12 +1514,6 @@ class Emformer(EncoderInterface):
|
|||||||
- updated convolution caches from current chunk.
|
- updated convolution caches from current chunk.
|
||||||
"""
|
"""
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
|
|
||||||
# TODO: The length computation in the encoder class should be moved here. # noqa
|
|
||||||
pos_len = self.chunk_length // 4 + self.left_context_length // 4
|
|
||||||
neg_len = self.chunk_length // 4
|
|
||||||
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
|
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
# Caution: We assume the subsampling factor is 4!
|
# Caution: We assume the subsampling factor is 4!
|
||||||
@ -1540,7 +1527,7 @@ class Emformer(EncoderInterface):
|
|||||||
output_lengths,
|
output_lengths,
|
||||||
output_states,
|
output_states,
|
||||||
output_conv_caches,
|
output_conv_caches,
|
||||||
) = self.encoder.infer(x, x_lens, pos_emb, states, conv_caches)
|
) = self.encoder.infer(x, x_lens, states, conv_caches)
|
||||||
|
|
||||||
logits = self.encoder_output_layer(output)
|
logits = self.encoder_output_layer(output)
|
||||||
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user