minor fix

This commit is contained in:
yaozengwei 2022-05-15 12:11:05 +08:00
parent 8b60d43ead
commit 20a23d13ce

View File

@ -1637,7 +1637,7 @@ class EmformerEncoder(nn.Module):
return attention_mask return attention_mask
def forward( def forward(
self, x: torch.Tensor, lengths: torch.Tensor self, x: torch.Tensor, lengths: torch.Tensor, warmup: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for training and validation mode. """Forward pass for training and validation mode.
@ -1684,6 +1684,7 @@ class EmformerEncoder(nn.Module):
memory, memory,
attention_mask, attention_mask,
pos_emb, pos_emb,
warmup=warmup,
) )
return output, output_lengths return output, output_lengths
@ -1831,7 +1832,7 @@ class Emformer(EncoderInterface):
) )
def forward( def forward(
self, x: torch.Tensor, x_lens: torch.Tensor self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for training and non-streaming inference. """Forward pass for training and non-streaming inference.
@ -1847,6 +1848,10 @@ class Emformer(EncoderInterface):
With shape (B,) and i-th element representing number of valid With shape (B,) and i-th element representing number of valid
utterance frames for i-th batch element in x, containing the utterance frames for i-th batch element in x, containing the
right_context at the end. right_context at the end.
warmup:
A floating point value that gradually increases from 0 throughout
training; when it is >= 1.0 we are "fully warmed up". It is used
to turn modules on sequentially.
Returns: Returns:
(Tensor, Tensor): (Tensor, Tensor):
@ -1864,7 +1869,9 @@ class Emformer(EncoderInterface):
x_lens = ((x_lens - 1) // 2 - 1) // 2 x_lens = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == x_lens.max().item() assert x.size(0) == x_lens.max().item()
output, output_lengths = self.encoder(x, x_lens) # (T, N, C) output, output_lengths = self.encoder(
x, x_lens, warmup=warmup
) # (T, N, C)
output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C)