mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
dc641991a5
commit
c6b71fc222
Binary file not shown.
@ -248,32 +248,6 @@ class Interformer(nn.Module):
|
|||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
get_layer_output=True
|
get_layer_output=True
|
||||||
)
|
)
|
||||||
assert torch.all(x_lens > 0)
|
|
||||||
|
|
||||||
# Now for the decoder, i.e., the prediction network
|
|
||||||
row_splits = y.shape.row_splits(1)
|
|
||||||
y_lens = row_splits[1:] - row_splits[:-1]
|
|
||||||
|
|
||||||
blank_id = self.decoder.blank_id
|
|
||||||
sos_y = add_sos(y, sos_id=blank_id)
|
|
||||||
|
|
||||||
# sos_y_padded: [B, S + 1], start with SOS.
|
|
||||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
|
||||||
|
|
||||||
# decoder_out: [B, S + 1, decoder_dim]
|
|
||||||
decoder_out = self.decoder(sos_y_padded)
|
|
||||||
|
|
||||||
# Note: y does not start with SOS
|
|
||||||
# y_padded : [B, S]
|
|
||||||
y_padded = y.pad(mode="constant", padding_value=0)
|
|
||||||
|
|
||||||
y_padded = y_padded.to(torch.int64)
|
|
||||||
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
|
|
||||||
boundary[:, 2] = y_lens
|
|
||||||
boundary[:, 3] = x_lens
|
|
||||||
|
|
||||||
lm = self.simple_lm_proj(decoder_out)
|
|
||||||
am = self.simple_am_proj(encoder_out)
|
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user