mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
minor fix
This commit is contained in:
parent
8b60d43ead
commit
20a23d13ce
@ -1637,7 +1637,7 @@ class EmformerEncoder(nn.Module):
|
||||
return attention_mask
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, lengths: torch.Tensor
|
||||
self, x: torch.Tensor, lengths: torch.Tensor, warmup: float = 1.0
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass for training and validation mode.
|
||||
|
||||
@ -1684,6 +1684,7 @@ class EmformerEncoder(nn.Module):
|
||||
memory,
|
||||
attention_mask,
|
||||
pos_emb,
|
||||
warmup=warmup,
|
||||
)
|
||||
|
||||
return output, output_lengths
|
||||
@ -1831,7 +1832,7 @@ class Emformer(EncoderInterface):
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass for training and non-streaming inference.
|
||||
|
||||
@ -1847,6 +1848,10 @@ class Emformer(EncoderInterface):
|
||||
With shape (B,) and i-th element representing number of valid
|
||||
utterance frames for i-th batch element in x, containing the
|
||||
right_context at the end.
|
||||
warmup:
|
||||
A floating point value that gradually increases from 0 throughout
|
||||
training; when it is >= 1.0 we are "fully warmed up". It is used
|
||||
to turn modules on sequentially.
|
||||
|
||||
Returns:
|
||||
(Tensor, Tensor):
|
||||
@ -1864,7 +1869,9 @@ class Emformer(EncoderInterface):
|
||||
x_lens = ((x_lens - 1) // 2 - 1) // 2
|
||||
assert x.size(0) == x_lens.max().item()
|
||||
|
||||
output, output_lengths = self.encoder(x, x_lens) # (T, N, C)
|
||||
output, output_lengths = self.encoder(
|
||||
x, x_lens, warmup=warmup
|
||||
) # (T, N, C)
|
||||
|
||||
output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user