From 891cf5590172d5e6594c7e52c9d2147612a50ac2 Mon Sep 17 00:00:00 2001 From: AmirHussein96 Date: Fri, 5 Apr 2024 13:00:29 -0400 Subject: [PATCH] black formating --- egs/seame/ASR/local/cer.py | 32 ++- .../ASR/local/compute_fbank_gpu_seame.py | 29 ++- .../local/compute_fbank_gpu_seame_sample.py | 25 +-- egs/seame/ASR/local/compute_fbank_musan.py | 2 +- egs/seame/ASR/local/cuts_validate.py | 25 ++- egs/seame/ASR/local/prepare_lexicon.py | 4 +- egs/seame/ASR/local/prepare_transcripts.py | 15 +- egs/seame/ASR/local/sample_hours.py | 18 +- egs/seame/ASR/local/train_bpe_model.py | 3 +- egs/seame/ASR/local/wer_lang.py | 149 +++++++------ egs/seame/ASR/zipformer/asr_datamodule.py | 17 +- egs/seame/ASR/zipformer/decode.py | 11 +- egs/seame/ASR/zipformer/train.py | 1 + egs/seame/ASR/zipformer_hat/beam_search.py | 8 +- egs/seame/ASR/zipformer_hat/decode.py | 29 +-- egs/seame/ASR/zipformer_hat/train.py | 15 +- .../ASR/zipformer_hat_lid/beam_search.py | 16 +- egs/seame/ASR/zipformer_hat_lid/decode.py | 203 ++++++++++-------- egs/seame/ASR/zipformer_hat_lid/joiner.py | 3 +- egs/seame/ASR/zipformer_hat_lid/model.py | 92 ++++---- egs/seame/ASR/zipformer_hat_lid/train.py | 48 +++-- 21 files changed, 387 insertions(+), 358 deletions(-) diff --git a/egs/seame/ASR/local/cer.py b/egs/seame/ASR/local/cer.py index 01ba53fd9..b57c4d4b5 100644 --- a/egs/seame/ASR/local/cer.py +++ b/egs/seame/ASR/local/cer.py @@ -8,18 +8,13 @@ This file cer from icefall decoded "recogs" file: id [hyp] yxy """ -import argparse +import argparse import jiwer def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--dec-file", - type=str, - help="Decoded icefall recogs file" - - ) + parser.add_argument("--dec-file", type=str, help="Decoded icefall recogs file") return parser @@ -29,31 +24,32 @@ def cer_(file): ref = [] cer_results = 0 ref_lens = 0 - with open(file, 'r', encoding='utf-8') as dec: + with open(file, "r", encoding="utf-8") as dec: for line in dec: - id, target = line.split('\t') + id, target = line.split("\t") id = id[0:-2] target, txt = target.split("=") - if target == 'ref': - words = txt.strip().strip('[]').split(', ') + if target == "ref": + words = txt.strip().strip("[]").split(", ") word_list = [word.strip("'") for word in words] ref.append(" ".join(word_list)) - elif target == 'hyp': - words = txt.strip().strip('[]').split(', ') + elif target == "hyp": + words = txt.strip().strip("[]").split(", ") word_list = [word.strip("'") for word in words] hyp.append(" ".join(word_list)) for h, r in zip(hyp, ref): if r: - cer_results += (jiwer.cer(r, h)*len(r)) - + cer_results += jiwer.cer(r, h) * len(r) + ref_lens += len(r) print(cer_results / ref_lens) def main(): parse = get_args() - args = parse.parse_args() + args = parse.parse_args() cer_(args.dec_file) - + + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/egs/seame/ASR/local/compute_fbank_gpu_seame.py b/egs/seame/ASR/local/compute_fbank_gpu_seame.py index 1e4ace80c..84134ab78 100755 --- a/egs/seame/ASR/local/compute_fbank_gpu_seame.py +++ b/egs/seame/ASR/local/compute_fbank_gpu_seame.py @@ -38,6 +38,7 @@ from lhotse.features.kaldifeat import ( KaldifeatMelOptions, ) + def get_args(): parser = argparse.ArgumentParser() parser.add_argument( @@ -70,7 +71,7 @@ def get_args(): def compute_fbank_gpu(args): src_dir = Path("data_seame/manifests") output_dir = Path("data_seame/fbank") - num_jobs = min(os.cpu_count(),10) + num_jobs = min(os.cpu_count(), 10) num_mel_bins = 80 sampling_rate = 16000 sr = 16000 @@ -87,7 +88,10 @@ def compute_fbank_gpu(args): suffix = "jsonl.gz" breakpoint manifests = read_manifests_if_cached( - prefix=prefix, dataset_parts=dataset_parts, output_dir=src_dir,suffix=suffix, + prefix=prefix, + dataset_parts=dataset_parts, + output_dir=src_dir, + suffix=suffix, ) assert manifests is not None @@ -116,15 +120,11 @@ def compute_fbank_gpu(args): cut_set = cut_set.resample(sr) cut_set = cut_set.trim_to_supervisions( - keep_overlapping=False, - keep_all_channels=False) - cut_set = cut_set.filter(lambda c: c.duration >= .2 and c.duration <= 30) + keep_overlapping=False, keep_all_channels=False + ) + cut_set = cut_set.filter(lambda c: c.duration >= 0.2 and c.duration <= 30) if "train" in partition: - cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) - ) + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, storage_path=f"{output_dir}/{prefix}_feats_{partition}", @@ -133,7 +133,7 @@ def compute_fbank_gpu(args): num_workers=num_jobs, storage_type=LilcomChunkyWriter, overwrite=True, - ) + ) cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz") else: logging.info(f"Processing {partition}") @@ -144,13 +144,12 @@ def compute_fbank_gpu(args): num_workers=num_jobs, storage_type=LilcomChunkyWriter, overwrite=True, - ) + ) cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz") + if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() diff --git a/egs/seame/ASR/local/compute_fbank_gpu_seame_sample.py b/egs/seame/ASR/local/compute_fbank_gpu_seame_sample.py index 4c03da0d1..cbeda9671 100755 --- a/egs/seame/ASR/local/compute_fbank_gpu_seame_sample.py +++ b/egs/seame/ASR/local/compute_fbank_gpu_seame_sample.py @@ -37,6 +37,7 @@ from lhotse.features.kaldifeat import ( KaldifeatMelOptions, ) + def get_args(): parser = argparse.ArgumentParser() parser.add_argument( @@ -69,7 +70,7 @@ def get_args(): def compute_fbank_gpu(args): src_dir = Path("data_seame/manifests") output_dir = Path("data_seame/fbank") - num_jobs = min(os.cpu_count(),10) + num_jobs = min(os.cpu_count(), 10) num_mel_bins = 80 sampling_rate = 16000 sr = 16000 @@ -80,7 +81,6 @@ def compute_fbank_gpu(args): "train10", "train50", "train30", - ) prefix = "" suffix = "jsonl.gz" @@ -103,15 +103,11 @@ def compute_fbank_gpu(args): cut_set = cut_set.resample(sr) cut_set = cut_set.trim_to_supervisions( - keep_overlapping=False, - keep_all_channels=False) - cut_set = cut_set.filter(lambda c: c.duration >= .5 and c.duration <= 30) + keep_overlapping=False, keep_all_channels=False + ) + cut_set = cut_set.filter(lambda c: c.duration >= 0.5 and c.duration <= 30) if "train" in part: - cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) - ) + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, storage_path=f"{output_dir}/{prefix}_feats_{part}", @@ -119,7 +115,7 @@ def compute_fbank_gpu(args): num_workers=num_jobs, storage_type=LilcomChunkyWriter, overwrite=True, - ) + ) cut_set.to_file(output_dir / f"cuts_{part}.jsonl.gz") else: logging.info(f"Processing {part}") @@ -131,13 +127,12 @@ def compute_fbank_gpu(args): num_workers=num_jobs, storage_type=LilcomChunkyWriter, overwrite=True, - ) + ) cut_set.to_file(output_dir / f"cuts_{part}.jsonl.gz") + if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() diff --git a/egs/seame/ASR/local/compute_fbank_musan.py b/egs/seame/ASR/local/compute_fbank_musan.py index 48905de6f..e1b104cc7 100755 --- a/egs/seame/ASR/local/compute_fbank_musan.py +++ b/egs/seame/ASR/local/compute_fbank_musan.py @@ -106,4 +106,4 @@ if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_musan() \ No newline at end of file + compute_fbank_musan() diff --git a/egs/seame/ASR/local/cuts_validate.py b/egs/seame/ASR/local/cuts_validate.py index f5cfb4728..8117a2364 100644 --- a/egs/seame/ASR/local/cuts_validate.py +++ b/egs/seame/ASR/local/cuts_validate.py @@ -7,7 +7,6 @@ from lhotse.qa import fix_manifests, validate_recordings_and_supervisions import pdb - def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -44,7 +43,7 @@ def get_parser(): def valid_asr(cut): tol = 2e-3 - i=0 + i = 0 total_dur = 0 for c in cut: if c.supervisions != []: @@ -52,10 +51,14 @@ def valid_asr(cut): logging.info(f"Supervision beyond the cut. Cut number: {i}") total_dur += c.duration - logging.info(f"id: {c.id}, sup_end: {c.supervisions[0].end}, dur: {c.duration}, source {c.recording.sources[0].source}") + logging.info( + f"id: {c.id}, sup_end: {c.supervisions[0].end}, dur: {c.duration}, source {c.recording.sources[0].source}" + ) elif c.supervisions[0].start < -tol: logging.info(f"Supervision starts before the cut. Cut number: {i}") - logging.info(f"id: {c.id}, sup_start: {c.supervisions[0].start}, dur: {c.duration}, source {c.recording.sources[0].source}") + logging.info( + f"id: {c.id}, sup_start: {c.supervisions[0].start}, dur: {c.duration}, source {c.recording.sources[0].source}" + ) else: continue else: @@ -63,7 +66,7 @@ def valid_asr(cut): logging.info(f"id: {c.id}") i += 1 logging.info(f"filtered duration: {total_dur}") - + def main(): @@ -74,7 +77,7 @@ def main(): else: recordings = RecordingSet.from_file(args.rec) supervisions = SupervisionSet.from_file(args.sup) - # breakpoint() + # breakpoint() logging.info("Example from supervisions:") logging.info(supervisions[0]) logging.info("Example from recordings") @@ -82,8 +85,11 @@ def main(): recordings, supervisions = fix_manifests(recordings, supervisions) logging.info("Validating manifests") validate_recordings_and_supervisions(recordings, supervisions) - - cuts = CutSet.from_manifests(recordings= recordings, supervisions=supervisions,) + + cuts = CutSet.from_manifests( + recordings=recordings, + supervisions=supervisions, + ) cuts = cuts.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False) cuts.describe() logging.info("Example from cut:") @@ -93,5 +99,6 @@ def main(): if args.savecut != "": cuts.to_file(args.savecut) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/egs/seame/ASR/local/prepare_lexicon.py b/egs/seame/ASR/local/prepare_lexicon.py index 807579503..1997f2741 100755 --- a/egs/seame/ASR/local/prepare_lexicon.py +++ b/egs/seame/ASR/local/prepare_lexicon.py @@ -25,9 +25,7 @@ def main(): for line in f: line = line.strip() characters = list(line) - characters = " ".join( - ["V" if char == "*" else char for char in characters] - ) + characters = " ".join(["V" if char == "*" else char for char in characters]) lex[line] = characters with open(args.output, "w", encoding="utf-8") as fp: diff --git a/egs/seame/ASR/local/prepare_transcripts.py b/egs/seame/ASR/local/prepare_transcripts.py index a9da2d695..b97d65c01 100755 --- a/egs/seame/ASR/local/prepare_transcripts.py +++ b/egs/seame/ASR/local/prepare_transcripts.py @@ -30,7 +30,7 @@ def get_parser(): help="name of the lang-dir", ) return parser - + def main(): @@ -40,15 +40,16 @@ def main(): logging.info("Reading the cuts") cuts = CutSet.from_file(args.cut) langdir = Path(args.langdir) - + if not os.path.exists(langdir): os.makedirs(langdir) - - with open(langdir / "transcript_words.txt", 'w') as txt: + + with open(langdir / "transcript_words.txt", "w") as txt: for c in cuts: - #breakpoint() + # breakpoint() text = c.supervisions[0].text - txt.write(text + '\n') + txt.write(text + "\n") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/egs/seame/ASR/local/sample_hours.py b/egs/seame/ASR/local/sample_hours.py index 93dcf040a..c442d6ede 100644 --- a/egs/seame/ASR/local/sample_hours.py +++ b/egs/seame/ASR/local/sample_hours.py @@ -50,7 +50,7 @@ def get_parser(): ) return parser - + def main(): @@ -61,15 +61,20 @@ def main(): logging.info(f"Loading {args.cut}") cuts = CutSet.from_file(args.cut) outdir = Path(os.path.dirname(args.cut)) - + else: outdir = Path(os.path.dirname(args.sup)) logging.info(f"Loading supervisions") recordings = RecordingSet.from_file(args.rec) supervisions = SupervisionSet.from_file(args.sup) logging.info("Fixing manifests") - cuts = CutSet.from_manifests(recordings= recordings, supervisions=supervisions,) - cuts = cuts.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False) + cuts = CutSet.from_manifests( + recordings=recordings, + supervisions=supervisions, + ) + cuts = cuts.trim_to_supervisions( + keep_overlapping=False, keep_all_channels=False + ) shuffled = cuts.shuffle() total_dur = 0 @@ -82,9 +87,10 @@ def main(): break cuts = cuts.filter(lambda c: c.id in cuts_list) cuts.describe() - + logging.info(f"Saving {args.outcut}") cuts.to_file(outdir / args.outcut) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/egs/seame/ASR/local/train_bpe_model.py b/egs/seame/ASR/local/train_bpe_model.py index 2594158bd..71d258786 100755 --- a/egs/seame/ASR/local/train_bpe_model.py +++ b/egs/seame/ASR/local/train_bpe_model.py @@ -91,7 +91,7 @@ def main(): user_defined_symbols = ["", ""] unk_id = len(user_defined_symbols) if predef_sym: - syms = predef_sym.split(',') + syms = predef_sym.split(",") for i in syms: user_defined_symbols.append(i) # Note: unk_id is fixed to 2. @@ -116,5 +116,6 @@ def main(): shutil.copyfile(model_file, f"{lang_dir}/bpe.model") generate_tokens(lang_dir) + if __name__ == "__main__": main() diff --git a/egs/seame/ASR/local/wer_lang.py b/egs/seame/ASR/local/wer_lang.py index 30f5114d3..d55790a8e 100644 --- a/egs/seame/ASR/local/wer_lang.py +++ b/egs/seame/ASR/local/wer_lang.py @@ -25,29 +25,31 @@ def get_parser(): ) return parser + lids = "en,zh" -lids_dict = {lid:id+1 for id, lid in enumerate(lids.split(","))} -id2lang = {id+1: lid for id, lid in enumerate(lids.split(","))} +lids_dict = {lid: id + 1 for id, lid in enumerate(lids.split(","))} +id2lang = {id + 1: lid for id, lid in enumerate(lids.split(","))} bad_id = [] + def extract_info(line, info): # Split the line at the first colon to separate the ID - id_part, rest = line.split(':', 1) - + id_part, rest = line.split(":", 1) + # Extract 'ref' by finding its start and end ref_start = rest.find(info) - ref_end = rest.find(']', ref_start) - ref = rest[ref_start+len(info):ref_end].replace("'", "").split(', ') - - # Extract 'lid' - if 'lid=' in rest: - lid_start = rest.find('lid=[') - lid_end = rest.find(']', lid_start) - lid = rest[lid_start+len('lid=['):lid_end].split(', ') - else: - lid = [''] + ref_end = rest.find("]", ref_start) + ref = rest[ref_start + len(info) : ref_end].replace("'", "").split(", ") - if lid[0]=='': + # Extract 'lid' + if "lid=" in rest: + lid_start = rest.find("lid=[") + lid_end = rest.find("]", lid_start) + lid = rest[lid_start + len("lid=[") : lid_end].split(", ") + else: + lid = [""] + + if lid[0] == "": bad_id.append(id_part) if " ".join(lid): lid = [int(i) for i in lid] # Convert each element to integer @@ -58,6 +60,7 @@ def is_English(c): """check character is in English""" return ord(c.lower()) >= ord("a") and ord(c.lower()) <= ord("z") + def get_en(text): res = [] for w in text: @@ -68,6 +71,7 @@ def get_en(text): continue return res + def get_zh(text): res = [] for w in text: @@ -79,84 +83,84 @@ def get_zh(text): return res - def extract_info_lid(line, tag): # Split the line at the first colon to separate the ID - id_part, rest = line.split(':', 1) - + id_part, rest = line.split(":", 1) + # Extract 'ref' by finding its start and end - + ref_start = rest.find(tag) - ref_end = rest.find(']', ref_start) - ref = rest[ref_start+len(tag):ref_end].replace("'", "").split(', ') - + ref_end = rest.find("]", ref_start) + ref = rest[ref_start + len(tag) : ref_end].replace("'", "").split(", ") + return id_part.strip(), ref def align_lid2(labels_a, labels_b, a, b): - # Alignment - EPS = '*' - ali = align(a, b, EPS, sclite_mode=True) + # Alignment + EPS = "*" + ali = align(a, b, EPS, sclite_mode=True) - a2idx = {(i,idx):j for idx,(i,j) in enumerate(zip(a,labels_a))} - b2idx = {(i,idx):j for idx,(i,j) in enumerate(zip(b,labels_b))} - # Comparing labels of aligned elements - idx_a = 0 - idx_b = 0 - ali_idx=0 - aligned_a = [] - aligned_b = [] - while idx_a CutSet: logging.info("Train data: About to get training cuts") - return load_manifest_lazy( - self.args.manifest_dir / "cuts_train_shuf.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz") @lru_cache() def valid_cuts(self) -> CutSet: logging.info("Dev data: About to get develop cuts") - return load_manifest_lazy( - self.args.manifest_dir / "cuts_valid.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "cuts_valid.jsonl.gz") @lru_cache() def dev_man(self) -> CutSet: logging.info("About to get dev_man cuts") - return load_manifest_lazy( - self.args.manifest_dir / "cuts_dev_man.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "cuts_dev_man.jsonl.gz") + def dev_sge(self) -> CutSet: logging.info("About to get dev_sge cuts") - return load_manifest_lazy( - self.args.manifest_dir / "cuts_dev_sge.jsonl.gz" - ) \ No newline at end of file + return load_manifest_lazy(self.args.manifest_dir / "cuts_dev_sge.jsonl.gz") diff --git a/egs/seame/ASR/zipformer/decode.py b/egs/seame/ASR/zipformer/decode.py index cc3747eb9..70ce05160 100755 --- a/egs/seame/ASR/zipformer/decode.py +++ b/egs/seame/ASR/zipformer/decode.py @@ -111,6 +111,7 @@ import re LOG_EPS = math.log(1e-10) + def remove_punc(text): """This function removes all English punctuations except the single quote (verbatim).""" @@ -119,20 +120,22 @@ def remove_punc(text): # english_punctuations = english_punctuations.replace("'", "") # Create a translation table that maps each punctuation to a space. - translator = str.maketrans(english_punctuations, ' ' * len(english_punctuations)) - + translator = str.maketrans(english_punctuations, " " * len(english_punctuations)) + # Translate the text using the translation table text = text.translate(translator) - + return text + def clean(text): text = remove_punc(text) text = text.lower() - text = re.sub(r'\s+', ' ', text) + text = re.sub(r"\s+", " ", text) text = text.rstrip() return text + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter diff --git a/egs/seame/ASR/zipformer/train.py b/egs/seame/ASR/zipformer/train.py index 39066b477..84748c8c9 100755 --- a/egs/seame/ASR/zipformer/train.py +++ b/egs/seame/ASR/zipformer/train.py @@ -1384,5 +1384,6 @@ def main(): else: run(rank=0, world_size=1, args=args) + if __name__ == "__main__": main() diff --git a/egs/seame/ASR/zipformer_hat/beam_search.py b/egs/seame/ASR/zipformer_hat/beam_search.py index 5e2eecd3a..ed6d8963a 100644 --- a/egs/seame/ASR/zipformer_hat/beam_search.py +++ b/egs/seame/ASR/zipformer_hat/beam_search.py @@ -800,7 +800,7 @@ def modified_beam_search_lm_shallow_fusion( hyps=ans, timestamps=ans_timestamps, ) - + def modified_beam_search_lm_rescore_LODR( model: nn.Module, @@ -924,9 +924,9 @@ def modified_beam_search_lm_rescore_LODR( # is equivalent to log(1 - sigmoid(logits[..., 0])). nb_shift = logp_b - logits[..., 0] nb_shift = nb_shift.unsqueeze(-1) - log_probs1 = (logits[..., 1:]/ temperature).log_softmax(dim=-1) + nb_shift + log_probs1 = (logits[..., 1:] / temperature).log_softmax(dim=-1) + nb_shift - #log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + # log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1) log_probs.add_(ys_log_probs) @@ -1333,4 +1333,4 @@ def modified_beam_search_LODR( for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) - return ans \ No newline at end of file + return ans diff --git a/egs/seame/ASR/zipformer_hat/decode.py b/egs/seame/ASR/zipformer_hat/decode.py index b14f37910..c473292e3 100755 --- a/egs/seame/ASR/zipformer_hat/decode.py +++ b/egs/seame/ASR/zipformer_hat/decode.py @@ -91,6 +91,7 @@ import re LOG_EPS = math.log(1e-10) + def remove_punc(text): """This function removes all English punctuations except the single quote (verbatim).""" @@ -99,20 +100,22 @@ def remove_punc(text): # english_punctuations = english_punctuations.replace("'", "") # Create a translation table that maps each punctuation to a space. - translator = str.maketrans(english_punctuations, ' ' * len(english_punctuations)) - + translator = str.maketrans(english_punctuations, " " * len(english_punctuations)) + # Translate the text using the translation table text = text.translate(translator) - + return text + def clean(text): text = remove_punc(text) text = text.lower() - text = re.sub(r'\s+', ' ', text) + text = re.sub(r"\s+", " ", text) text = text.rstrip() return text + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -485,8 +488,8 @@ def decode_one_batch( hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": - lm_scale_list = [0.05 * i for i in range(4, 10)] - hyp_tokens = modified_beam_search_lm_rescore_LODR( + lm_scale_list = [0.05 * i for i in range(4, 10)] + hyp_tokens = modified_beam_search_lm_rescore_LODR( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, @@ -496,7 +499,7 @@ def decode_one_batch( sp=sp, lm_scale_list=lm_scale_list, ) - for hyp in sp.decode(hyp_tokens): + for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) else: @@ -583,7 +586,7 @@ def decode_dataset( this_batch = [] assert len(hyps) == len(texts) for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - + if params.clean: tmp_hyp = " ".join(hyp_words) tmp_hyp = clean(tmp_hyp) @@ -813,12 +816,10 @@ def main(): model.eval() # only load the neural network LM if required - if ( - params.use_shallow_fusion - or params.decoding_method in ( - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_LODR", - "modified_beam_search_lm_rescore_LODR",) + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + "modified_beam_search_lm_rescore_LODR", ): LM = LmScorer( lm_type=params.lm_type, diff --git a/egs/seame/ASR/zipformer_hat/train.py b/egs/seame/ASR/zipformer_hat/train.py index d8973cae5..74d964568 100755 --- a/egs/seame/ASR/zipformer_hat/train.py +++ b/egs/seame/ASR/zipformer_hat/train.py @@ -349,10 +349,10 @@ def get_parser(): parser.add_argument( "--train-size", type=str, - default='full', + default="full", help="train datasize", ) - + parser.add_argument( "--lr-batches", type=float, @@ -551,7 +551,7 @@ def get_params() -> AttributeDict: "valid_interval": 2000, # For the 100h subset, use 800 # parameters for zipformer "feature_dim": 80, - #"model_warm_step": 5000, + # "model_warm_step": 5000, "subsampling_factor": 4, # not passed in, this is fixed. "warm_step": 5000, # parameters for ctc loss @@ -644,7 +644,7 @@ def get_model(params: AttributeDict) -> nn.Module: else: decoder = None joiner = None - + model = AsrModel( encoder_embed=encoder_embed, encoder=encoder, @@ -1199,11 +1199,11 @@ def run(rank, world_size, args): seame = SeameAsrDataModule(args) - if params.train_size == '30': + if params.train_size == "30": train_cuts = seame.train30_cuts() - elif params.train_size == '10': + elif params.train_size == "10": train_cuts = seame.train10_cuts() - elif params.train_size == '50': + elif params.train_size == "50": train_cuts = seame.train50_cuts() else: train_cuts = seame.train_cuts() @@ -1379,5 +1379,6 @@ def main(): else: run(rank=0, world_size=1, args=args) + if __name__ == "__main__": main() diff --git a/egs/seame/ASR/zipformer_hat_lid/beam_search.py b/egs/seame/ASR/zipformer_hat_lid/beam_search.py index 9af280c1c..b6b26628b 100644 --- a/egs/seame/ASR/zipformer_hat_lid/beam_search.py +++ b/egs/seame/ASR/zipformer_hat_lid/beam_search.py @@ -37,6 +37,7 @@ from icefall.utils import ( get_texts_with_timestamp, ) + @dataclass class Result: # timestamps[k] contains the frame number on which tokens[k] @@ -465,7 +466,9 @@ def modified_beam_search( lid_current_encoder_out = lid_encoder_out.data[start:end] current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - asr_lid_current_encoder_out = asr_lid_current_encoder_out.unsqueeze(1).unsqueeze(1) + asr_lid_current_encoder_out = asr_lid_current_encoder_out.unsqueeze( + 1 + ).unsqueeze(1) lid_current_encoder_out = lid_current_encoder_out.unsqueeze(1).unsqueeze(1) # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) offset = end @@ -492,7 +495,6 @@ def modified_beam_search( decoder_out = model.joiner.decoder_proj(decoder_out_) lid_decoder_out = model.lid_joiner.decoder_proj(decoder_out_) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor @@ -521,7 +523,7 @@ def modified_beam_search( project_input=False, lid_out=asr_lid_current_encoder_out, ) # (num_hyps, 1, 1, vocab_size) - + lid_logits = model.lid_joiner( lid_current_encoder_out, lid_decoder_out, @@ -879,6 +881,7 @@ def modified_beam_search_lm_shallow_fusion( timestamps=ans_timestamps, ) + def modified_beam_search_auxlm_shallow_fusion( model: nn.Module, encoder_out: torch.Tensor, @@ -1160,6 +1163,7 @@ def modified_beam_search_auxlm_shallow_fusion( timestamps=ans_timestamps, ) + def modified_beam_search_lm_rescore_LODR( model: nn.Module, encoder_out: torch.Tensor, @@ -1282,9 +1286,9 @@ def modified_beam_search_lm_rescore_LODR( # is equivalent to log(1 - sigmoid(logits[..., 0])). nb_shift = logp_b - logits[..., 0] nb_shift = nb_shift.unsqueeze(-1) - log_probs1 = (logits[..., 1:]/ temperature).log_softmax(dim=-1) + nb_shift + log_probs1 = (logits[..., 1:] / temperature).log_softmax(dim=-1) + nb_shift - #log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + # log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1) log_probs.add_(ys_log_probs) @@ -1691,4 +1695,4 @@ def modified_beam_search_LODR( for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) - return ans \ No newline at end of file + return ans diff --git a/egs/seame/ASR/zipformer_hat_lid/decode.py b/egs/seame/ASR/zipformer_hat_lid/decode.py index a41991729..5456f2794 100755 --- a/egs/seame/ASR/zipformer_hat_lid/decode.py +++ b/egs/seame/ASR/zipformer_hat_lid/decode.py @@ -118,6 +118,7 @@ import matplotlib.pyplot as plt LOG_EPS = math.log(1e-10) + def remove_punc(text): """This function removes all English punctuations except the single quote (verbatim).""" @@ -126,21 +127,23 @@ def remove_punc(text): english_punctuations = english_punctuations.replace("'", "") # Create a translation table that maps each punctuation to a space. - #translator = str.maketrans(english_punctuations, ' ' * len(english_punctuations)) - translator = str.maketrans('', '', english_punctuations) - + # translator = str.maketrans(english_punctuations, ' ' * len(english_punctuations)) + translator = str.maketrans("", "", english_punctuations) + # Translate the text using the translation table text = text.translate(translator) - + return text + def clean(text): text = remove_punc(text) text = text.lower() - text = re.sub(r'\s+', ' ', text) + text = re.sub(r"\s+", " ", text) text = text.rstrip() return text + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -399,60 +402,65 @@ def get_parser(): return parser + def align_lid(labels_a, labels_b, a, b): - # Alignment - EPS = '*' - ali = align(a, b, EPS, sclite_mode=True) + # Alignment + EPS = "*" + ali = align(a, b, EPS, sclite_mode=True) - a2idx = {(i,idx):j for idx,(i,j) in enumerate(zip(a,labels_a))} - b2idx = {(i,idx):j for idx,(i,j) in enumerate(zip(b,labels_b))} - # Comparing labels of aligned elements - idx_a = 0 - idx_b = 0 - ali_idx=0 - aligned_a = [] - aligned_b = [] - while idx_a (T, N, C) - encoder_out, encoder_out_lens, lid_output = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out, encoder_out_lens, lid_output = self.encoder( + x, x_lens, src_key_padding_mask + ) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) @@ -216,7 +218,6 @@ class AsrModel(nn.Module): part """ # Now for the decoder, i.e., the prediction network - blank_id = self.decoder.blank_id sos_y = add_sos(y, sos_id=blank_id) @@ -285,17 +286,20 @@ class AsrModel(nn.Module): # logits : [B, T, prune_range, vocab_size] if self.lid_joiner is not None: - lid_am_pruned, lid_lm_pruned = k2.do_rnnt_pruning( + lid_am_pruned, lid_lm_pruned = k2.do_rnnt_pruning( am=self.lid_joiner.encoder_proj(lid_encoder_out), lm=self.lid_joiner.decoder_proj(decoder_out), ranges=ranges, ) - lid_logits = self.lid_joiner( - lid_am_pruned, lid_lm_pruned, project_input=False) - + lid_logits = self.lid_joiner( + lid_am_pruned, lid_lm_pruned, project_input=False + ) + # project_input=False since we applied the decoder's input projections # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False, lid_out=lid_pruned) + logits = self.joiner( + am_pruned, lm_pruned, project_input=False, lid_out=lid_pruned + ) # Add blank logits to lid_logits logits = torch.cat((lid_logits[..., 0].unsqueeze(-1), logits), dim=-1) @@ -310,21 +314,23 @@ class AsrModel(nn.Module): use_hat_loss=True, ) - # Compute HAT loss for auxiliary lm joiner + # Compute HAT loss for auxiliary lm joiner if self.lid_joiner is not None: - with torch.cuda.amp.autocast(enabled=False): - pruned_lid_loss = k2.rnnt_loss_pruned( - logits=lid_logits.float(), - symbols=y_lid.pad(mode="constant", padding_value=blank_id).to(torch.int64), - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - reduction="sum", - use_hat_loss=True, - ) - return simple_loss, pruned_loss, pruned_lid_loss + with torch.cuda.amp.autocast(enabled=False): + pruned_lid_loss = k2.rnnt_loss_pruned( + logits=lid_logits.float(), + symbols=y_lid.pad(mode="constant", padding_value=blank_id).to( + torch.int64 + ), + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + use_hat_loss=True, + ) + return simple_loss, pruned_loss, pruned_lid_loss else: - return simple_loss, pruned_loss + return simple_loss, pruned_loss def forward( self, @@ -374,7 +380,9 @@ class AsrModel(nn.Module): # Compute encoder outputs if self.lid_joiner != None: - encoder_out, encoder_out_lens, lid_encoder_out = self.forward_encoder(x, x_lens) + encoder_out, encoder_out_lens, lid_encoder_out = self.forward_encoder( + x, x_lens + ) else: encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) @@ -382,30 +390,30 @@ class AsrModel(nn.Module): y_lens = row_splits[1:] - row_splits[:-1] if self.use_transducer: - + # Compute transducer loss if self.lid_joiner != None: - simple_loss, pruned_loss, pruned_loss_lm = self.forward_transducer( - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - lid_encoder_out=lid_encoder_out, - y=y.to(x.device), - y_lens=y_lens, - y_lid=y_lid, - prune_range=prune_range, - am_scale=am_scale, - lm_scale=lm_scale, - ) + simple_loss, pruned_loss, pruned_loss_lm = self.forward_transducer( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + lid_encoder_out=lid_encoder_out, + y=y.to(x.device), + y_lens=y_lens, + y_lid=y_lid, + prune_range=prune_range, + am_scale=am_scale, + lm_scale=lm_scale, + ) else: simple_loss, pruned_loss = self.forward_transducer( - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - y=y.to(x.device), - y_lens=y_lens, - prune_range=prune_range, - am_scale=am_scale, - lm_scale=lm_scale, - ) + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + y=y.to(x.device), + y_lens=y_lens, + prune_range=prune_range, + am_scale=am_scale, + lm_scale=lm_scale, + ) else: simple_loss = torch.empty(0) pruned_loss = torch.empty(0) diff --git a/egs/seame/ASR/zipformer_hat_lid/train.py b/egs/seame/ASR/zipformer_hat_lid/train.py index cb16caa74..1f6648a9c 100755 --- a/egs/seame/ASR/zipformer_hat_lid/train.py +++ b/egs/seame/ASR/zipformer_hat_lid/train.py @@ -366,7 +366,8 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--lid-value-head-dim", type=str, default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.",) + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) parser.add_argument( "--lid-pos-head-dim", type=str, @@ -429,6 +430,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Whether to skip positional embedding in the lid encoder.", ) + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -781,9 +783,11 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: causal=params.causal, chunk_size=_to_int_tuple(params.chunk_size), left_context_frames=_to_int_tuple(params.left_context_frames), - lid_output_layer=params.lid_output_layer if params.use_lid_encoder else None,) + lid_output_layer=params.lid_output_layer if params.use_lid_encoder else None, + ) return encoder + def get_lid_encoder_model(params: AttributeDict) -> nn.Module: lid_encoder = Zipformer2( output_downsampling_factor=2, @@ -806,6 +810,7 @@ def get_lid_encoder_model(params: AttributeDict) -> nn.Module: ) return lid_encoder + def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, @@ -826,15 +831,17 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: ) return joiner + def get_lid_joiner_model(params: AttributeDict) -> nn.Module: lid_joiner = Joiner( encoder_dim=int(params.lid_encoder_dim.split(",")[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.lid_joiner_dim, - vocab_size=len(params.lids.split(","))+1, + vocab_size=len(params.lids.split(",")) + 1, ) return lid_joiner + def get_model(params: AttributeDict) -> nn.Module: assert params.use_transducer or params.use_ctc, ( f"At least one of them should be True, " @@ -858,7 +865,6 @@ def get_model(params: AttributeDict) -> nn.Module: decoder = None joiner = None - model = AsrModel( encoder_embed=encoder_embed, encoder=encoder, @@ -875,7 +881,6 @@ def get_model(params: AttributeDict) -> nn.Module: return model - def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, @@ -916,7 +921,7 @@ def load_checkpoint_if_available( return None assert filename.is_file(), f"{filename} does not exist!" - + saved_params = load_checkpoint( filename, model=model, @@ -1017,8 +1022,8 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - - lids_dict = {lid:id+1 for id, lid in enumerate(params.lids.split(","))} + + lids_dict = {lid: id + 1 for id, lid in enumerate(params.lids.split(","))} device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) @@ -1066,9 +1071,7 @@ def compute_loss( lid_pruned_loss_is_finite = torch.isfinite(lid_pruned_loss) is_finite = ( - simple_loss_is_finite - & pruned_loss_is_finite - & lid_pruned_loss_is_finite + simple_loss_is_finite & pruned_loss_is_finite & lid_pruned_loss_is_finite ) if not torch.all(is_finite): logging.info( @@ -1091,17 +1094,18 @@ def compute_loss( else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - - loss += (1-params.lid_loss_scale)*(simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss) - #loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss += (1 - params.lid_loss_scale) * ( + simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + ) + # loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if params.use_lid_joiner: loss += params.lid_loss_scale * pruned_loss_scale * lid_pruned_loss - #loss += pruned_loss_scale * lid_pruned_loss + # loss += pruned_loss_scale * lid_pruned_loss if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss @@ -1119,7 +1123,7 @@ def compute_loss( info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() if params.use_lid_joiner: - info["lid_pruned_loss"] = lid_pruned_loss.detach().cpu().item() + info["lid_pruned_loss"] = lid_pruned_loss.detach().cpu().item() if params.use_ctc: info["ctc_loss"] = ctc_loss.detach().cpu().item() @@ -1456,7 +1460,7 @@ def run(rank, world_size, args): ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - + # if checkpoints and "optimizer" in checkpoints: # logging.info("Loading optimizer state dict") # optimizer.load_state_dict(checkpoints["optimizer"])