mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Fix decoding for gigaspeech in the libri + giga setup. (#345)
This commit is contained in:
parent
e1c3e98980
commit
8635fb4334
@ -69,7 +69,8 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search,
|
fast_beam_search_nbest_oracle,
|
||||||
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
@ -100,27 +101,28 @@ def get_parser():
|
|||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=28,
|
default=28,
|
||||||
help="It specifies the checkpoint to use for decoding."
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
"Note: Epoch counts from 0.",
|
Note: Epoch counts from 0.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=15,
|
default=15,
|
||||||
help="Number of checkpoints to average. Automatically select "
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. ",
|
"'--epoch' and '--iter'",
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--avg-last-n",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="""If positive, --epoch and --avg are ignored and it
|
|
||||||
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
|
|
||||||
where xxx is the number of processed batches while
|
|
||||||
saving that checkpoint.
|
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -146,6 +148,7 @@ def get_parser():
|
|||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
- fast_beam_search
|
- fast_beam_search
|
||||||
|
- fast_beam_search_nbest_oracle
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -165,7 +168,8 @@ def get_parser():
|
|||||||
help="""A floating point value to calculate the cutoff score during beam
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
`beam` in Kaldi.
|
`beam` in Kaldi.
|
||||||
Used only when --decoding-method is fast_beam_search""",
|
Used only when --decoding-method is
|
||||||
|
fast_beam_search or fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -173,7 +177,7 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is
|
||||||
fast_beam_search""",
|
fast_beam_search or fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -181,7 +185,7 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=8,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is
|
||||||
fast_beam_search""",
|
fast_beam_search or fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -199,6 +203,23 @@ def get_parser():
|
|||||||
Used only when --decoding_method is greedy_search""",
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-paths",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="""Number of paths for computed nbest oracle WER
|
||||||
|
when the decoding method is fast_beam_search_nbest_oracle.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nbest-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="""Scale applied to lattice scores when computing nbest paths.
|
||||||
|
Used only when the decoding_method is fast_beam_search_nbest_oracle.
|
||||||
|
""",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -243,7 +264,8 @@ def decode_one_batch(
|
|||||||
for the format of the `batch`.
|
for the format of the `batch`.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
only when --decoding_method is fast_beam_search.
|
only when --decoding_method is
|
||||||
|
fast_beam_search or fast_beam_search_nbest_oracle.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict.
|
the returned dict.
|
||||||
@ -264,7 +286,7 @@ def decode_one_batch(
|
|||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyp_tokens = fast_beam_search(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -275,6 +297,21 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||||
|
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
ref_texts=sp.encode(supervisions["text"]),
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
elif (
|
elif (
|
||||||
params.decoding_method == "greedy_search"
|
params.decoding_method == "greedy_search"
|
||||||
and params.max_sym_per_frame == 1
|
and params.max_sym_per_frame == 1
|
||||||
@ -328,6 +365,16 @@ def decode_one_batch(
|
|||||||
f"max_states_{params.max_states}"
|
f"max_states_{params.max_states}"
|
||||||
): hyps
|
): hyps
|
||||||
}
|
}
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||||
|
return {
|
||||||
|
(
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}_"
|
||||||
|
f"num_paths_{params.num_paths}_"
|
||||||
|
f"nbest_scale_{params.nbest_scale}"
|
||||||
|
): hyps
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
|
|
||||||
@ -463,17 +510,30 @@ def main():
|
|||||||
"greedy_search",
|
"greedy_search",
|
||||||
"beam_search",
|
"beam_search",
|
||||||
"fast_beam_search",
|
"fast_beam_search",
|
||||||
|
"fast_beam_search_nbest_oracle",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / "giga" / params.decoding_method
|
params.res_dir = params.exp_dir / "giga" / params.decoding_method
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
|
else:
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
if "fast_beam_search" in params.decoding_method:
|
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||||
|
params.suffix += f"-beam-{params.beam}"
|
||||||
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
params.suffix += f"-num-paths-{params.num_paths}"
|
||||||
|
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam_size}"
|
params.suffix += (
|
||||||
|
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -490,8 +550,9 @@ def main():
|
|||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(params.bpe_model)
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> and <unk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.unk_id = sp.unk_id()
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
@ -499,8 +560,20 @@ def main():
|
|||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_transducer_model(params)
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
if params.avg_last_n > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
@ -519,13 +592,17 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
model.device = device
|
model.device = device
|
||||||
|
model.unk_id = params.unk_id
|
||||||
|
|
||||||
# In beam_search.py, we are using model.decoder() and model.joiner(),
|
# In beam_search.py, we are using model.decoder() and model.joiner(),
|
||||||
# so we have to switch to the branch for the GigaSpeech dataset.
|
# so we have to switch to the branch for the GigaSpeech dataset.
|
||||||
model.decoder = model.decoder_giga
|
model.decoder = model.decoder_giga
|
||||||
model.joiner = model.joiner_giga
|
model.joiner = model.joiner_giga
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method in (
|
||||||
|
"fast_beam_search",
|
||||||
|
"fast_beam_search_nbest_oracle",
|
||||||
|
):
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user