mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
not removing result_dir in tedlium conformer ctc2 + add lm stem to compile_hlg_using_openfst.py + add MASTER_ADDR to be prvided to setup_dist (#801)
This commit is contained in:
parent
67ae5fdf2b
commit
2fd970b682
@ -24,7 +24,7 @@ This script takes as input lang_dir and generates HLG from
|
|||||||
|
|
||||||
Caution: We use a lexicon that contains disambiguation symbols
|
Caution: We use a lexicon that contains disambiguation symbols
|
||||||
|
|
||||||
- G, the LM, built from data/lm/G_3_gram.fst.txt
|
- G, the LM, built from data/lm/G_n_gram.fst.txt
|
||||||
|
|
||||||
The generated HLG is saved in $lang_dir/HLG_fst.pt
|
The generated HLG is saved in $lang_dir/HLG_fst.pt
|
||||||
|
|
||||||
@ -46,6 +46,13 @@ from icefall.lexicon import Lexicon
|
|||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm",
|
||||||
|
type=str,
|
||||||
|
default="G_3_gram",
|
||||||
|
help="""Stem name for LM used in HLG compiling.
|
||||||
|
""",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--lang-dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -56,11 +63,13 @@ def get_args():
|
|||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def compile_HLG(lang_dir: str) -> kaldifst.StdVectorFst:
|
def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> kaldifst.StdVectorFst:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
lang_dir:
|
lang_dir:
|
||||||
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
||||||
|
lm:
|
||||||
|
The language stem base name.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
An FST representing HLG.
|
An FST representing HLG.
|
||||||
@ -71,8 +80,8 @@ def compile_HLG(lang_dir: str) -> kaldifst.StdVectorFst:
|
|||||||
kaldifst.arcsort(L, sort_type="olabel")
|
kaldifst.arcsort(L, sort_type="olabel")
|
||||||
logging.info(f"L: #states {L.num_states}")
|
logging.info(f"L: #states {L.num_states}")
|
||||||
|
|
||||||
G_filename_txt = "data/lm/G_3_gram.fst.txt"
|
G_filename_txt = f"data/lm/{lm}.fst.txt"
|
||||||
G_filename_binary = "data/lm/G_3_gram.fst"
|
G_filename_binary = f"data/lm/{lm}.fst"
|
||||||
if Path(G_filename_binary).is_file():
|
if Path(G_filename_binary).is_file():
|
||||||
logging.info(f"Loading {G_filename_binary}")
|
logging.info(f"Loading {G_filename_binary}")
|
||||||
G = kaldifst.StdVectorFst.read(G_filename_binary)
|
G = kaldifst.StdVectorFst.read(G_filename_binary)
|
||||||
@ -171,7 +180,7 @@ def main():
|
|||||||
logging.info(f"{filename} already exists - skipping")
|
logging.info(f"{filename} already exists - skipping")
|
||||||
return
|
return
|
||||||
|
|
||||||
HLG = compile_HLG(lang_dir)
|
HLG = compile_HLG(lang_dir, args.lm)
|
||||||
logging.info(f"Saving HLG to {filename}")
|
logging.info(f"Saving HLG to {filename}")
|
||||||
torch.save(HLG.as_dict(), filename)
|
torch.save(HLG.as_dict(), filename)
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@ -183,7 +182,7 @@ def get_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--result-dir",
|
"--result-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="conformer_ctc2/exp",
|
default="conformer_ctc2/exp/results",
|
||||||
help="Directory to store results.",
|
help="Directory to store results.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -635,9 +634,7 @@ def main() -> None:
|
|||||||
args.lm_path = Path(args.lm_path)
|
args.lm_path = Path(args.lm_path)
|
||||||
args.result_dir = Path(args.result_dir)
|
args.result_dir = Path(args.result_dir)
|
||||||
|
|
||||||
if args.result_dir.is_dir():
|
args.result_dir.mkdir(exist_ok=True)
|
||||||
shutil.rmtree(args.result_dir)
|
|
||||||
args.result_dir.mkdir()
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
@ -21,12 +21,16 @@ import torch
|
|||||||
from torch import distributed as dist
|
from torch import distributed as dist
|
||||||
|
|
||||||
|
|
||||||
def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False):
|
def setup_dist(
|
||||||
|
rank, world_size, master_addr=None, master_port=None, use_ddp_launch=False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
rank and world_size are used only if use_ddp_launch is False.
|
rank and world_size are used only if use_ddp_launch is False.
|
||||||
"""
|
"""
|
||||||
if "MASTER_ADDR" not in os.environ:
|
if "MASTER_ADDR" not in os.environ:
|
||||||
os.environ["MASTER_ADDR"] = "localhost"
|
os.environ["MASTER_ADDR"] = (
|
||||||
|
"localhost" if master_addr is None else str(master_addr)
|
||||||
|
)
|
||||||
|
|
||||||
if "MASTER_PORT" not in os.environ:
|
if "MASTER_PORT" not in os.environ:
|
||||||
os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port)
|
os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user