mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Fix the mismatch of forward & backward joiner label
This commit is contained in:
parent
18f997fe51
commit
08729f88b1
@ -18,6 +18,7 @@
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
|
||||||
from icefall.utils import add_eos, add_sos
|
from icefall.utils import add_eos, add_sos
|
||||||
@ -162,6 +163,8 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
eos_y = add_eos(y, eos_id=blank_id)
|
eos_y = add_eos(y, eos_id=blank_id)
|
||||||
eos_y_padded = eos_y.pad(mode="constant", padding_value=blank_id)
|
eos_y_padded = eos_y.pad(mode="constant", padding_value=blank_id)
|
||||||
|
eos_y_padded = F.pad(eos_y_padded[:, 1:], pad=(0, 1), value=blank_id)
|
||||||
|
|
||||||
# backward loss
|
# backward loss
|
||||||
assert self.backward_decoder is not None
|
assert self.backward_decoder is not None
|
||||||
assert self.backward_joiner is not None
|
assert self.backward_joiner is not None
|
||||||
@ -174,7 +177,7 @@ class Transducer(nn.Module):
|
|||||||
)
|
)
|
||||||
backward_pruned_loss = k2.rnnt_loss_pruned(
|
backward_pruned_loss = k2.rnnt_loss_pruned(
|
||||||
backward_logits,
|
backward_logits,
|
||||||
sos_y_padded.to(torch.int64),
|
y_padded.to(torch.int64),
|
||||||
ranges,
|
ranges,
|
||||||
blank_id,
|
blank_id,
|
||||||
boundary,
|
boundary,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user