Remove SOS from decoder.

This commit is contained in:
Fangjun Kuang 2021-12-28 10:38:22 +08:00
parent 14c93add50
commit 2cf1b56cb3
5 changed files with 5 additions and 20 deletions

View File

@ -111,7 +111,6 @@ def beam_search(
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
sos_id = model.decoder.sos_id
device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
@ -192,7 +191,7 @@ def beam_search(
# Second, choose other labels
for i, v in enumerate(log_prob.tolist()):
if i in (blank_id, sos_id):
if i == blank_id:
continue
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v

View File

@ -155,7 +155,6 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim,
@ -393,9 +392,8 @@ def main():
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()
logging.info(params)

View File

@ -27,7 +27,6 @@ class Decoder(nn.Module):
vocab_size: int,
embedding_dim: int,
blank_id: int,
sos_id: int,
num_layers: int,
hidden_dim: int,
output_dim: int,
@ -42,8 +41,6 @@ class Decoder(nn.Module):
Dimension of the input embedding.
blank_id:
The ID of the blank symbol.
sos_id:
The ID of the SOS symbol.
num_layers:
Number of LSTM layers.
hidden_dim:
@ -71,7 +68,6 @@ class Decoder(nn.Module):
dropout=rnn_dropout,
)
self.blank_id = blank_id
self.sos_id = sos_id
self.output_linear = nn.Linear(hidden_dim, output_dim)
def forward(

View File

@ -49,7 +49,7 @@ class Transducer(nn.Module):
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain
two attributes: `blank_id` and `sos_id`.
one attributes `blank_id`.
joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains
@ -58,7 +58,6 @@ class Transducer(nn.Module):
super().__init__()
assert isinstance(encoder, EncoderInterface)
assert hasattr(decoder, "blank_id")
assert hasattr(decoder, "sos_id")
self.encoder = encoder
self.decoder = decoder
@ -97,8 +96,7 @@ class Transducer(nn.Module):
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_id = self.decoder.sos_id
sos_y = add_sos(y, sos_id=sos_id)
sos_y = add_sos(y, sos_id=blank_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)

View File

@ -179,8 +179,6 @@ def get_params() -> AttributeDict:
- num_decoder_layers: Number of decoder layer of transformer decoder.
- weight_decay: The weight_decay for the optimizer.
- warm_step: The warm_step for Noam optimizer.
"""
params = AttributeDict(
@ -206,7 +204,6 @@ def get_params() -> AttributeDict:
"num_decoder_layers": 4,
"decoder_hidden_dim": 512,
# parameters for Noam
"weight_decay": 1e-6,
"warm_step": 80000, # For the 100h subset, use 8k
"env_info": get_env_info(),
}
@ -232,7 +229,6 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim,
@ -568,9 +564,8 @@ def run(rank, world_size, args):
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
@ -594,7 +589,6 @@ def run(rank, world_size, args):
model_size=params.encoder_hidden_size,
factor=params.lr_factor,
warm_step=params.warm_step,
weight_decay=params.weight_decay,
)
if checkpoints and "optimizer" in checkpoints: