from local
This commit is contained in:
parent
5f6b801d41
commit
1994866c0a
Binary file not shown.
@ -243,55 +243,7 @@ class Interformer(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
):
|
):
|
||||||
"""
|
encoder_out, x_lens = self.pt_encoder(x, x_lens, warmup=warmup)
|
||||||
Args:
|
|
||||||
x:
|
|
||||||
A 3-D tensor of shape (N, T, C).
|
|
||||||
x_lens:
|
|
||||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
|
||||||
before padding.
|
|
||||||
y:
|
|
||||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
|
||||||
utterance.
|
|
||||||
prune_range:
|
|
||||||
The prune range for rnnt loss, it means how many symbols(context)
|
|
||||||
we are considering for each frame to compute the loss.
|
|
||||||
am_scale:
|
|
||||||
The scale to smooth the loss with am (output of encoder network)
|
|
||||||
part
|
|
||||||
lm_scale:
|
|
||||||
The scale to smooth the loss with lm (output of predictor network)
|
|
||||||
part
|
|
||||||
warmup:
|
|
||||||
A value warmup >= 0 that determines which modules are active, values
|
|
||||||
warmup > 1 "are fully warmed up" and all modules will be active.
|
|
||||||
reduction:
|
|
||||||
"sum" to sum the losses over all utterances in the batch.
|
|
||||||
"none" to return the loss in a 1-D tensor for each utterance
|
|
||||||
in the batch.
|
|
||||||
delay_penalty:
|
|
||||||
A constant value used to penalize symbol delay, to encourage
|
|
||||||
streaming models to emit symbols earlier.
|
|
||||||
See https://github.com/k2-fsa/k2/issues/955 and
|
|
||||||
https://arxiv.org/pdf/2211.00490.pdf for more details.
|
|
||||||
Returns:
|
|
||||||
Returns:
|
|
||||||
Return the transducer loss.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
|
||||||
the form:
|
|
||||||
lm_scale * lm_probs + am_scale * am_probs +
|
|
||||||
(1-lm_scale-am_scale) * combined_probs
|
|
||||||
"""
|
|
||||||
assert reduction in ("sum", "none"), reduction
|
|
||||||
assert x.ndim == 3, x.shape
|
|
||||||
assert x_lens.ndim == 1, x_lens.shape
|
|
||||||
assert y.num_axes == 2, y.num_axes
|
|
||||||
|
|
||||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
|
||||||
|
|
||||||
encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup)
|
|
||||||
assert torch.all(x_lens > 0)
|
assert torch.all(x_lens > 0)
|
||||||
|
|
||||||
# Now for the decoder, i.e., the prediction network
|
# Now for the decoder, i.e., the prediction network
|
||||||
|
|||||||
Reference in New Issue
Block a user