mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
minor fix
This commit is contained in:
parent
8b60d43ead
commit
20a23d13ce
@ -1637,7 +1637,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, lengths: torch.Tensor
|
self, x: torch.Tensor, lengths: torch.Tensor, warmup: float = 1.0
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Forward pass for training and validation mode.
|
"""Forward pass for training and validation mode.
|
||||||
|
|
||||||
@ -1684,6 +1684,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
memory,
|
memory,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
pos_emb,
|
pos_emb,
|
||||||
|
warmup=warmup,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output, output_lengths
|
return output, output_lengths
|
||||||
@ -1831,7 +1832,7 @@ class Emformer(EncoderInterface):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Forward pass for training and non-streaming inference.
|
"""Forward pass for training and non-streaming inference.
|
||||||
|
|
||||||
@ -1847,6 +1848,10 @@ class Emformer(EncoderInterface):
|
|||||||
With shape (B,) and i-th element representing number of valid
|
With shape (B,) and i-th element representing number of valid
|
||||||
utterance frames for i-th batch element in x, containing the
|
utterance frames for i-th batch element in x, containing the
|
||||||
right_context at the end.
|
right_context at the end.
|
||||||
|
warmup:
|
||||||
|
A floating point value that gradually increases from 0 throughout
|
||||||
|
training; when it is >= 1.0 we are "fully warmed up". It is used
|
||||||
|
to turn modules on sequentially.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(Tensor, Tensor):
|
(Tensor, Tensor):
|
||||||
@ -1864,7 +1869,9 @@ class Emformer(EncoderInterface):
|
|||||||
x_lens = ((x_lens - 1) // 2 - 1) // 2
|
x_lens = ((x_lens - 1) // 2 - 1) // 2
|
||||||
assert x.size(0) == x_lens.max().item()
|
assert x.size(0) == x_lens.max().item()
|
||||||
|
|
||||||
output, output_lengths = self.encoder(x, x_lens) # (T, N, C)
|
output, output_lengths = self.encoder(
|
||||||
|
x, x_lens, warmup=warmup
|
||||||
|
) # (T, N, C)
|
||||||
|
|
||||||
output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user