This commit is contained in:
Fangjun Kuang 2022-05-14 09:05:11 +08:00
parent ad3fb63ad6
commit a006d6494f
3 changed files with 152 additions and 61 deletions

View File

@ -23,7 +23,7 @@ Usage:
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav \
/path/to/bar.wav
(2) beam search
./pruned_transducer_stateless/pretrained.py \
@ -32,7 +32,7 @@ Usage:
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
/path/to/bar.wav
(3) modified beam search
./pruned_transducer_stateless/pretrained.py \
@ -41,7 +41,7 @@ Usage:
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
/path/to/bar.wav
(4) fast beam search
./pruned_transducer_stateless/pretrained.py \
@ -50,7 +50,7 @@ Usage:
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
/path/to/bar.wav
You can also use `./pruned_transducer_stateless/exp/epoch-xx.pt`.
@ -233,6 +233,9 @@ def main():
logging.info("Creating model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)

View File

@ -23,16 +23,34 @@ Usage:
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav \
/path/to/bar.wav
(1) beam search
(2) beam search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
/path/to/bar.wav
(3) modified beam search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./pruned_transducer_stateless2/exp/epoch-xx.pt`.
@ -79,9 +97,7 @@ def get_parser():
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
help="""Path to bpe.model.""",
)
parser.add_argument(
@ -117,7 +133,33 @@ def get_parser():
"--beam-size",
type=int,
default=4,
help="Used only when --method is beam_search and modified_beam_search",
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
@ -244,9 +286,9 @@ def main():
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=8.0,
max_contexts=32,
max_states=8,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -254,6 +296,7 @@ def main():
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
@ -263,6 +306,7 @@ def main():
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())

View File

@ -23,16 +23,34 @@ Usage:
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav \
/path/to/bar.wav
(1) beam search
(2) beam search
./pruned_transducer_stateless3/pretrained.py \
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
/path/to/bar.wav
(3) modified beam search
./pruned_transducer_stateless3/pretrained.py \
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./pruned_transducer_stateless3/pretrained.py \
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./pruned_transducer_stateless3/exp/epoch-xx.pt`.
@ -79,9 +97,7 @@ def get_parser():
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
help="""Path to bpe.model.""",
)
parser.add_argument(
@ -117,7 +133,33 @@ def get_parser():
"--beam-size",
type=int,
default=4,
help="Used only when --method is beam_search and modified_beam_search",
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
@ -244,9 +286,9 @@ def main():
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=8.0,
max_contexts=32,
max_states=8,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -254,6 +296,7 @@ def main():
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
@ -263,6 +306,7 @@ def main():
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())