not removing result_dir in tedlium conformer ctc2 + add lm stem to compile_hlg_using_openfst.py + add MASTER_ADDR to be prvided to setup_dist (#801)

This commit is contained in:
Daniil 2023-01-01 19:08:32 -05:00 committed by GitHub
parent 67ae5fdf2b
commit 2fd970b682
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 12 deletions

View File

@ -24,7 +24,7 @@ This script takes as input lang_dir and generates HLG from
Caution: We use a lexicon that contains disambiguation symbols Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G_3_gram.fst.txt - G, the LM, built from data/lm/G_n_gram.fst.txt
The generated HLG is saved in $lang_dir/HLG_fst.pt The generated HLG is saved in $lang_dir/HLG_fst.pt
@ -46,6 +46,13 @@ from icefall.lexicon import Lexicon
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
"--lm",
type=str,
default="G_3_gram",
help="""Stem name for LM used in HLG compiling.
""",
)
parser.add_argument( parser.add_argument(
"--lang-dir", "--lang-dir",
type=str, type=str,
@ -56,11 +63,13 @@ def get_args():
return parser.parse_args() return parser.parse_args()
def compile_HLG(lang_dir: str) -> kaldifst.StdVectorFst: def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> kaldifst.StdVectorFst:
""" """
Args: Args:
lang_dir: lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000. The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
lm:
The language stem base name.
Return: Return:
An FST representing HLG. An FST representing HLG.
@ -71,8 +80,8 @@ def compile_HLG(lang_dir: str) -> kaldifst.StdVectorFst:
kaldifst.arcsort(L, sort_type="olabel") kaldifst.arcsort(L, sort_type="olabel")
logging.info(f"L: #states {L.num_states}") logging.info(f"L: #states {L.num_states}")
G_filename_txt = "data/lm/G_3_gram.fst.txt" G_filename_txt = f"data/lm/{lm}.fst.txt"
G_filename_binary = "data/lm/G_3_gram.fst" G_filename_binary = f"data/lm/{lm}.fst"
if Path(G_filename_binary).is_file(): if Path(G_filename_binary).is_file():
logging.info(f"Loading {G_filename_binary}") logging.info(f"Loading {G_filename_binary}")
G = kaldifst.StdVectorFst.read(G_filename_binary) G = kaldifst.StdVectorFst.read(G_filename_binary)
@ -171,7 +180,7 @@ def main():
logging.info(f"{filename} already exists - skipping") logging.info(f"{filename} already exists - skipping")
return return
HLG = compile_HLG(lang_dir) HLG = compile_HLG(lang_dir, args.lm)
logging.info(f"Saving HLG to {filename}") logging.info(f"Saving HLG to {filename}")
torch.save(HLG.as_dict(), filename) torch.save(HLG.as_dict(), filename)

View File

@ -20,7 +20,6 @@
import argparse import argparse
import logging import logging
import shutil
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -183,7 +182,7 @@ def get_parser() -> argparse.ArgumentParser:
parser.add_argument( parser.add_argument(
"--result-dir", "--result-dir",
type=str, type=str,
default="conformer_ctc2/exp", default="conformer_ctc2/exp/results",
help="Directory to store results.", help="Directory to store results.",
) )
@ -635,9 +634,7 @@ def main() -> None:
args.lm_path = Path(args.lm_path) args.lm_path = Path(args.lm_path)
args.result_dir = Path(args.result_dir) args.result_dir = Path(args.result_dir)
if args.result_dir.is_dir(): args.result_dir.mkdir(exist_ok=True)
shutil.rmtree(args.result_dir)
args.result_dir.mkdir()
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))

View File

@ -21,12 +21,16 @@ import torch
from torch import distributed as dist from torch import distributed as dist
def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False): def setup_dist(
rank, world_size, master_addr=None, master_port=None, use_ddp_launch=False
):
""" """
rank and world_size are used only if use_ddp_launch is False. rank and world_size are used only if use_ddp_launch is False.
""" """
if "MASTER_ADDR" not in os.environ: if "MASTER_ADDR" not in os.environ:
os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_ADDR"] = (
"localhost" if master_addr is None else str(master_addr)
)
if "MASTER_PORT" not in os.environ: if "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port) os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port)