diff --git a/egs/fisher_swbd/ASR/local/compute_fbank_fisher_swbd_eval2000.py b/egs/fisher_swbd/ASR/local/compute_fbank_fisher_swbd_eval2000.py index eb6ff6d83..d2bb3e3e2 100755 --- a/egs/fisher_swbd/ASR/local/compute_fbank_fisher_swbd_eval2000.py +++ b/egs/fisher_swbd/ASR/local/compute_fbank_fisher_swbd_eval2000.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) + + # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -47,18 +48,19 @@ def compute_fbank_fisher_swbd_eval2000(): num_jobs = min(25, os.cpu_count()) num_mel_bins = 80 sampling_rate = 8000 - dataset_parts = ( - "eval2000", - "fisher", - "swbd", - ) - test_dataset=("eval2000",) + dataset_parts = ("eval2000", "fisher", "swbd") + test_dataset = ("eval2000",) manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, output_dir=src_dir, lazy=True, suffix="jsonl" + dataset_parts=dataset_parts, + output_dir=src_dir, + lazy=True, + suffix="jsonl", ) assert manifests is not None - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=sampling_rate)) + extractor = Fbank( + FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=sampling_rate) + ) with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): @@ -67,10 +69,9 @@ def compute_fbank_fisher_swbd_eval2000(): continue logging.info(f"Processing {partition}") cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], + recordings=m["recordings"], supervisions=m["supervisions"] ) - #if "train" in partition: + # if "train" in partition: if partition not in test_dataset: logging.info(f"Adding speed perturbations to : {partition}") cut_set = ( diff --git a/egs/fisher_swbd/ASR/local/compute_fbank_musan.py b/egs/fisher_swbd/ASR/local/compute_fbank_musan.py index acb477540..5f217fd63 100755 --- a/egs/fisher_swbd/ASR/local/compute_fbank_musan.py +++ b/egs/fisher_swbd/ASR/local/compute_fbank_musan.py @@ -47,11 +47,7 @@ def compute_fbank_musan(): num_jobs = min(15, os.cpu_count()) num_mel_bins = 80 sampling_rate = 8000 - dataset_parts = ( - "music", - "speech", - "noise", - ) + dataset_parts = ("music", "speech", "noise") prefix = "musan" suffix = "jsonl.gz" manifests = read_manifests_if_cached( @@ -75,7 +71,9 @@ def compute_fbank_musan(): logging.info("Extracting features for Musan") - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=sampling_rate) + extractor = Fbank( + FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=sampling_rate) + ) with get_executor() as ex: # Initialize the executor only once. # create chunks of Musan with duration 5 - 10 seconds diff --git a/egs/fisher_swbd/ASR/local/extract_json_cuts.py b/egs/fisher_swbd/ASR/local/extract_json_cuts.py index 932f14714..df200c4ea 100644 --- a/egs/fisher_swbd/ASR/local/extract_json_cuts.py +++ b/egs/fisher_swbd/ASR/local/extract_json_cuts.py @@ -1,58 +1,56 @@ #!/usr/bin/env python3 - -# +# script to extract cutids corresponding to a list of source audio files. +# It takes three arguments: list of audio (.sph) , cut jsonl and out jsonl -import sys, json ; -import ntpath; +import sys, json +import ntpath -list_of_sph = sys.argv[1]; -jsonfile = sys.argv[2]; -out_partition_json = sys.argv[3]; +list_of_sph = sys.argv[1] +jsonfile = sys.argv[2] +out_partition_json = sys.argv[3] -list_of_sph=[line.rstrip('\n') for line in open(list_of_sph)] +list_of_sph = [line.rstrip("\n") for line in open(list_of_sph)] -sph_basename_list=[] +sph_basename_list = [] for f in list_of_sph: - bsname=ntpath.basename(f) - #print(bsname) + bsname = ntpath.basename(f) sph_basename_list.append(ntpath.basename(f)) -json_str=[line.rstrip('\n') for line in open(jsonfile)] +json_str = [line.rstrip("\n") for line in open(jsonfile)] num_json = len(json_str) -#cutid2sph=dict() -out_partition=open(out_partition_json,'w',encoding='utf-8') +out_partition = open(out_partition_json, "w", encoding="utf-8") for i in range(num_json): - if json_str[i] != '': - #print(json_str[i]) + if json_str[i] != "": + # print(json_str[i]) cur_json = json.loads(json_str[i]) - #print(cur_json) - cur_cutid= cur_json['id'] - cur_rec = cur_json['recording'] - cur_sources = cur_rec['sources'] - #print(cur_cutid) - #print(cur_rec) - #print(cur_sources) + # print(cur_json) + cur_cutid = cur_json["id"] + cur_rec = cur_json["recording"] + cur_sources = cur_rec["sources"] + # print(cur_cutid) + # print(cur_rec) + # print(cur_sources) for s in cur_sources: - cur_sph = s['source'] - cur_sph_basename=ntpath.basename(cur_sph) - #print(cur_sph) - #print(cur_sph_basename) - if cur_sph_basename in sph_basename_list : + cur_sph = s["source"] + cur_sph_basename = ntpath.basename(cur_sph) + # print(cur_sph) + # print(cur_sph_basename) + if cur_sph_basename in sph_basename_list: out_json_line = json_str[i] out_partition.write(out_json_line) out_partition.write("\n") - #for keys in cur_json: - #cur_cutid= cur_json['id'] - #cur_rec = cur_json['recording_id'] - #print(cur_cutid) - - + # for keys in cur_json: + # cur_cutid= cur_json['id'] + # cur_rec = cur_json['recording_id'] + # print(cur_cutid) + + """ for keys in cur_json: #print(keys) @@ -64,9 +62,3 @@ for i in range(num_json): out_partition.write(out_json_line) out_partition.write("\n") """ - - - - - - diff --git a/egs/fisher_swbd/ASR/local/extract_json_supervision.py b/egs/fisher_swbd/ASR/local/extract_json_supervision.py index 8dd714f23..0241e8012 100644 --- a/egs/fisher_swbd/ASR/local/extract_json_supervision.py +++ b/egs/fisher_swbd/ASR/local/extract_json_supervision.py @@ -1,40 +1,35 @@ #!/usr/bin/env python3 -# +# -import sys, json ; -import ntpath; +import sys, json +import ntpath -list_of_sph = sys.argv[1]; -jsonfile = sys.argv[2]; -out_partition_json = sys.argv[3]; +list_of_sph = sys.argv[1] +jsonfile = sys.argv[2] +out_partition_json = sys.argv[3] -list_of_sph=[line.rstrip('\n') for line in open(list_of_sph)] +list_of_sph = [line.rstrip("\n") for line in open(list_of_sph)] -sph_basename_list=[] +sph_basename_list = [] for f in list_of_sph: - bsname=ntpath.basename(f) - #print(bsname) + bsname = ntpath.basename(f) sph_basename_list.append(ntpath.basename(f)) -json_str=[line.rstrip('\n') for line in open(jsonfile)] +json_str = [line.rstrip("\n") for line in open(jsonfile)] num_json = len(json_str) -out_partition=open(out_partition_json,'w',encoding='utf-8') +out_partition = open(out_partition_json, "w", encoding="utf-8") for i in range(num_json): - if json_str[i] != '': - #print(json_str[i]) + if json_str[i] != "": cur_json = json.loads(json_str[i]) - #print(cur_json) - cur_rec = cur_json['recording_id'] - #print(cur_rec) + cur_rec = cur_json["recording_id"] cur_sph_basename = cur_rec + ".sph" - #print(cur_sph_basename) - if cur_sph_basename in sph_basename_list : + if cur_sph_basename in sph_basename_list: out_json_line = json_str[i] out_partition.write(out_json_line) out_partition.write("\n") diff --git a/egs/fisher_swbd/ASR/local/extract_list_of_sph.py b/egs/fisher_swbd/ASR/local/extract_list_of_sph.py index 708291820..d80cb8f6e 100644 --- a/egs/fisher_swbd/ASR/local/extract_list_of_sph.py +++ b/egs/fisher_swbd/ASR/local/extract_list_of_sph.py @@ -1,38 +1,20 @@ #!/usr/bin/env python3 - +# extract list of sph from a cut jsonl # python3 extract_list_of_sph.py dev_cuts_swbd.jsonl > data/fbank/dev_swbd_sph.list -import sys, json ; +import sys, json + inputfile = sys.argv[1] -json_str=[line.rstrip('\n') for line in open(inputfile)] +json_str = [line.rstrip("\n") for line in open(inputfile)] num_json = len(json_str) -#print(num_json) -#with open(inputfile, 'r',encoding='utf-8') as Jsonfile: -# print("Converting JSON encoded data into Python dictionary") -# json_dict = json.load(Jsonfile) -# for k,v in json_dict: -# print(k,v) - - - for i in range(num_json): - if json_str[i] != '': - #print(json_str[i]) + if json_str[i] != "": cur_json = json.loads(json_str[i]) - # print(cur_json) for keys in cur_json: - #print(keys) - cur_rec = cur_json['recording'] - cur_sources = cur_rec['sources'] - #print(cur_sources) + cur_rec = cur_json["recording"] + cur_sources = cur_rec["sources"] for s in cur_sources: - cur_sph = s['source'] + cur_sph = s["source"] print(cur_sph) - #cur_sph = cur_sources[2] - #print(cur_sph) - - - -#print(json.load(sys.stdin)['source']) diff --git a/egs/fisher_swbd/ASR/local/normalize_eval2000.py b/egs/fisher_swbd/ASR/local/normalize_eval2000.py index 244f20ea4..76efa176d 100644 --- a/egs/fisher_swbd/ASR/local/normalize_eval2000.py +++ b/egs/fisher_swbd/ASR/local/normalize_eval2000.py @@ -16,67 +16,72 @@ def get_args(): parser.add_argument("output_sups") return parser.parse_args() -def remove_punctutation_and_other_symbol(text:str) -> str: - text = text.replace("--"," ") - text = text.replace("//"," ") - text = text.replace("."," ") - text = text.replace("?"," ") - text = text.replace("~"," ") - text = text.replace(","," ") - text = text.replace(";"," ") - text = text.replace("("," ") - text = text.replace(")"," ") - text = text.replace("&"," ") - text = text.replace("%"," ") - text = text.replace("*"," ") - text = text.replace("{"," ") - text = text.replace("}"," ") + +def remove_punctutation_and_other_symbol(text: str) -> str: + text = text.replace("--", " ") + text = text.replace("//", " ") + text = text.replace(".", " ") + text = text.replace("?", " ") + text = text.replace("~", " ") + text = text.replace(",", " ") + text = text.replace(";", " ") + text = text.replace("(", " ") + text = text.replace(")", " ") + text = text.replace("&", " ") + text = text.replace("%", " ") + text = text.replace("*", " ") + text = text.replace("{", " ") + text = text.replace("}", " ") return text + def eval2000_clean_eform(text: str, eform_count) -> str: string_to_remove = [] - piece=text.split("\">") - for i in range(0,len(piece)): - s=piece[i]+"\">" - res = re.search(r'', s) + piece = text.split('">') + for i in range(0, len(piece)): + s = piece[i] + '">' + res = re.search(r"", s) if res is not None: - res_rm= res.group(1) + res_rm = res.group(1) string_to_remove.append(res_rm) for p in string_to_remove: eform_string = p text = text.replace(eform_string, " ") eform_1 = "" - text = text.replace(eform_2," ") - #print("TEXT final: ", text) + eform_2 = '">' + text = text.replace(eform_2, " ") + # print("TEXT final: ", text) return text - -def replace_silphone(text: str) -> str: + + +def replace_silphone(text: str) -> str: text = text.replace("[/BABY CRYING]", " ") - text = text.replace("[/CHILD]" , " ") - text = text.replace("[[DISTORTED]]" , " ") - text = text.replace("[/DISTORTION]" , " ") - text = text.replace("[[DRAWN OUT]]" , " ") - text = text.replace("[[DRAWN-OUT]]" , " ") - text = text.replace("[[FAINT]]" , " ") - text = text.replace("[SMACK]" , " ") - text = text.replace("[[MUMBLES]]" , " ") - text = text.replace("[[HIGH PITCHED SQUEAKY VOICE]]" , " ") - text = text.replace("[[IN THE LAUGH]]" , "[LAUGHTER]") - text = text.replace("[[LAST WORD SPOKEN WITH A LAUGH]]" , "[LAUGHTER]") - text = text.replace("[[PART OF FIRST SYLLABLE OF PREVIOUS WORD CUT OFF]]" , " ") - text = text.replace("[[PREVIOUS WORD SPOKEN WITH A LAUGH]]" , " ") - text = text.replace("[[PREVIOUS TWO WORDS SPOKEN WHILE LAUGHING]]" , " ") - text = text.replace("[[PROLONGED]]" , " ") - text = text.replace("[/RUNNING WATER]" , " ") - text = text.replace("[[SAYS LAUGHING]]" , "[LAUGHTER]") - text = text.replace("[[SINGING]]" , " ") - text = text.replace("[[SPOKEN WHILE LAUGHING]]" , "[LAUGHTER]") - text = text.replace("[/STATIC]" , " ") - text = text.replace("['THIRTIETH' DRAWN OUT]" , " ") - text = text.replace("[/VOICES]" , " ") - text = text.replace("[[WHISPERED]]" , " ") + text = text.replace("[/CHILD]", " ") + text = text.replace("[[DISTORTED]]", " ") + text = text.replace("[/DISTORTION]", " ") + text = text.replace("[[DRAWN OUT]]", " ") + text = text.replace("[[DRAWN-OUT]]", " ") + text = text.replace("[[FAINT]]", " ") + text = text.replace("[SMACK]", " ") + text = text.replace("[[MUMBLES]]", " ") + text = text.replace("[[HIGH PITCHED SQUEAKY VOICE]]", " ") + text = text.replace("[[IN THE LAUGH]]", "[LAUGHTER]") + text = text.replace("[[LAST WORD SPOKEN WITH A LAUGH]]", "[LAUGHTER]") + text = text.replace( + "[[PART OF FIRST SYLLABLE OF PREVIOUS WORD CUT OFF]]", " " + ) + text = text.replace("[[PREVIOUS WORD SPOKEN WITH A LAUGH]]", " ") + text = text.replace("[[PREVIOUS TWO WORDS SPOKEN WHILE LAUGHING]]", " ") + text = text.replace("[[PROLONGED]]", " ") + text = text.replace("[/RUNNING WATER]", " ") + text = text.replace("[[SAYS LAUGHING]]", "[LAUGHTER]") + text = text.replace("[[SINGING]]", " ") + text = text.replace("[[SPOKEN WHILE LAUGHING]]", "[LAUGHTER]") + text = text.replace("[/STATIC]", " ") + text = text.replace("['THIRTIETH' DRAWN OUT]", " ") + text = text.replace("[/VOICES]", " ") + text = text.replace("[[WHISPERED]]", " ") text = text.replace("[DISTORTION]", " ") text = text.replace("[DISTORTION, HIGH VOLUME ON WAVES]", " ") text = text.replace("[BACKGROUND LAUGHTER]", "[LAUGHTER]") @@ -95,24 +100,24 @@ def replace_silphone(text: str) -> str: text = text.replace("[BABY CRYING]", " ") text = text.replace("[METALLIC KNOCKING SOUND]", " ") text = text.replace("[METALLIC SOUND]", " ") - + text = text.replace("[PHONE JIGGLING]", " ") text = text.replace("[BACKGROUND SOUND]", " ") text = text.replace("[BACKGROUND VOICE]", " ") - text = text.replace("[BACKGROUND VOICES]", " ") + text = text.replace("[BACKGROUND VOICES]", " ") text = text.replace("[BACKGROUND NOISE]", " ") text = text.replace("[CAR HORNS IN BACKGROUND]", " ") text = text.replace("[CAR HORNS]", " ") - text = text.replace("[CARNATING]", " ") + text = text.replace("[CARNATING]", " ") text = text.replace("[CRYING CHILD]", " ") text = text.replace("[CHOPPING SOUND]", " ") text = text.replace("[BANGING]", " ") text = text.replace("[CLICKING NOISE]", " ") - text = text.replace("[CLATTERING]", " ") + text = text.replace("[CLATTERING]", " ") text = text.replace("[ECHO]", " ") - text = text.replace("[KNOCK]", " ") - text = text.replace("[NOISE-GOOD]", "[NOISE]") - text = text.replace("[RIGHT]", " ") + text = text.replace("[KNOCK]", " ") + text = text.replace("[NOISE-GOOD]", "[NOISE]") + text = text.replace("[RIGHT]", " ") text = text.replace("[SOUND]", " ") text = text.replace("[SQUEAK]", " ") text = text.replace("[STATIC]", " ") @@ -131,64 +136,65 @@ def replace_silphone(text: str) -> str: text = text.replace("Y[OU]I-", "YOU I") text = text.replace("-[A]ND", "AND") text = text.replace("JU[ST]", "JUST") - text = text.replace("{BREATH}" , " ") - text = text.replace("{BREATHY}" , " ") - text = text.replace("{CHANNEL NOISE}" , " ") - text = text.replace("{CLEAR THROAT}" , " ") + text = text.replace("{BREATH}", " ") + text = text.replace("{BREATHY}", " ") + text = text.replace("{CHANNEL NOISE}", " ") + text = text.replace("{CLEAR THROAT}", " ") - text = text.replace("{CLEARING THROAT}" , " ") - text = text.replace("{CLEARS THROAT}" , " ") - text = text.replace("{COUGH}" , " ") - text = text.replace("{DRAWN OUT}" , " ") - text = text.replace("{EXHALATION}" , " ") - text = text.replace("{EXHALE}" , " ") - text = text.replace("{GASP}" , " ") - text = text.replace("{HIGH SQUEAL}" , " ") - text = text.replace("{INHALE}" , " ") - text = text.replace("{LAUGH}" , "[LAUGHTER]") - text = text.replace("{LAUGH}" , "[LAUGHTER]") - text = text.replace("{LAUGH}" , "[LAUGHTER]") - text = text.replace("{LIPSMACK}" , " ") - text = text.replace("{LIPSMACK}" , " ") + text = text.replace("{CLEARING THROAT}", " ") + text = text.replace("{CLEARS THROAT}", " ") + text = text.replace("{COUGH}", " ") + text = text.replace("{DRAWN OUT}", " ") + text = text.replace("{EXHALATION}", " ") + text = text.replace("{EXHALE}", " ") + text = text.replace("{GASP}", " ") + text = text.replace("{HIGH SQUEAL}", " ") + text = text.replace("{INHALE}", " ") + text = text.replace("{LAUGH}", "[LAUGHTER]") + text = text.replace("{LAUGH}", "[LAUGHTER]") + text = text.replace("{LAUGH}", "[LAUGHTER]") + text = text.replace("{LIPSMACK}", " ") + text = text.replace("{LIPSMACK}", " ") - text = text.replace("{NOISE OF DISGUST}" , " ") - text = text.replace("{SIGH}" , " ") - text = text.replace("{SNIFF}" , " ") - text = text.replace("{SNORT}" , " ") - text = text.replace("{SHARP EXHALATION}" , " ") - text = text.replace("{BREATH LAUGH}" , " ") + text = text.replace("{NOISE OF DISGUST}", " ") + text = text.replace("{SIGH}", " ") + text = text.replace("{SNIFF}", " ") + text = text.replace("{SNORT}", " ") + text = text.replace("{SHARP EXHALATION}", " ") + text = text.replace("{BREATH LAUGH}", " ") return text -def remove_languagetag(text:str) -> str: - langtag = re.findall(r'<(.*?)>', text) + +def remove_languagetag(text: str) -> str: + langtag = re.findall(r"<(.*?)>", text) for t in langtag: text = text.replace(t, " ") - text = text.replace("<"," ") - text = text.replace(">"," ") + text = text.replace("<", " ") + text = text.replace(">", " ") return text - + + def eval2000_normalizer(text: str) -> str: - #print("TEXT original: ",text) - eform_count=text.count("contraction e_form") - #print("eform corunt:", eform_count) - if eform_count>0: - text = eval2000_clean_eform(text,eform_count) + # print("TEXT original: ",text) + eform_count = text.count("contraction e_form") + # print("eform corunt:", eform_count) + if eform_count > 0: + text = eval2000_clean_eform(text, eform_count) text = text.upper() text = remove_languagetag(text) text = replace_silphone(text) text = remove_punctutation_and_other_symbol(text) text = text.replace("IGNORE_TIME_SEGMENT_IN_SCORING", " ") text = text.replace("IGNORE_TIME_SEGMENT_SCORING", " ") - spaces = re.findall(r'\s+', text) + spaces = re.findall(r"\s+", text) for sp in spaces: - text = text.replace(sp," ") - text = text.strip() - #text = self.whitespace_regexp.sub(" ", text).strip() - #print(text) + text = text.replace(sp, " ") + text = text.strip() + # text = self.whitespace_regexp.sub(" ", text).strip() + # print(text) return text - def main(): args = get_args() sups = load_manifest_lazy_or_eager(args.input_sups) @@ -203,6 +209,7 @@ def main(): skip += 1 continue writer.write(sup) - + + if __name__ == "__main__": main() diff --git a/egs/fisher_swbd/ASR/local/prepare_lang_bpe.py b/egs/fisher_swbd/ASR/local/prepare_lang_bpe.py index b3f012e84..431ebd439 100755 --- a/egs/fisher_swbd/ASR/local/prepare_lang_bpe.py +++ b/egs/fisher_swbd/ASR/local/prepare_lang_bpe.py @@ -108,9 +108,7 @@ def lexicon_to_fst_no_sil( disambig_token = token2id["#0"] disambig_word = word2id["#0"] arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, + arcs, disambig_token=disambig_token, disambig_word=disambig_word ) final_state = next_state @@ -223,9 +221,7 @@ def main(): write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) L = lexicon_to_fst_no_sil( - lexicon, - token2id=token_sym_table, - word2id=word_sym_table, + lexicon, token2id=token_sym_table, word2id=word_sym_table ) L_disambig = lexicon_to_fst_no_sil( diff --git a/egs/fisher_swbd/ASR/local/prepare_lang_g2pen.py b/egs/fisher_swbd/ASR/local/prepare_lang_g2pen.py index 0549d7306..6a504d4c6 100755 --- a/egs/fisher_swbd/ASR/local/prepare_lang_g2pen.py +++ b/egs/fisher_swbd/ASR/local/prepare_lang_g2pen.py @@ -68,12 +68,7 @@ def get_args(): def get_g2p_sym2int(): # These symbols are removed from from g2p_en's vocabulary - excluded_symbols = [ - "", - "", - "", - "", - ] + excluded_symbols = ["", "", "", ""] symbols = [p for p in sorted(G2p().phonemes) if p not in excluded_symbols] # reserve 0 and 1 for blank and sos/eos/pad tokens @@ -345,9 +340,7 @@ def lexicon_to_fst( disambig_token = token2id["#0"] disambig_word = word2id["#0"] arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, + arcs, disambig_token=disambig_token, disambig_word=disambig_word ) final_state = next_state @@ -396,9 +389,7 @@ def main(): print(vocab[:10]) if not lexicon_filename.is_file(): - lexicon = [ - ("!SIL", [sil_token]), - ] + lexicon = [("!SIL", [sil_token])] for symbol in special_symbols: lexicon.append((symbol, [symbol[1:-1]])) lexicon += [ diff --git a/egs/fisher_swbd/ASR/local/train_bpe_model.py b/egs/fisher_swbd/ASR/local/train_bpe_model.py index bc5812810..2581a90cb 100755 --- a/egs/fisher_swbd/ASR/local/train_bpe_model.py +++ b/egs/fisher_swbd/ASR/local/train_bpe_model.py @@ -43,16 +43,10 @@ def get_args(): """, ) - parser.add_argument( - "--transcript", - type=str, - help="Training transcript.", - ) + parser.add_argument("--transcript", type=str, help="Training transcript.") parser.add_argument( - "--vocab-size", - type=int, - help="Vocabulary size for BPE training", + "--vocab-size", type=int, help="Vocabulary size for BPE training" ) return parser.parse_args() diff --git a/egs/fisher_swbd/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/fisher_swbd/ASR/pruned_transducer_stateless2/asr_datamodule.py index 706cdf34c..ef255cc47 100644 --- a/egs/fisher_swbd/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/fisher_swbd/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -1,6 +1,6 @@ # Copyright 2021 Piotr Żelasko # Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# +# # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -356,13 +356,10 @@ class FisherSwbdSpeechAsrDataModule: ) else: validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, + cut_transforms=transforms, return_cuts=self.args.return_cuts ) valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, + cuts_valid, max_duration=self.args.max_duration, shuffle=False ) logging.info("About to create dev dataloader") valid_dl = DataLoader( @@ -384,9 +381,7 @@ class FisherSwbdSpeechAsrDataModule: return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, + cuts, max_duration=self.args.max_duration, shuffle=False ) logging.debug("About to create test dataloader") test_dl = DataLoader( @@ -396,41 +391,52 @@ class FisherSwbdSpeechAsrDataModule: num_workers=self.args.num_workers, ) return test_dl - + @lru_cache() def train_fisher_cuts(self) -> CutSet: logging.info("About to get fisher cuts") return load_manifest_lazy( self.args.manifest_dir / "train_cuts_fisher.jsonl.gz" ) + @lru_cache() def train_swbd_cuts(self) -> CutSet: logging.info("About to get train swbd cuts") return load_manifest_lazy( self.args.manifest_dir / "train_cuts_swbd.jsonl.gz" ) + @lru_cache() def dev_fisher_cuts(self) -> CutSet: logging.info("About to get dev fisher cuts") - return load_manifest_lazy(self.args.manifest_dir / "dev_cuts_fisher.jsonl.gz" + return load_manifest_lazy( + self.args.manifest_dir / "dev_cuts_fisher.jsonl.gz" ) + @lru_cache() def dev_swbd_cuts(self) -> CutSet: logging.info("About to get dev swbd cuts") - return load_manifest_lazy(self.args.manifest_dir / "dev_cuts_swbd.jsonl.gz" + return load_manifest_lazy( + self.args.manifest_dir / "dev_cuts_swbd.jsonl.gz" ) + @lru_cache() def test_eval2000_cuts(self) -> CutSet: logging.info("About to get test eval2000 cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_eval2000.jsonl.gz" + return load_manifest_lazy( + self.args.manifest_dir / "cuts_eval2000.jsonl.gz" ) + @lru_cache() def test_swbd_cuts(self) -> CutSet: logging.info("About to get test eval2000 swbd cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_eval2000_swbd.jsonl.gz" + return load_manifest_lazy( + self.args.manifest_dir / "cuts_eval2000_swbd.jsonl.gz" ) + @lru_cache() def test_callhome_cuts(self) -> CutSet: logging.info("About to get test eval2000 callhome cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_eval2000_callhome.jsonl.gz" + return load_manifest_lazy( + self.args.manifest_dir / "cuts_eval2000_callhome.jsonl.gz" ) diff --git a/egs/fisher_swbd/ASR/pruned_transducer_stateless2/beam_search.py b/egs/fisher_swbd/ASR/pruned_transducer_stateless2/beam_search.py index ed6a6ea82..940da93ac 100644 --- a/egs/fisher_swbd/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/fisher_swbd/ASR/pruned_transducer_stateless2/beam_search.py @@ -550,9 +550,7 @@ def greedy_search( def greedy_search_batch( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, + model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor ) -> List[List[int]]: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. Args: @@ -591,9 +589,7 @@ def greedy_search_batch( hyps = [[blank_id] * context_size for _ in range(N)] decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, + hyps, device=device, dtype=torch.int64 ) # (N, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -630,9 +626,7 @@ def greedy_search_batch( # update decoder output decoder_input = [h[-context_size:] for h in hyps[:batch_size]] decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, + decoder_input, device=device, dtype=torch.int64 ) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -894,9 +888,7 @@ def modified_beam_search( ) # (num_hyps, 1, 1, encoder_out_dim) logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, + current_encoder_out, decoder_out, project_input=False ) # (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) @@ -953,9 +945,7 @@ def modified_beam_search( def _deprecated_modified_beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 4, + model: Transducer, encoder_out: torch.Tensor, beam: int = 4 ) -> List[int]: """It limits the maximum number of symbols per frame to 1. @@ -1023,9 +1013,7 @@ def _deprecated_modified_beam_search( ) # (num_hyps, 1, 1, encoder_out_dim) logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, + current_encoder_out, decoder_out, project_input=False ) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -1097,9 +1085,7 @@ def beam_search( device = next(model.parameters()).device decoder_input = torch.tensor( - [blank_id] * context_size, - device=device, - dtype=torch.int64, + [blank_id] * context_size, device=device, dtype=torch.int64 ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -1318,9 +1304,7 @@ def fast_beam_search_with_nbest_rescoring( num_unique_paths = len(word_ids_list) b_to_a_map = torch.zeros( - num_unique_paths, - dtype=torch.int32, - device=lattice.device, + num_unique_paths, dtype=torch.int32, device=lattice.device ) rescored_word_fsas = k2.intersect_device( @@ -1334,8 +1318,7 @@ def fast_beam_search_with_nbest_rescoring( rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) ngram_lm_scores = rescored_word_fsas.get_tot_scores( - use_double_scores=True, - log_semiring=False, + use_double_scores=True, log_semiring=False ) ans: Dict[str, List[List[int]]] = {} diff --git a/egs/fisher_swbd/ASR/pruned_transducer_stateless2/conformer.py b/egs/fisher_swbd/ASR/pruned_transducer_stateless2/conformer.py index fb8123838..f40a950c6 100644 --- a/egs/fisher_swbd/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/fisher_swbd/ASR/pruned_transducer_stateless2/conformer.py @@ -223,19 +223,10 @@ class Conformer(EncoderInterface): init_states: List[torch.Tensor] = [ torch.zeros( - ( - self.encoder_layers, - left_context, - self.d_model, - ), - device=device, + (self.encoder_layers, left_context, self.d_model), device=device ), torch.zeros( - ( - self.encoder_layers, - self.cnn_module_kernel - 1, - self.d_model, - ), + (self.encoder_layers, self.cnn_module_kernel - 1, self.d_model), device=device, ), ] @@ -330,7 +321,9 @@ class Conformer(EncoderInterface): {(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)}, given {states[1].shape}.""" - lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output + lengths -= ( + 2 + ) # we will cut off 1 frame on each side of encoder_embed output src_key_padding_mask = make_pad_mask(lengths) @@ -829,9 +822,7 @@ class RelPositionalEncoding(torch.nn.Module): self.pe = pe.to(device=x.device, dtype=x.dtype) def forward( - self, - x: torch.Tensor, - left_context: int = 0, + self, x: torch.Tensor, left_context: int = 0 ) -> Tuple[Tensor, Tensor]: """Add positional encoding. @@ -875,10 +866,7 @@ class RelPositionMultiheadAttention(nn.Module): """ def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, + self, embed_dim: int, num_heads: int, dropout: float = 0.0 ) -> None: super(RelPositionMultiheadAttention, self).__init__() self.embed_dim = embed_dim @@ -1272,8 +1260,7 @@ class RelPositionMultiheadAttention(nn.Module): bsz, num_heads, tgt_len, src_len ) attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float("-inf"), + key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf") ) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len @@ -1420,10 +1407,7 @@ class ConvolutionModule(nn.Module): ) def forward( - self, - x: Tensor, - cache: Optional[Tensor] = None, - right_context: int = 0, + self, x: Tensor, cache: Optional[Tensor] = None, right_context: int = 0 ) -> Tuple[Tensor, Tensor]: """Compute convolution module. diff --git a/egs/fisher_swbd/ASR/pruned_transducer_stateless2/decode.py b/egs/fisher_swbd/ASR/pruned_transducer_stateless2/decode.py index 0ebe20372..e5a4fc313 100755 --- a/egs/fisher_swbd/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/fisher_swbd/ASR/pruned_transducer_stateless2/decode.py @@ -384,9 +384,7 @@ def decode_one_batch( feature_lens += params.left_context feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, + feature, pad=(0, 0, 0, params.left_context), value=LOG_EPS ) if params.simulate_streaming: @@ -778,7 +776,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - + fisherswbd = FisherSwbdSpeechAsrDataModule(args) test_eval2000_cuts = fisherswbd.test_eval2000_cuts() @@ -803,9 +801,7 @@ def main(): ) save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, + params=params, test_set_name=test_set, results_dict=results_dict ) logging.info("Done!") diff --git a/egs/fisher_swbd/ASR/pruned_transducer_stateless2/decode_stream.py b/egs/fisher_swbd/ASR/pruned_transducer_stateless2/decode_stream.py index ba5e80555..2e90cdac4 100644 --- a/egs/fisher_swbd/ASR/pruned_transducer_stateless2/decode_stream.py +++ b/egs/fisher_swbd/ASR/pruned_transducer_stateless2/decode_stream.py @@ -92,10 +92,7 @@ class DecodeStream(object): """Return True if all the features are processed.""" return self._done - def set_features( - self, - features: torch.Tensor, - ) -> None: + def set_features(self, features: torch.Tensor) -> None: """Set features tensor of current utterance.""" assert features.dim() == 2, features.dim() self.features = torch.nn.functional.pad( diff --git a/egs/fisher_swbd/ASR/pruned_transducer_stateless2/pretrained.py b/egs/fisher_swbd/ASR/pruned_transducer_stateless2/pretrained.py index f52cb22ab..99cf5273d 100755 --- a/egs/fisher_swbd/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/fisher_swbd/ASR/pruned_transducer_stateless2/pretrained.py @@ -96,11 +96,7 @@ def get_parser(): "icefall.checkpoint.save_checkpoint().", ) - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) + parser.add_argument("--bpe-model", type=str, help="""Path to bpe.model.""") parser.add_argument( "--method", diff --git a/egs/fisher_swbd/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/fisher_swbd/ASR/pruned_transducer_stateless2/streaming_decode.py index 52fe34e88..76d50d28d 100755 --- a/egs/fisher_swbd/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/fisher_swbd/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -39,7 +39,8 @@ import numpy as np import sentencepiece as spm import torch import torch.nn as nn -#from asr_datamodule import LibriSpeechAsrDataModule + +# from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import FisherSwbdSpeechAsrDataModule from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions @@ -187,9 +188,7 @@ def get_parser(): def greedy_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[DecodeStream], + model: nn.Module, encoder_out: torch.Tensor, streams: List[DecodeStream] ) -> List[List[int]]: assert len(streams) == encoder_out.size(0) @@ -236,10 +235,7 @@ def greedy_search( device=device, dtype=torch.int64, ) - decoder_out = model.decoder( - decoder_input, - need_pad=False, - ) + decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) hyp_tokens = [] @@ -290,9 +286,7 @@ def fast_beam_search( def decode_one_chunk( - params: AttributeDict, - model: nn.Module, - decode_streams: List[DecodeStream], + params: AttributeDict, model: nn.Module, decode_streams: List[DecodeStream] ) -> List[int]: """Decode one chunk frames of features for each decode_streams and return the indexes of finished streams in a List. @@ -502,10 +496,7 @@ def decode_dataset( if params.decoding_method == "greedy_search": hyp = hyp[params.context_size :] # noqa decode_results.append( - ( - decode_streams[i].ground_truth.split(), - sp.decode(hyp).split(), - ) + (decode_streams[i].ground_truth.split(), sp.decode(hyp).split()) ) del decode_streams[i] @@ -661,7 +652,7 @@ def main(): fisherswbd = FisherSwbdSpeechAsrDataModule(args) test_eval2000_cuts = fisherswbd.test_eval2000_cuts() - test_swbd_cuts = fisherswbd.test_swbd_cuts () + test_swbd_cuts = fisherswbd.test_swbd_cuts() test_callhome_cuts = fisherswbd.test_callhome_cuts() test_eval2000_dl = fisherswbd.test_dataloaders(test_eval2000_cuts) @@ -681,9 +672,7 @@ def main(): ) save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, + params=params, test_set_name=test_set, results_dict=results_dict ) logging.info("Done!") diff --git a/egs/fisher_swbd/ASR/pruned_transducer_stateless2/train.py b/egs/fisher_swbd/ASR/pruned_transducer_stateless2/train.py index 354bd1d74..5b5af6720 100755 --- a/egs/fisher_swbd/ASR/pruned_transducer_stateless2/train.py +++ b/egs/fisher_swbd/ASR/pruned_transducer_stateless2/train.py @@ -155,10 +155,7 @@ def get_parser(): ) parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", + "--num-epochs", type=int, default=30, help="Number of epochs to train." ) parser.add_argument( @@ -480,10 +477,7 @@ def load_checkpoint_if_available( assert filename.is_file(), f"{filename} does not exist!" saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, + filename, model=model, optimizer=optimizer, scheduler=scheduler ) keys = [ @@ -646,11 +640,7 @@ def compute_validation_loss( for batch_idx, batch in enumerate(valid_dl): loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, + params=params, model=model, sp=sp, batch=batch, is_training=False ) assert loss.requires_grad is False tot_loss = tot_loss + loss_info @@ -767,9 +757,7 @@ def train_one_epoch( ) del params.cur_batch_idx remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, + out_dir=params.exp_dir, topk=params.keep_last_k, rank=rank ) if batch_idx % params.log_interval == 0: @@ -830,8 +818,6 @@ def run(rank, world_size, args): """ params = get_params() params.update(vars(args)) - if params.full_libri is False: - params.valid_interval = 1600 fix_random_seed(params.seed) if world_size > 1: @@ -897,11 +883,11 @@ def run(rank, world_size, args): if params.print_diagnostics: diagnostic = diagnostics.attach_diagnostics(model) - librispeech = FisherSwbdSpeechAsrDataModule(args) + fisherswbd = FisherSwbdSpeechAsrDataModule(args) train_cuts = fisherswbd.train_fisher_cuts() train_cuts += fisherswbd.train_swbd_cuts() - + def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds # @@ -991,9 +977,7 @@ def run(rank, world_size, args): def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, + batch: dict, params: AttributeDict, sp: spm.SentencePieceProcessor ) -> None: """Display the batch statistics and save the batch into disk.