mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Merge remote-tracking branch 'dan/master' into mmi
This commit is contained in:
commit
f03c991781
@ -276,7 +276,6 @@ class Transformer(nn.Module):
|
|||||||
# We set the first column to False since the first column in ys_in_pad
|
# We set the first column to False since the first column in ys_in_pad
|
||||||
# contains sos_id, which is the same as eos_id in our current setting.
|
# contains sos_id, which is the same as eos_id in our current setting.
|
||||||
tgt_key_padding_mask[:, 0] = False
|
tgt_key_padding_mask[:, 0] = False
|
||||||
# tgt_key_padding_mask = decoder_padding_mask(ys_in_pad)
|
|
||||||
|
|
||||||
tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C)
|
tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C)
|
||||||
tgt = self.decoder_pos(tgt)
|
tgt = self.decoder_pos(tgt)
|
||||||
@ -322,7 +321,6 @@ class Transformer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
# The common part between this function and decoder_forward could be
|
# The common part between this function and decoder_forward could be
|
||||||
# extracted as a separate function.
|
# extracted as a separate function.
|
||||||
|
|
||||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||||
ys_in = [torch.tensor(y) for y in ys_in]
|
ys_in = [torch.tensor(y) for y in ys_in]
|
||||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
|
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
|
||||||
@ -341,7 +339,6 @@ class Transformer(nn.Module):
|
|||||||
|
|
||||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||||
tgt_key_padding_mask[:, 0] = False
|
tgt_key_padding_mask[:, 0] = False
|
||||||
# tgt_key_padding_mask = decoder_padding_mask(ys_in_pad)
|
|
||||||
|
|
||||||
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
|
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
|
||||||
tgt = self.decoder_pos(tgt)
|
tgt = self.decoder_pos(tgt)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user