From 2fd970b6821d47dacb2e6513321520db21fff67b Mon Sep 17 00:00:00 2001 From: Daniil Date: Sun, 1 Jan 2023 19:08:32 -0500 Subject: [PATCH] not removing result_dir in tedlium conformer ctc2 + add lm stem to compile_hlg_using_openfst.py + add MASTER_ADDR to be prvided to setup_dist (#801) --- .../ASR/local/compile_hlg_using_openfst.py | 19 ++++++++++++++----- egs/tedlium3/ASR/conformer_ctc2/decode.py | 7 ++----- icefall/dist.py | 8 ++++++-- 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/local/compile_hlg_using_openfst.py b/egs/librispeech/ASR/local/compile_hlg_using_openfst.py index 9e5e3df69..15fc47ef1 100755 --- a/egs/librispeech/ASR/local/compile_hlg_using_openfst.py +++ b/egs/librispeech/ASR/local/compile_hlg_using_openfst.py @@ -24,7 +24,7 @@ This script takes as input lang_dir and generates HLG from Caution: We use a lexicon that contains disambiguation symbols - - G, the LM, built from data/lm/G_3_gram.fst.txt + - G, the LM, built from data/lm/G_n_gram.fst.txt The generated HLG is saved in $lang_dir/HLG_fst.pt @@ -46,6 +46,13 @@ from icefall.lexicon import Lexicon def get_args(): parser = argparse.ArgumentParser() + parser.add_argument( + "--lm", + type=str, + default="G_3_gram", + help="""Stem name for LM used in HLG compiling. + """, + ) parser.add_argument( "--lang-dir", type=str, @@ -56,11 +63,13 @@ def get_args(): return parser.parse_args() -def compile_HLG(lang_dir: str) -> kaldifst.StdVectorFst: +def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> kaldifst.StdVectorFst: """ Args: lang_dir: The language directory, e.g., data/lang_phone or data/lang_bpe_5000. + lm: + The language stem base name. Return: An FST representing HLG. @@ -71,8 +80,8 @@ def compile_HLG(lang_dir: str) -> kaldifst.StdVectorFst: kaldifst.arcsort(L, sort_type="olabel") logging.info(f"L: #states {L.num_states}") - G_filename_txt = "data/lm/G_3_gram.fst.txt" - G_filename_binary = "data/lm/G_3_gram.fst" + G_filename_txt = f"data/lm/{lm}.fst.txt" + G_filename_binary = f"data/lm/{lm}.fst" if Path(G_filename_binary).is_file(): logging.info(f"Loading {G_filename_binary}") G = kaldifst.StdVectorFst.read(G_filename_binary) @@ -171,7 +180,7 @@ def main(): logging.info(f"{filename} already exists - skipping") return - HLG = compile_HLG(lang_dir) + HLG = compile_HLG(lang_dir, args.lm) logging.info(f"Saving HLG to {filename}") torch.save(HLG.as_dict(), filename) diff --git a/egs/tedlium3/ASR/conformer_ctc2/decode.py b/egs/tedlium3/ASR/conformer_ctc2/decode.py index ce4dcd142..28d39de70 100755 --- a/egs/tedlium3/ASR/conformer_ctc2/decode.py +++ b/egs/tedlium3/ASR/conformer_ctc2/decode.py @@ -20,7 +20,6 @@ import argparse import logging -import shutil from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -183,7 +182,7 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument( "--result-dir", type=str, - default="conformer_ctc2/exp", + default="conformer_ctc2/exp/results", help="Directory to store results.", ) @@ -635,9 +634,7 @@ def main() -> None: args.lm_path = Path(args.lm_path) args.result_dir = Path(args.result_dir) - if args.result_dir.is_dir(): - shutil.rmtree(args.result_dir) - args.result_dir.mkdir() + args.result_dir.mkdir(exist_ok=True) params = get_params() params.update(vars(args)) diff --git a/icefall/dist.py b/icefall/dist.py index 9df1c5bd1..672948623 100644 --- a/icefall/dist.py +++ b/icefall/dist.py @@ -21,12 +21,16 @@ import torch from torch import distributed as dist -def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False): +def setup_dist( + rank, world_size, master_addr=None, master_port=None, use_ddp_launch=False +): """ rank and world_size are used only if use_ddp_launch is False. """ if "MASTER_ADDR" not in os.environ: - os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_ADDR"] = ( + "localhost" if master_addr is None else str(master_addr) + ) if "MASTER_PORT" not in os.environ: os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port)