remove hlg related modifications

This commit is contained in:
Guo Liyong 2022-07-16 13:37:31 +08:00
parent 473efcd531
commit 1ebf714fb7
3 changed files with 7 additions and 31 deletions

View File

@ -47,19 +47,10 @@ def get_args():
""",
)
parser.add_argument(
"--h-graph",
type=str,
help="""one of ["H", "Trivial"]
H: k2.ctc_topo
Trivial: k2.trivial_graph
""",
)
return parser.parse_args()
def compile_HLG(lang_dir: str, h_graph: str = "H") -> k2.Fsa:
def compile_HLG(lang_dir: str) -> k2.Fsa:
"""
Args:
lang_dir:
@ -71,14 +62,7 @@ def compile_HLG(lang_dir: str, h_graph: str = "H") -> k2.Fsa:
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
if h_graph == "H":
H = k2.ctc_topo(max_token_id)
elif h_graph == "Trivial":
H = k2.trivial_graph(max_token_id - 1)
else:
raise ValueError(f"Unsupported h_graph: {h_graph}")
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path("data/lm/G_3_gram.pt").is_file():
@ -154,17 +138,15 @@ def main():
args = get_args()
lang_dir = Path(args.lang_dir)
if (lang_dir / f"{args.h_graph}LG.pt").is_file():
logging.info(
f"{lang_dir}/{args.h_graph}LG.pt already exists - skipping"
)
if (lang_dir / "HLG.pt").is_file():
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
return
logging.info(f"Processing {lang_dir}")
HLG = compile_HLG(lang_dir)
logging.info(f"Saving {args.h_graph}LG.pt to {lang_dir}")
torch.save(HLG.as_dict(), f"{lang_dir}/{args.h_graph}LG.pt")
logging.info(f"Saving HLG.pt to {lang_dir}")
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
if __name__ == "__main__":

View File

@ -614,15 +614,14 @@ def greedy_search_batch(
logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
)
# logits'shape (batch_size, 1, 1, vocab_size)
# logits'shape (batch_size, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
if ngram_rescoring:
all_logits[start:end] = logits
assert logits.ndim == 2, logits.shape
logits_argmax = logits.argmax(dim=1)
logits_softmax = logits.softmax(dim=1)
@ -729,9 +728,6 @@ def greedy_search_batch(
subsampling_factor=1,
)
lm_weight = 0.5 # (TODO): tuning this.
lattice.scores = lattice.scores - lattice.lm_scores * (1 - lm_weight)
best_path = one_best_decoding(
lattice=lattice,
use_double_scores=True,

View File

@ -659,8 +659,6 @@ def main():
)
decoding_graph = k2.add_epsilon_self_loops(decoding_graph)
decoding_graph.lm_scores = decoding_graph.scores.clone()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")