Remove sos ID.

This commit is contained in:
Fangjun Kuang 2021-12-18 11:22:12 +08:00
parent 63e1266e3a
commit 9d0d5d19fb
11 changed files with 3 additions and 25 deletions

View File

@ -111,7 +111,6 @@ def beam_search(
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
sos_id = model.decoder.sos_id
device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
@ -192,7 +191,7 @@ def beam_search(
# Second, choose other labels
for i, v in enumerate(log_prob.tolist()):
if i in (blank_id, sos_id):
if i == blank_id:
continue
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v

View File

@ -159,7 +159,6 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim,
@ -399,7 +398,6 @@ def main():
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()
logging.info(params)

View File

@ -27,7 +27,6 @@ class Decoder(nn.Module):
vocab_size: int,
embedding_dim: int,
blank_id: int,
sos_id: int,
num_layers: int,
hidden_dim: int,
output_dim: int,
@ -42,8 +41,6 @@ class Decoder(nn.Module):
Dimension of the input embedding.
blank_id:
The ID of the blank symbol.
sos_id:
The ID of the SOS symbol.
num_layers:
Number of LSTM layers.
hidden_dim:
@ -71,7 +68,6 @@ class Decoder(nn.Module):
dropout=rnn_dropout,
)
self.blank_id = blank_id
self.sos_id = sos_id
self.output_linear = nn.Linear(hidden_dim, output_dim)
def forward(

View File

@ -148,7 +148,6 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim,
@ -197,7 +196,6 @@ def main():
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()
logging.info(params)

View File

@ -54,7 +54,7 @@ class Transducer(nn.Module):
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain
two attributes: `blank_id` and `sos_id`.
one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains
@ -63,7 +63,6 @@ class Transducer(nn.Module):
super().__init__()
assert isinstance(encoder, EncoderInterface)
assert hasattr(decoder, "blank_id")
assert hasattr(decoder, "sos_id")
self.encoder = encoder
self.decoder = decoder
@ -102,8 +101,7 @@ class Transducer(nn.Module):
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_id = self.decoder.sos_id
sos_y = add_sos(y, sos_id=sos_id)
sos_y = add_sos(y, sos_id=blank_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)

View File

@ -145,7 +145,6 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim,
@ -211,7 +210,6 @@ def main():
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")

View File

@ -36,7 +36,6 @@ def test_conformer():
nhead=8,
dim_feedforward=2048,
num_encoder_layers=12,
use_feat_batchnorm=True,
)
N = 3
T = 100

View File

@ -29,7 +29,6 @@ from decoder import Decoder
def test_decoder():
vocab_size = 3
blank_id = 0
sos_id = 2
embedding_dim = 128
num_layers = 2
hidden_dim = 6
@ -41,7 +40,6 @@ def test_decoder():
vocab_size=vocab_size,
embedding_dim=embedding_dim,
blank_id=blank_id,
sos_id=sos_id,
num_layers=num_layers,
hidden_dim=hidden_dim,
output_dim=output_dim,

View File

@ -39,7 +39,6 @@ def test_transducer():
# decoder params
vocab_size = 3
blank_id = 0
sos_id = 2
embedding_dim = 128
num_layers = 2
@ -51,14 +50,12 @@ def test_transducer():
nhead=8,
dim_feedforward=2048,
num_encoder_layers=12,
use_feat_batchnorm=True,
)
decoder = Decoder(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
blank_id=blank_id,
sos_id=sos_id,
num_layers=num_layers,
hidden_dim=output_dim,
output_dim=output_dim,

View File

@ -36,7 +36,6 @@ def test_transformer():
nhead=8,
dim_feedforward=2048,
num_encoder_layers=12,
use_feat_batchnorm=True,
)
N = 3
T = 100

View File

@ -229,7 +229,6 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim,
@ -567,7 +566,6 @@ def run(rank, world_size, args):
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()
logging.info(params)