mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
psd algorithm
This commit is contained in:
parent
bc2882ddcc
commit
0a99ceb6ba
@ -47,10 +47,19 @@ def get_args():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--h-graph",
|
||||
type=str,
|
||||
help="""one of ["H", "Trivial"]
|
||||
H: k2.ctc_topo
|
||||
Trivial: k2.trivial_graph
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||
def compile_HLG(lang_dir: str, h_graph: str = "H") -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
@ -62,7 +71,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||
lexicon = Lexicon(lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
||||
H = k2.ctc_topo(max_token_id)
|
||||
|
||||
if h_graph == "H":
|
||||
H = k2.ctc_topo(max_token_id)
|
||||
elif h_graph == "Trivial":
|
||||
H = k2.trivial_graph(max_token_id - 1)
|
||||
else:
|
||||
raise ValueError(f"Unsupported h_graph: {h_graph}")
|
||||
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
if Path("data/lm/G_3_gram.pt").is_file():
|
||||
@ -138,15 +154,17 @@ def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
if (lang_dir / "HLG.pt").is_file():
|
||||
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
|
||||
if (lang_dir / f"{args.h_graph}LG.pt").is_file():
|
||||
logging.info(
|
||||
f"{lang_dir}/{args.h_graph}LG.pt already exists - skipping"
|
||||
)
|
||||
return
|
||||
|
||||
logging.info(f"Processing {lang_dir}")
|
||||
|
||||
HLG = compile_HLG(lang_dir)
|
||||
logging.info(f"Saving HLG.pt to {lang_dir}")
|
||||
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
|
||||
logging.info(f"Saving {args.h_graph}LG.pt to {lang_dir}")
|
||||
torch.save(HLG.as_dict(), f"{lang_dir}/{args.h_graph}LG.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -14,16 +14,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence
|
||||
from model import Transducer
|
||||
|
||||
from icefall.decode import Nbest, one_best_decoding
|
||||
from icefall.utils import get_texts
|
||||
from icefall.decode import get_lattice, Nbest, one_best_decoding
|
||||
from icefall.utils import get_alignments, get_texts
|
||||
|
||||
|
||||
def fast_beam_search_one_best(
|
||||
@ -534,6 +536,8 @@ def greedy_search_batch(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
ngram_rescoring: bool = False,
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
@ -583,6 +587,18 @@ def greedy_search_batch(
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||
|
||||
if ngram_rescoring:
|
||||
vocab_size = model.decoder.vocab_size
|
||||
total_t = encoder_out.shape[0]
|
||||
# cached all joiner outputs during greedy search,
|
||||
# from which non-blank frames are selected before n-gram rescoring.
|
||||
all_logits = torch.zeros([total_t, vocab_size], device=device)
|
||||
|
||||
# A flag indicating a frame is a blank frame or not.
|
||||
# 0 for blank frame and 1 for non-blank frame.
|
||||
# Used to select non-blank frames for n-gram rescoring.
|
||||
non_blank_flag = torch.zeros([total_t], device=device)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
@ -600,7 +616,36 @@ def greedy_search_batch(
|
||||
# logits'shape (batch_size, 1, 1, vocab_size)
|
||||
|
||||
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
||||
assert logits.ndim == 2, logits.shape
|
||||
|
||||
if ngram_rescoring:
|
||||
all_logits[start:end] = logits
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
logits_argmax = logits.argmax(dim=1)
|
||||
logits_softmax = logits.softmax(dim=1)
|
||||
|
||||
# detailed in below fuction verify_non_blank_logits.
|
||||
selection_verification = True
|
||||
|
||||
# 0 for blank frame and 1 for non-blank frame.
|
||||
non_blank_flag[start:end] = torch.where(
|
||||
logits_argmax == blank_id, 0, 1
|
||||
)
|
||||
|
||||
if False:
|
||||
# In paper: https://arxiv.org/pdf/2101.06856.pdf
|
||||
# A gama_blank threshold value is used to determinze blank frame.
|
||||
# Currently, results are worse than baseline greedy_search
|
||||
# and also very sensitive to gama_blank.
|
||||
# (TODO): debug this later.
|
||||
gama_blank = 0.50
|
||||
non_blank_flag[start:end] = torch.where(
|
||||
logits_softmax[:, 0] >= gama_blank, 0, 1
|
||||
)
|
||||
|
||||
# function verify_non_blank_logits only works with logits_argmax == blank_id.
|
||||
selection_verification = False
|
||||
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
@ -624,6 +669,105 @@ def greedy_search_batch(
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
if not ngram_rescoring:
|
||||
return ans
|
||||
|
||||
assert decoding_graph is not None
|
||||
|
||||
# Transform logits to shape [N, T, vocab_size] format to make it easier
|
||||
# to select non-blank frames.
|
||||
packed_all_logits = PackedSequence(
|
||||
all_logits, torch.tensor(batch_size_list)
|
||||
)
|
||||
all_logits_unpacked, _ = pad_packed_sequence(
|
||||
packed_all_logits, batch_first=True
|
||||
)
|
||||
|
||||
# Transform non_blank_flag to shape [N, T]
|
||||
packed_non_blank_flag = PackedSequence(
|
||||
non_blank_flag, torch.tensor(batch_size_list)
|
||||
)
|
||||
non_blank_flag_unpacked, _ = pad_packed_sequence(
|
||||
packed_non_blank_flag, batch_first=True
|
||||
)
|
||||
|
||||
non_blank_logits_lens = torch.sum(non_blank_flag_unpacked, dim=1)
|
||||
max_frame_to_rescore = non_blank_logits_lens.max()
|
||||
|
||||
non_blank_logits = torch.zeros(
|
||||
[N, int(max_frame_to_rescore), vocab_size], device=device
|
||||
)
|
||||
|
||||
# torch.index_select only acceptec a single dimension to index from.
|
||||
# So we need generate non_blank_logits one by one.
|
||||
# Maybe there is another efficient way to do this.
|
||||
for i in range(N):
|
||||
cur_non_blank_index = torch.where(non_blank_flag_unpacked[i, :] != 0)[0]
|
||||
assert non_blank_logits_lens[i] == cur_non_blank_index.shape[0]
|
||||
non_blank_logits[
|
||||
i, : int(non_blank_logits_lens[i]), :
|
||||
] = torch.index_select(
|
||||
all_logits_unpacked[i, :], 0, cur_non_blank_index
|
||||
)
|
||||
|
||||
def verify_non_blank_logits():
|
||||
# A way to verify non_blank_logits are selected correctly from all_logits.
|
||||
hyps_before_rescore = non_blank_logits.argmax(dim=2)
|
||||
for i in range(N):
|
||||
usi = unsorted_indices[i]
|
||||
hyp_to_verify = hyps_before_rescore[usi][
|
||||
: int(non_blank_logits_lens[usi])
|
||||
].tolist()
|
||||
assert ans[i] == hyp_to_verify
|
||||
logging.info("Verified non-blank logits.")
|
||||
|
||||
# TODO: skip verification after we finally get a workable rescoring method.
|
||||
if selection_verification:
|
||||
verify_non_blank_logits()
|
||||
|
||||
# Split log_softmax into two seperate steps,
|
||||
# so we cound do blank deweight in probability domain if needed.
|
||||
logits_to_rescore_softmax = non_blank_logits.softmax(dim=2)
|
||||
logits_to_rescore = logits_to_rescore_softmax.log()
|
||||
|
||||
# In paper: https://arxiv.org/pdf/2101.06856.pdf
|
||||
# blank deweight is applied before non_blank frames selected.
|
||||
# However, in current setup, that results in a higher WER.
|
||||
# So just put this blank deweight before ngram rescoring.
|
||||
# (TODO): debug this blank deweight issue.
|
||||
|
||||
blank_deweight = 100
|
||||
logits_to_rescore[:, :, 0] -= blank_deweight
|
||||
|
||||
supervision_segments = torch.zeros([N, 3], dtype=torch.int32)
|
||||
supervision_segments[:, 0] = torch.arange(0, N, dtype=torch.int32)
|
||||
supervision_segments[:, 2] = non_blank_logits_lens.to(torch.int32)
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=logits_to_rescore,
|
||||
decoding_graph=decoding_graph,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=20,
|
||||
output_beam=8,
|
||||
min_active_states=30,
|
||||
max_active_states=1000,
|
||||
subsampling_factor=1,
|
||||
)
|
||||
|
||||
lm_weight = 0.3 # (TODO): tuning this.
|
||||
lattice.scores = lattice.scores - lattice.lm_scores * (1 - lm_weight)
|
||||
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice,
|
||||
use_double_scores=True,
|
||||
)
|
||||
|
||||
token_ids = get_alignments(best_path, "labels")
|
||||
|
||||
ans = []
|
||||
for i in range(N):
|
||||
usi = unsorted_indices[i]
|
||||
ans.append(token_ids[usi][: int(non_blank_logits_lens[usi])])
|
||||
return ans
|
||||
|
||||
|
||||
|
@ -136,6 +136,28 @@ def get_parser():
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-rescoring",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use ngram_rescoring.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-graph",
|
||||
type=str,
|
||||
default="trivial_graph",
|
||||
help="one of [trivial_grpah, HLG, Trival_LG, LG]"
|
||||
"used by greedy_search_batch with ngram-rescoring=True.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="./data/lang_bpe_500/",
|
||||
help="Path to decoding graphs",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
@ -293,6 +315,8 @@ def decode_one_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
decoding_graph=decoding_graph,
|
||||
ngram_rescoring=params.ngram_rescoring,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
@ -498,6 +522,10 @@ def main():
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
if params.ngram_rescoring:
|
||||
params.suffix += "-ngram-rescoring"
|
||||
params.suffix += f"-{params.decoding_graph}"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
@ -605,6 +633,26 @@ def main():
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
if params.ngram_rescoring and params.decoding_method == "greedy_search":
|
||||
assert params.decoding_graph in [
|
||||
"trivial_graph",
|
||||
"HLG",
|
||||
"Trivial_LG",
|
||||
], f"Unsupported decoding graph {params.decoding_graph}"
|
||||
if params.decoding_graph == "trivial_graph":
|
||||
decoding_graph = k2.trivial_graph(
|
||||
params.vocab_size - 1, device=device
|
||||
)
|
||||
else:
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(
|
||||
f"data/lang_bpe_500/{params.decoding_graph}.pt",
|
||||
map_location=device,
|
||||
)
|
||||
)
|
||||
|
||||
decoding_graph.lm_scores = decoding_graph.scores.clone()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user