not removing result_dir in tedlium conformer ctc2 + add lm stem to compile_hlg_using_openfst.py + add MASTER_ADDR to be prvided to setup_dist (#801)

This commit is contained in:
Daniil 2023-01-01 19:08:32 -05:00 committed by GitHub
parent 67ae5fdf2b
commit 2fd970b682
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 12 deletions

View File

@ -24,7 +24,7 @@ This script takes as input lang_dir and generates HLG from
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G_3_gram.fst.txt
- G, the LM, built from data/lm/G_n_gram.fst.txt
The generated HLG is saved in $lang_dir/HLG_fst.pt
@ -46,6 +46,13 @@ from icefall.lexicon import Lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lm",
type=str,
default="G_3_gram",
help="""Stem name for LM used in HLG compiling.
""",
)
parser.add_argument(
"--lang-dir",
type=str,
@ -56,11 +63,13 @@ def get_args():
return parser.parse_args()
def compile_HLG(lang_dir: str) -> kaldifst.StdVectorFst:
def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> kaldifst.StdVectorFst:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
lm:
The language stem base name.
Return:
An FST representing HLG.
@ -71,8 +80,8 @@ def compile_HLG(lang_dir: str) -> kaldifst.StdVectorFst:
kaldifst.arcsort(L, sort_type="olabel")
logging.info(f"L: #states {L.num_states}")
G_filename_txt = "data/lm/G_3_gram.fst.txt"
G_filename_binary = "data/lm/G_3_gram.fst"
G_filename_txt = f"data/lm/{lm}.fst.txt"
G_filename_binary = f"data/lm/{lm}.fst"
if Path(G_filename_binary).is_file():
logging.info(f"Loading {G_filename_binary}")
G = kaldifst.StdVectorFst.read(G_filename_binary)
@ -171,7 +180,7 @@ def main():
logging.info(f"{filename} already exists - skipping")
return
HLG = compile_HLG(lang_dir)
HLG = compile_HLG(lang_dir, args.lm)
logging.info(f"Saving HLG to {filename}")
torch.save(HLG.as_dict(), filename)

View File

@ -20,7 +20,6 @@
import argparse
import logging
import shutil
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@ -183,7 +182,7 @@ def get_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--result-dir",
type=str,
default="conformer_ctc2/exp",
default="conformer_ctc2/exp/results",
help="Directory to store results.",
)
@ -635,9 +634,7 @@ def main() -> None:
args.lm_path = Path(args.lm_path)
args.result_dir = Path(args.result_dir)
if args.result_dir.is_dir():
shutil.rmtree(args.result_dir)
args.result_dir.mkdir()
args.result_dir.mkdir(exist_ok=True)
params = get_params()
params.update(vars(args))

View File

@ -21,12 +21,16 @@ import torch
from torch import distributed as dist
def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False):
def setup_dist(
rank, world_size, master_addr=None, master_port=None, use_ddp_launch=False
):
"""
rank and world_size are used only if use_ddp_launch is False.
"""
if "MASTER_ADDR" not in os.environ:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_ADDR"] = (
"localhost" if master_addr is None else str(master_addr)
)
if "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port)