Remove sos ID.

This commit is contained in:
Fangjun Kuang 2021-12-18 11:22:12 +08:00
parent 63e1266e3a
commit 9d0d5d19fb
11 changed files with 3 additions and 25 deletions

View File

@ -111,7 +111,6 @@ def beam_search(
# support only batch_size == 1 for now # support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0) assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
sos_id = model.decoder.sos_id
device = model.device device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1) sos = torch.tensor([blank_id], device=device).reshape(1, 1)
@ -192,7 +191,7 @@ def beam_search(
# Second, choose other labels # Second, choose other labels
for i, v in enumerate(log_prob.tolist()): for i, v in enumerate(log_prob.tolist()):
if i in (blank_id, sos_id): if i == blank_id:
continue continue
new_ys = y_star.ys + [i] new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v new_log_prob = y_star.log_prob + v

View File

@ -159,7 +159,6 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim, embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers, num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim, hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
@ -399,7 +398,6 @@ def main():
# <blk> and <sos/eos> are defined in local/train_bpe_model.py # <blk> and <sos/eos> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)

View File

@ -27,7 +27,6 @@ class Decoder(nn.Module):
vocab_size: int, vocab_size: int,
embedding_dim: int, embedding_dim: int,
blank_id: int, blank_id: int,
sos_id: int,
num_layers: int, num_layers: int,
hidden_dim: int, hidden_dim: int,
output_dim: int, output_dim: int,
@ -42,8 +41,6 @@ class Decoder(nn.Module):
Dimension of the input embedding. Dimension of the input embedding.
blank_id: blank_id:
The ID of the blank symbol. The ID of the blank symbol.
sos_id:
The ID of the SOS symbol.
num_layers: num_layers:
Number of LSTM layers. Number of LSTM layers.
hidden_dim: hidden_dim:
@ -71,7 +68,6 @@ class Decoder(nn.Module):
dropout=rnn_dropout, dropout=rnn_dropout,
) )
self.blank_id = blank_id self.blank_id = blank_id
self.sos_id = sos_id
self.output_linear = nn.Linear(hidden_dim, output_dim) self.output_linear = nn.Linear(hidden_dim, output_dim)
def forward( def forward(

View File

@ -148,7 +148,6 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim, embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers, num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim, hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
@ -197,7 +196,6 @@ def main():
# <blk> and <sos/eos> are defined in local/train_bpe_model.py # <blk> and <sos/eos> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)

View File

@ -54,7 +54,7 @@ class Transducer(nn.Module):
decoder: decoder:
It is the prediction network in the paper. Its input shape It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain is (N, U) and its output shape is (N, U, C). It should contain
two attributes: `blank_id` and `sos_id`. one attribute: `blank_id`.
joiner: joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains output shape is (N, T, U, C). Note that its output contains
@ -63,7 +63,6 @@ class Transducer(nn.Module):
super().__init__() super().__init__()
assert isinstance(encoder, EncoderInterface) assert isinstance(encoder, EncoderInterface)
assert hasattr(decoder, "blank_id") assert hasattr(decoder, "blank_id")
assert hasattr(decoder, "sos_id")
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
@ -102,8 +101,7 @@ class Transducer(nn.Module):
y_lens = row_splits[1:] - row_splits[:-1] y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id blank_id = self.decoder.blank_id
sos_id = self.decoder.sos_id sos_y = add_sos(y, sos_id=blank_id)
sos_y = add_sos(y, sos_id=sos_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)

View File

@ -145,7 +145,6 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim, embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers, num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim, hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
@ -211,7 +210,6 @@ def main():
# <blk> and <sos/eos> are defined in local/train_bpe_model.py # <blk> and <sos/eos> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(f"{params}") logging.info(f"{params}")

View File

@ -36,7 +36,6 @@ def test_conformer():
nhead=8, nhead=8,
dim_feedforward=2048, dim_feedforward=2048,
num_encoder_layers=12, num_encoder_layers=12,
use_feat_batchnorm=True,
) )
N = 3 N = 3
T = 100 T = 100

View File

@ -29,7 +29,6 @@ from decoder import Decoder
def test_decoder(): def test_decoder():
vocab_size = 3 vocab_size = 3
blank_id = 0 blank_id = 0
sos_id = 2
embedding_dim = 128 embedding_dim = 128
num_layers = 2 num_layers = 2
hidden_dim = 6 hidden_dim = 6
@ -41,7 +40,6 @@ def test_decoder():
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
blank_id=blank_id, blank_id=blank_id,
sos_id=sos_id,
num_layers=num_layers, num_layers=num_layers,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=output_dim, output_dim=output_dim,

View File

@ -39,7 +39,6 @@ def test_transducer():
# decoder params # decoder params
vocab_size = 3 vocab_size = 3
blank_id = 0 blank_id = 0
sos_id = 2
embedding_dim = 128 embedding_dim = 128
num_layers = 2 num_layers = 2
@ -51,14 +50,12 @@ def test_transducer():
nhead=8, nhead=8,
dim_feedforward=2048, dim_feedforward=2048,
num_encoder_layers=12, num_encoder_layers=12,
use_feat_batchnorm=True,
) )
decoder = Decoder( decoder = Decoder(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
blank_id=blank_id, blank_id=blank_id,
sos_id=sos_id,
num_layers=num_layers, num_layers=num_layers,
hidden_dim=output_dim, hidden_dim=output_dim,
output_dim=output_dim, output_dim=output_dim,

View File

@ -36,7 +36,6 @@ def test_transformer():
nhead=8, nhead=8,
dim_feedforward=2048, dim_feedforward=2048,
num_encoder_layers=12, num_encoder_layers=12,
use_feat_batchnorm=True,
) )
N = 3 N = 3
T = 100 T = 100

View File

@ -229,7 +229,6 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim, embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers, num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim, hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
@ -567,7 +566,6 @@ def run(rank, world_size, args):
# <blk> and <sos/eos> are defined in local/train_bpe_model.py # <blk> and <sos/eos> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)