mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Fix the mismatch of forward & backward joiner label
This commit is contained in:
parent
18f997fe51
commit
08729f88b1
@ -18,6 +18,7 @@
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_eos, add_sos
|
||||
@ -162,6 +163,8 @@ class Transducer(nn.Module):
|
||||
|
||||
eos_y = add_eos(y, eos_id=blank_id)
|
||||
eos_y_padded = eos_y.pad(mode="constant", padding_value=blank_id)
|
||||
eos_y_padded = F.pad(eos_y_padded[:, 1:], pad=(0, 1), value=blank_id)
|
||||
|
||||
# backward loss
|
||||
assert self.backward_decoder is not None
|
||||
assert self.backward_joiner is not None
|
||||
@ -174,7 +177,7 @@ class Transducer(nn.Module):
|
||||
)
|
||||
backward_pruned_loss = k2.rnnt_loss_pruned(
|
||||
backward_logits,
|
||||
sos_y_padded.to(torch.int64),
|
||||
y_padded.to(torch.int64),
|
||||
ranges,
|
||||
blank_id,
|
||||
boundary,
|
||||
|
Loading…
x
Reference in New Issue
Block a user