Minor fixes

This commit is contained in:
pkufool 2023-06-19 12:12:33 +08:00
parent 802bf98f59
commit a7d0588827
4 changed files with 28 additions and 11 deletions

View File

@ -526,6 +526,8 @@ def fast_beam_search(
project_input=False, project_input=False,
) )
ilme_logits = ilme_logits.squeeze(1).squeeze(1) ilme_logits = ilme_logits.squeeze(1).squeeze(1)
if blank_penalty != 0:
ilme_logits[:, 0] -= blank_penalty
ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1) ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1)
log_probs -= ilme_scale * ilme_log_probs log_probs -= ilme_scale * ilme_log_probs

View File

@ -100,7 +100,7 @@ from beam_search import (
modified_beam_search, modified_beam_search,
) )
from lhotse.cut import Cut from lhotse.cut import Cut
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_model, get_params
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -227,6 +227,16 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--ilme-scale",
type=float,
default=0.2,
help="""
Used only when --decoding_method is fast_beam_search_LG.
It specifies the scale for the internal language model estimation.
""",
)
parser.add_argument( parser.add_argument(
"--max-contexts", "--max-contexts",
type=int, type=int,
@ -381,6 +391,7 @@ def decode_one_batch(
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
blank_penalty=params.blank_penalty, blank_penalty=params.blank_penalty,
ilme_scale=params.ilme_scale,
) )
for hyp in hyp_tokens: for hyp in hyp_tokens:
sentence = "".join([lexicon.word_table[i] for i in hyp]) sentence = "".join([lexicon.word_table[i] for i in hyp])
@ -458,6 +469,7 @@ def decode_one_batch(
key += f"_num_paths_{params.num_paths}_" key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}" key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method: if "LG" in params.decoding_method:
key += f"_ilme_scale_{params.ilme_scale}"
key += f"_ngram_lm_scale_{params.ngram_lm_scale}" key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps} return {key: hyps}
@ -624,6 +636,7 @@ def main():
params.suffix += f"-nbest-scale-{params.nbest_scale}" params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}" params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method: if "LG" in params.decoding_method:
params.suffix += f"_ilme_scale_{params.ilme_scale}"
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
@ -656,7 +669,7 @@ def main():
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_model(params)
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
@ -739,7 +752,7 @@ def main():
model.eval() model.eval()
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG": if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
@ -782,6 +795,9 @@ def main():
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dl = [dev_dl, test_net_dl, test_meeting_dl] test_dl = [dev_dl, test_net_dl, test_meeting_dl]
test_sets = ["TEST_NET"]
test_dl = [test_net_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,

View File

@ -50,7 +50,7 @@ from streaming_beam_search import (
) )
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
@ -761,7 +761,7 @@ def main():
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_model(params)
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:

View File

@ -66,7 +66,7 @@ from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from model import AsrModel
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from scaling import ScheduledFloat from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling from subsampling import Conv2dSubsampling
@ -578,20 +578,19 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
return joiner return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module:
encoder_embed = get_encoder_embed(params) encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)
model = Transducer( model = AsrModel(
encoder_embed=encoder_embed, encoder_embed=encoder_embed,
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
encoder_dim=int(max(params.encoder_dim.split(","))), encoder_dim=int(max(params.encoder_dim.split(","))),
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
) )
return model return model
@ -758,7 +757,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model( simple_loss, pruned_loss, _ = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -1086,7 +1085,7 @@ def run(rank, world_size, args):
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_model(params)
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")