changes for pretrained.py

This commit is contained in:
luomingshuang 2022-04-12 16:36:43 +08:00
parent b9c0e8e957
commit 05fce8e3a3
2 changed files with 112 additions and 33 deletions

View File

@ -88,7 +88,7 @@ def fast_beam_search(
# (shape.NumElements(), 1, encoder_out_dim) # (shape.NumElements(), 1, encoder_out_dim)
# fmt: off # fmt: off
current_encoder_out = torch.index_select( current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1) encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).long()
) )
# fmt: on # fmt: on
logits = model.joiner( logits = model.joiner(
@ -486,10 +486,7 @@ def modified_beam_search(
for i in range(batch_size): for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
topk_hyp_indexes = torch.div( topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_indexes, vocab_size, rounding_mode="trunc"
)
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist() topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)): for k in range(len(topk_hyp_indexes)):

View File

@ -36,7 +36,6 @@ Usage:
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless/pretrained.py \ ./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
@ -46,6 +45,17 @@ Usage:
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
(4) fast beam search
./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./pruned_transducer_stateless/exp/epoch-xx.pt`. You can also use `./pruned_transducer_stateless/exp/epoch-xx.pt`.
Note: ./pruned_transducer_stateless/exp/pretrained.pt is generated by Note: ./pruned_transducer_stateless/exp/pretrained.pt is generated by
@ -58,12 +68,19 @@ import logging
import math import math
from typing import List from typing import List
import k2
import kaldifeat import kaldifeat
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
beam_search,
fast_beam_search,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
@ -97,12 +114,14 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--method", "--decoding-method",
type=str, type=str,
default="greedy_search", default="greedy_search",
help="""Possible values are: help="""Possible values are:
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search
- fast_beam_search
""", """,
) )
@ -123,6 +142,32 @@ def get_parser():
help="Used only when --method is beam_search and modified_beam_search ", help="Used only when --method is beam_search and modified_beam_search ",
) )
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument( parser.add_argument(
"--context-size", "--context-size",
type=int, type=int,
@ -134,7 +179,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="""Maximum number of symbols per frame. Used only when help="""Maximum number of symbols per frame. Used only when
--method is greedy_search. --method is greedy_search.
""", """,
@ -268,6 +313,11 @@ def main():
model.eval() model.eval()
model.device = device model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
logging.info("Constructing Fbank computer") logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions() opts = kaldifeat.FbankOptions()
opts.device = device opts.device = device
@ -299,34 +349,66 @@ def main():
x=features, x_lens=feature_lengths x=features, x_lens=feature_lengths
) )
num_waves = encoder_out.size(0)
hyps = [] hyps = []
msg = f"Using {params.method}" msg = f"Using {params.decoding_method}"
if params.method == "beam_search": if params.decoding_method == "beam_search":
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
elif params.method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split()) if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):