mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Minor fixes
This commit is contained in:
parent
802bf98f59
commit
a7d0588827
@ -526,6 +526,8 @@ def fast_beam_search(
|
|||||||
project_input=False,
|
project_input=False,
|
||||||
)
|
)
|
||||||
ilme_logits = ilme_logits.squeeze(1).squeeze(1)
|
ilme_logits = ilme_logits.squeeze(1).squeeze(1)
|
||||||
|
if blank_penalty != 0:
|
||||||
|
ilme_logits[:, 0] -= blank_penalty
|
||||||
ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1)
|
ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1)
|
||||||
log_probs -= ilme_scale * ilme_log_probs
|
log_probs -= ilme_scale * ilme_log_probs
|
||||||
|
|
||||||
|
|||||||
@ -100,7 +100,7 @@ from beam_search import (
|
|||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -227,6 +227,16 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ilme-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.2,
|
||||||
|
help="""
|
||||||
|
Used only when --decoding_method is fast_beam_search_LG.
|
||||||
|
It specifies the scale for the internal language model estimation.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-contexts",
|
"--max-contexts",
|
||||||
type=int,
|
type=int,
|
||||||
@ -381,6 +391,7 @@ def decode_one_batch(
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
blank_penalty=params.blank_penalty,
|
blank_penalty=params.blank_penalty,
|
||||||
|
ilme_scale=params.ilme_scale,
|
||||||
)
|
)
|
||||||
for hyp in hyp_tokens:
|
for hyp in hyp_tokens:
|
||||||
sentence = "".join([lexicon.word_table[i] for i in hyp])
|
sentence = "".join([lexicon.word_table[i] for i in hyp])
|
||||||
@ -458,6 +469,7 @@ def decode_one_batch(
|
|||||||
key += f"_num_paths_{params.num_paths}_"
|
key += f"_num_paths_{params.num_paths}_"
|
||||||
key += f"nbest_scale_{params.nbest_scale}"
|
key += f"nbest_scale_{params.nbest_scale}"
|
||||||
if "LG" in params.decoding_method:
|
if "LG" in params.decoding_method:
|
||||||
|
key += f"_ilme_scale_{params.ilme_scale}"
|
||||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||||
|
|
||||||
return {key: hyps}
|
return {key: hyps}
|
||||||
@ -624,6 +636,7 @@ def main():
|
|||||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||||
params.suffix += f"-num-paths-{params.num_paths}"
|
params.suffix += f"-num-paths-{params.num_paths}"
|
||||||
if "LG" in params.decoding_method:
|
if "LG" in params.decoding_method:
|
||||||
|
params.suffix += f"_ilme_scale_{params.ilme_scale}"
|
||||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
@ -656,7 +669,7 @@ def main():
|
|||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_transducer_model(params)
|
model = get_model(params)
|
||||||
|
|
||||||
if not params.use_averaged_model:
|
if not params.use_averaged_model:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
@ -739,7 +752,7 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
if "LG" in params.decoding_method:
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
lg_filename = params.lang_dir / "LG.pt"
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
logging.info(f"Loading {lg_filename}")
|
logging.info(f"Loading {lg_filename}")
|
||||||
@ -782,6 +795,9 @@ def main():
|
|||||||
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
|
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
|
||||||
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
|
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
|
||||||
|
|
||||||
|
test_sets = ["TEST_NET"]
|
||||||
|
test_dl = [test_net_dl]
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
|
|||||||
@ -50,7 +50,7 @@ from streaming_beam_search import (
|
|||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -761,7 +761,7 @@ def main():
|
|||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_transducer_model(params)
|
model = get_model(params)
|
||||||
|
|
||||||
if not params.use_averaged_model:
|
if not params.use_averaged_model:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
|
|||||||
@ -66,7 +66,7 @@ from joiner import Joiner
|
|||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import Transducer
|
from model import AsrModel
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from scaling import ScheduledFloat
|
from scaling import ScheduledFloat
|
||||||
from subsampling import Conv2dSubsampling
|
from subsampling import Conv2dSubsampling
|
||||||
@ -578,20 +578,19 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
return joiner
|
return joiner
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
def get_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder_embed = get_encoder_embed(params)
|
encoder_embed = get_encoder_embed(params)
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
|
|
||||||
model = Transducer(
|
model = AsrModel(
|
||||||
encoder_embed=encoder_embed,
|
encoder_embed=encoder_embed,
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
encoder_dim=int(max(params.encoder_dim.split(","))),
|
encoder_dim=int(max(params.encoder_dim.split(","))),
|
||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
joiner_dim=params.joiner_dim,
|
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
@ -758,7 +757,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss = model(
|
simple_loss, pruned_loss, _ = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -1086,7 +1085,7 @@ def run(rank, world_size, args):
|
|||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_transducer_model(params)
|
model = get_model(params)
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user