remove hlg related modifications
This commit is contained in:
parent
473efcd531
commit
1ebf714fb7
@ -47,19 +47,10 @@ def get_args():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--h-graph",
|
||||
type=str,
|
||||
help="""one of ["H", "Trivial"]
|
||||
H: k2.ctc_topo
|
||||
Trivial: k2.trivial_graph
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compile_HLG(lang_dir: str, h_graph: str = "H") -> k2.Fsa:
|
||||
def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
@ -71,14 +62,7 @@ def compile_HLG(lang_dir: str, h_graph: str = "H") -> k2.Fsa:
|
||||
lexicon = Lexicon(lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
||||
|
||||
if h_graph == "H":
|
||||
H = k2.ctc_topo(max_token_id)
|
||||
elif h_graph == "Trivial":
|
||||
H = k2.trivial_graph(max_token_id - 1)
|
||||
else:
|
||||
raise ValueError(f"Unsupported h_graph: {h_graph}")
|
||||
|
||||
H = k2.ctc_topo(max_token_id)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
if Path("data/lm/G_3_gram.pt").is_file():
|
||||
@ -154,17 +138,15 @@ def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
if (lang_dir / f"{args.h_graph}LG.pt").is_file():
|
||||
logging.info(
|
||||
f"{lang_dir}/{args.h_graph}LG.pt already exists - skipping"
|
||||
)
|
||||
if (lang_dir / "HLG.pt").is_file():
|
||||
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
|
||||
return
|
||||
|
||||
logging.info(f"Processing {lang_dir}")
|
||||
|
||||
HLG = compile_HLG(lang_dir)
|
||||
logging.info(f"Saving {args.h_graph}LG.pt to {lang_dir}")
|
||||
torch.save(HLG.as_dict(), f"{lang_dir}/{args.h_graph}LG.pt")
|
||||
logging.info(f"Saving HLG.pt to {lang_dir}")
|
||||
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -614,15 +614,14 @@ def greedy_search_batch(
|
||||
logits = model.joiner(
|
||||
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
||||
)
|
||||
# logits'shape (batch_size, 1, 1, vocab_size)
|
||||
|
||||
# logits'shape (batch_size, 1, 1, vocab_size)
|
||||
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
||||
|
||||
if ngram_rescoring:
|
||||
all_logits[start:end] = logits
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
logits_argmax = logits.argmax(dim=1)
|
||||
logits_softmax = logits.softmax(dim=1)
|
||||
|
||||
|
||||
@ -729,9 +728,6 @@ def greedy_search_batch(
|
||||
subsampling_factor=1,
|
||||
)
|
||||
|
||||
lm_weight = 0.5 # (TODO): tuning this.
|
||||
lattice.scores = lattice.scores - lattice.lm_scores * (1 - lm_weight)
|
||||
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice,
|
||||
use_double_scores=True,
|
||||
|
||||
@ -659,8 +659,6 @@ def main():
|
||||
)
|
||||
decoding_graph = k2.add_epsilon_self_loops(decoding_graph)
|
||||
|
||||
decoding_graph.lm_scores = decoding_graph.scores.clone()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user