mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
psd algorithm
This commit is contained in:
parent
bc2882ddcc
commit
0a99ceb6ba
@ -47,10 +47,19 @@ def get_args():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--h-graph",
|
||||||
|
type=str,
|
||||||
|
help="""one of ["H", "Trivial"]
|
||||||
|
H: k2.ctc_topo
|
||||||
|
Trivial: k2.trivial_graph
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def compile_HLG(lang_dir: str) -> k2.Fsa:
|
def compile_HLG(lang_dir: str, h_graph: str = "H") -> k2.Fsa:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
lang_dir:
|
lang_dir:
|
||||||
@ -62,7 +71,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
|||||||
lexicon = Lexicon(lang_dir)
|
lexicon = Lexicon(lang_dir)
|
||||||
max_token_id = max(lexicon.tokens)
|
max_token_id = max(lexicon.tokens)
|
||||||
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
||||||
H = k2.ctc_topo(max_token_id)
|
|
||||||
|
if h_graph == "H":
|
||||||
|
H = k2.ctc_topo(max_token_id)
|
||||||
|
elif h_graph == "Trivial":
|
||||||
|
H = k2.trivial_graph(max_token_id - 1)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported h_graph: {h_graph}")
|
||||||
|
|
||||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||||
|
|
||||||
if Path("data/lm/G_3_gram.pt").is_file():
|
if Path("data/lm/G_3_gram.pt").is_file():
|
||||||
@ -138,15 +154,17 @@ def main():
|
|||||||
args = get_args()
|
args = get_args()
|
||||||
lang_dir = Path(args.lang_dir)
|
lang_dir = Path(args.lang_dir)
|
||||||
|
|
||||||
if (lang_dir / "HLG.pt").is_file():
|
if (lang_dir / f"{args.h_graph}LG.pt").is_file():
|
||||||
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
|
logging.info(
|
||||||
|
f"{lang_dir}/{args.h_graph}LG.pt already exists - skipping"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
logging.info(f"Processing {lang_dir}")
|
logging.info(f"Processing {lang_dir}")
|
||||||
|
|
||||||
HLG = compile_HLG(lang_dir)
|
HLG = compile_HLG(lang_dir)
|
||||||
logging.info(f"Saving HLG.pt to {lang_dir}")
|
logging.info(f"Saving {args.h_graph}LG.pt to {lang_dir}")
|
||||||
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
|
torch.save(HLG.as_dict(), f"{lang_dir}/{args.h_graph}LG.pt")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -14,16 +14,18 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
|
||||||
from icefall.decode import Nbest, one_best_decoding
|
from icefall.decode import get_lattice, Nbest, one_best_decoding
|
||||||
from icefall.utils import get_texts
|
from icefall.utils import get_alignments, get_texts
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search_one_best(
|
def fast_beam_search_one_best(
|
||||||
@ -534,6 +536,8 @@ def greedy_search_batch(
|
|||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
ngram_rescoring: bool = False,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
Args:
|
Args:
|
||||||
@ -583,6 +587,18 @@ def greedy_search_batch(
|
|||||||
|
|
||||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||||
|
|
||||||
|
if ngram_rescoring:
|
||||||
|
vocab_size = model.decoder.vocab_size
|
||||||
|
total_t = encoder_out.shape[0]
|
||||||
|
# cached all joiner outputs during greedy search,
|
||||||
|
# from which non-blank frames are selected before n-gram rescoring.
|
||||||
|
all_logits = torch.zeros([total_t, vocab_size], device=device)
|
||||||
|
|
||||||
|
# A flag indicating a frame is a blank frame or not.
|
||||||
|
# 0 for blank frame and 1 for non-blank frame.
|
||||||
|
# Used to select non-blank frames for n-gram rescoring.
|
||||||
|
non_blank_flag = torch.zeros([total_t], device=device)
|
||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
for batch_size in batch_size_list:
|
for batch_size in batch_size_list:
|
||||||
start = offset
|
start = offset
|
||||||
@ -600,7 +616,36 @@ def greedy_search_batch(
|
|||||||
# logits'shape (batch_size, 1, 1, vocab_size)
|
# logits'shape (batch_size, 1, 1, vocab_size)
|
||||||
|
|
||||||
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
||||||
assert logits.ndim == 2, logits.shape
|
|
||||||
|
if ngram_rescoring:
|
||||||
|
all_logits[start:end] = logits
|
||||||
|
|
||||||
|
assert logits.ndim == 2, logits.shape
|
||||||
|
logits_argmax = logits.argmax(dim=1)
|
||||||
|
logits_softmax = logits.softmax(dim=1)
|
||||||
|
|
||||||
|
# detailed in below fuction verify_non_blank_logits.
|
||||||
|
selection_verification = True
|
||||||
|
|
||||||
|
# 0 for blank frame and 1 for non-blank frame.
|
||||||
|
non_blank_flag[start:end] = torch.where(
|
||||||
|
logits_argmax == blank_id, 0, 1
|
||||||
|
)
|
||||||
|
|
||||||
|
if False:
|
||||||
|
# In paper: https://arxiv.org/pdf/2101.06856.pdf
|
||||||
|
# A gama_blank threshold value is used to determinze blank frame.
|
||||||
|
# Currently, results are worse than baseline greedy_search
|
||||||
|
# and also very sensitive to gama_blank.
|
||||||
|
# (TODO): debug this later.
|
||||||
|
gama_blank = 0.50
|
||||||
|
non_blank_flag[start:end] = torch.where(
|
||||||
|
logits_softmax[:, 0] >= gama_blank, 0, 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# function verify_non_blank_logits only works with logits_argmax == blank_id.
|
||||||
|
selection_verification = False
|
||||||
|
|
||||||
y = logits.argmax(dim=1).tolist()
|
y = logits.argmax(dim=1).tolist()
|
||||||
emitted = False
|
emitted = False
|
||||||
for i, v in enumerate(y):
|
for i, v in enumerate(y):
|
||||||
@ -624,6 +669,105 @@ def greedy_search_batch(
|
|||||||
for i in range(N):
|
for i in range(N):
|
||||||
ans.append(sorted_ans[unsorted_indices[i]])
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
|
||||||
|
if not ngram_rescoring:
|
||||||
|
return ans
|
||||||
|
|
||||||
|
assert decoding_graph is not None
|
||||||
|
|
||||||
|
# Transform logits to shape [N, T, vocab_size] format to make it easier
|
||||||
|
# to select non-blank frames.
|
||||||
|
packed_all_logits = PackedSequence(
|
||||||
|
all_logits, torch.tensor(batch_size_list)
|
||||||
|
)
|
||||||
|
all_logits_unpacked, _ = pad_packed_sequence(
|
||||||
|
packed_all_logits, batch_first=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transform non_blank_flag to shape [N, T]
|
||||||
|
packed_non_blank_flag = PackedSequence(
|
||||||
|
non_blank_flag, torch.tensor(batch_size_list)
|
||||||
|
)
|
||||||
|
non_blank_flag_unpacked, _ = pad_packed_sequence(
|
||||||
|
packed_non_blank_flag, batch_first=True
|
||||||
|
)
|
||||||
|
|
||||||
|
non_blank_logits_lens = torch.sum(non_blank_flag_unpacked, dim=1)
|
||||||
|
max_frame_to_rescore = non_blank_logits_lens.max()
|
||||||
|
|
||||||
|
non_blank_logits = torch.zeros(
|
||||||
|
[N, int(max_frame_to_rescore), vocab_size], device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# torch.index_select only acceptec a single dimension to index from.
|
||||||
|
# So we need generate non_blank_logits one by one.
|
||||||
|
# Maybe there is another efficient way to do this.
|
||||||
|
for i in range(N):
|
||||||
|
cur_non_blank_index = torch.where(non_blank_flag_unpacked[i, :] != 0)[0]
|
||||||
|
assert non_blank_logits_lens[i] == cur_non_blank_index.shape[0]
|
||||||
|
non_blank_logits[
|
||||||
|
i, : int(non_blank_logits_lens[i]), :
|
||||||
|
] = torch.index_select(
|
||||||
|
all_logits_unpacked[i, :], 0, cur_non_blank_index
|
||||||
|
)
|
||||||
|
|
||||||
|
def verify_non_blank_logits():
|
||||||
|
# A way to verify non_blank_logits are selected correctly from all_logits.
|
||||||
|
hyps_before_rescore = non_blank_logits.argmax(dim=2)
|
||||||
|
for i in range(N):
|
||||||
|
usi = unsorted_indices[i]
|
||||||
|
hyp_to_verify = hyps_before_rescore[usi][
|
||||||
|
: int(non_blank_logits_lens[usi])
|
||||||
|
].tolist()
|
||||||
|
assert ans[i] == hyp_to_verify
|
||||||
|
logging.info("Verified non-blank logits.")
|
||||||
|
|
||||||
|
# TODO: skip verification after we finally get a workable rescoring method.
|
||||||
|
if selection_verification:
|
||||||
|
verify_non_blank_logits()
|
||||||
|
|
||||||
|
# Split log_softmax into two seperate steps,
|
||||||
|
# so we cound do blank deweight in probability domain if needed.
|
||||||
|
logits_to_rescore_softmax = non_blank_logits.softmax(dim=2)
|
||||||
|
logits_to_rescore = logits_to_rescore_softmax.log()
|
||||||
|
|
||||||
|
# In paper: https://arxiv.org/pdf/2101.06856.pdf
|
||||||
|
# blank deweight is applied before non_blank frames selected.
|
||||||
|
# However, in current setup, that results in a higher WER.
|
||||||
|
# So just put this blank deweight before ngram rescoring.
|
||||||
|
# (TODO): debug this blank deweight issue.
|
||||||
|
|
||||||
|
blank_deweight = 100
|
||||||
|
logits_to_rescore[:, :, 0] -= blank_deweight
|
||||||
|
|
||||||
|
supervision_segments = torch.zeros([N, 3], dtype=torch.int32)
|
||||||
|
supervision_segments[:, 0] = torch.arange(0, N, dtype=torch.int32)
|
||||||
|
supervision_segments[:, 2] = non_blank_logits_lens.to(torch.int32)
|
||||||
|
|
||||||
|
lattice = get_lattice(
|
||||||
|
nnet_output=logits_to_rescore,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
supervision_segments=supervision_segments,
|
||||||
|
search_beam=20,
|
||||||
|
output_beam=8,
|
||||||
|
min_active_states=30,
|
||||||
|
max_active_states=1000,
|
||||||
|
subsampling_factor=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
lm_weight = 0.3 # (TODO): tuning this.
|
||||||
|
lattice.scores = lattice.scores - lattice.lm_scores * (1 - lm_weight)
|
||||||
|
|
||||||
|
best_path = one_best_decoding(
|
||||||
|
lattice=lattice,
|
||||||
|
use_double_scores=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
token_ids = get_alignments(best_path, "labels")
|
||||||
|
|
||||||
|
ans = []
|
||||||
|
for i in range(N):
|
||||||
|
usi = unsorted_indices[i]
|
||||||
|
ans.append(token_ids[usi][: int(non_blank_logits_lens[usi])])
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@ -136,6 +136,28 @@ def get_parser():
|
|||||||
"`epoch` are loaded for averaging. ",
|
"`epoch` are loaded for averaging. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngram-rescoring",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to use ngram_rescoring.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoding-graph",
|
||||||
|
type=str,
|
||||||
|
default="trivial_graph",
|
||||||
|
help="one of [trivial_grpah, HLG, Trival_LG, LG]"
|
||||||
|
"used by greedy_search_batch with ngram-rescoring=True.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
default="./data/lang_bpe_500/",
|
||||||
|
help="Path to decoding graphs",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -293,6 +315,8 @@ def decode_one_batch(
|
|||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
ngram_rescoring=params.ngram_rescoring,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -498,6 +522,10 @@ def main():
|
|||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "-use-averaged-model"
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
|
if params.ngram_rescoring:
|
||||||
|
params.suffix += "-ngram-rescoring"
|
||||||
|
params.suffix += f"-{params.decoding_graph}"
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
|
||||||
@ -605,6 +633,26 @@ def main():
|
|||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
|
|
||||||
|
if params.ngram_rescoring and params.decoding_method == "greedy_search":
|
||||||
|
assert params.decoding_graph in [
|
||||||
|
"trivial_graph",
|
||||||
|
"HLG",
|
||||||
|
"Trivial_LG",
|
||||||
|
], f"Unsupported decoding graph {params.decoding_graph}"
|
||||||
|
if params.decoding_graph == "trivial_graph":
|
||||||
|
decoding_graph = k2.trivial_graph(
|
||||||
|
params.vocab_size - 1, device=device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
decoding_graph = k2.Fsa.from_dict(
|
||||||
|
torch.load(
|
||||||
|
f"data/lang_bpe_500/{params.decoding_graph}.pt",
|
||||||
|
map_location=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
decoding_graph.lm_scores = decoding_graph.scores.clone()
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user