mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Remove SOS from decoder.
This commit is contained in:
parent
14c93add50
commit
2cf1b56cb3
@ -111,7 +111,6 @@ def beam_search(
|
||||
# support only batch_size == 1 for now
|
||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||
blank_id = model.decoder.blank_id
|
||||
sos_id = model.decoder.sos_id
|
||||
device = model.device
|
||||
|
||||
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
|
||||
@ -192,7 +191,7 @@ def beam_search(
|
||||
|
||||
# Second, choose other labels
|
||||
for i, v in enumerate(log_prob.tolist()):
|
||||
if i in (blank_id, sos_id):
|
||||
if i == blank_id:
|
||||
continue
|
||||
new_ys = y_star.ys + [i]
|
||||
new_log_prob = y_star.log_prob + v
|
||||
|
@ -155,7 +155,6 @@ def get_decoder_model(params: AttributeDict):
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.decoder_embedding_dim,
|
||||
blank_id=params.blank_id,
|
||||
sos_id=params.sos_id,
|
||||
num_layers=params.num_decoder_layers,
|
||||
hidden_dim=params.decoder_hidden_dim,
|
||||
output_dim=params.encoder_out_dim,
|
||||
@ -393,9 +392,8 @@ def main():
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.sos_id = sp.piece_to_id("<sos/eos>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
@ -27,7 +27,6 @@ class Decoder(nn.Module):
|
||||
vocab_size: int,
|
||||
embedding_dim: int,
|
||||
blank_id: int,
|
||||
sos_id: int,
|
||||
num_layers: int,
|
||||
hidden_dim: int,
|
||||
output_dim: int,
|
||||
@ -42,8 +41,6 @@ class Decoder(nn.Module):
|
||||
Dimension of the input embedding.
|
||||
blank_id:
|
||||
The ID of the blank symbol.
|
||||
sos_id:
|
||||
The ID of the SOS symbol.
|
||||
num_layers:
|
||||
Number of LSTM layers.
|
||||
hidden_dim:
|
||||
@ -71,7 +68,6 @@ class Decoder(nn.Module):
|
||||
dropout=rnn_dropout,
|
||||
)
|
||||
self.blank_id = blank_id
|
||||
self.sos_id = sos_id
|
||||
self.output_linear = nn.Linear(hidden_dim, output_dim)
|
||||
|
||||
def forward(
|
||||
|
@ -49,7 +49,7 @@ class Transducer(nn.Module):
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, C). It should contain
|
||||
two attributes: `blank_id` and `sos_id`.
|
||||
one attributes `blank_id`.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
|
||||
output shape is (N, T, U, C). Note that its output contains
|
||||
@ -58,7 +58,6 @@ class Transducer(nn.Module):
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface)
|
||||
assert hasattr(decoder, "blank_id")
|
||||
assert hasattr(decoder, "sos_id")
|
||||
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
@ -97,8 +96,7 @@ class Transducer(nn.Module):
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_id = self.decoder.sos_id
|
||||
sos_y = add_sos(y, sos_id=sos_id)
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
|
@ -179,8 +179,6 @@ def get_params() -> AttributeDict:
|
||||
|
||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||
|
||||
- weight_decay: The weight_decay for the optimizer.
|
||||
|
||||
- warm_step: The warm_step for Noam optimizer.
|
||||
"""
|
||||
params = AttributeDict(
|
||||
@ -206,7 +204,6 @@ def get_params() -> AttributeDict:
|
||||
"num_decoder_layers": 4,
|
||||
"decoder_hidden_dim": 512,
|
||||
# parameters for Noam
|
||||
"weight_decay": 1e-6,
|
||||
"warm_step": 80000, # For the 100h subset, use 8k
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
@ -232,7 +229,6 @@ def get_decoder_model(params: AttributeDict):
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.decoder_embedding_dim,
|
||||
blank_id=params.blank_id,
|
||||
sos_id=params.sos_id,
|
||||
num_layers=params.num_decoder_layers,
|
||||
hidden_dim=params.decoder_hidden_dim,
|
||||
output_dim=params.encoder_out_dim,
|
||||
@ -568,9 +564,8 @@ def run(rank, world_size, args):
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.sos_id = sp.piece_to_id("<sos/eos>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
@ -594,7 +589,6 @@ def run(rank, world_size, args):
|
||||
model_size=params.encoder_hidden_size,
|
||||
factor=params.lr_factor,
|
||||
warm_step=params.warm_step,
|
||||
weight_decay=params.weight_decay,
|
||||
)
|
||||
|
||||
if checkpoints and "optimizer" in checkpoints:
|
||||
|
Loading…
x
Reference in New Issue
Block a user