Minor fixes

This commit is contained in:
pkufool 2023-06-19 12:12:33 +08:00
parent 802bf98f59
commit a7d0588827
4 changed files with 28 additions and 11 deletions

View File

@ -526,6 +526,8 @@ def fast_beam_search(
project_input=False,
)
ilme_logits = ilme_logits.squeeze(1).squeeze(1)
if blank_penalty != 0:
ilme_logits[:, 0] -= blank_penalty
ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1)
log_probs -= ilme_scale * ilme_log_probs

View File

@ -100,7 +100,7 @@ from beam_search import (
modified_beam_search,
)
from lhotse.cut import Cut
from train import add_model_arguments, get_params, get_transducer_model
from train import add_model_arguments, get_model, get_params
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import (
@ -227,6 +227,16 @@ def get_parser():
""",
)
parser.add_argument(
"--ilme-scale",
type=float,
default=0.2,
help="""
Used only when --decoding_method is fast_beam_search_LG.
It specifies the scale for the internal language model estimation.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
@ -381,6 +391,7 @@ def decode_one_batch(
max_contexts=params.max_contexts,
max_states=params.max_states,
blank_penalty=params.blank_penalty,
ilme_scale=params.ilme_scale,
)
for hyp in hyp_tokens:
sentence = "".join([lexicon.word_table[i] for i in hyp])
@ -458,6 +469,7 @@ def decode_one_batch(
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ilme_scale_{params.ilme_scale}"
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
@ -624,6 +636,7 @@ def main():
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"_ilme_scale_{params.ilme_scale}"
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
@ -656,7 +669,7 @@ def main():
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
@ -739,7 +752,7 @@ def main():
model.eval()
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir)
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
@ -782,6 +795,9 @@ def main():
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
test_sets = ["TEST_NET"]
test_dl = [test_net_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,

View File

@ -50,7 +50,7 @@ from streaming_beam_search import (
)
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
@ -761,7 +761,7 @@ def main():
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:

View File

@ -66,7 +66,7 @@ from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import Transducer
from model import AsrModel
from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
@ -578,20 +578,19 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
def get_model(params: AttributeDict) -> nn.Module:
encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
model = AsrModel(
encoder_embed=encoder_embed,
encoder=encoder,
decoder=decoder,
joiner=joiner,
encoder_dim=int(max(params.encoder_dim.split(","))),
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
)
return model
@ -758,7 +757,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(
simple_loss, pruned_loss, _ = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -1086,7 +1085,7 @@ def run(rank, world_size, args):
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model = get_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")