black formating

This commit is contained in:
AmirHussein96 2024-04-05 13:00:29 -04:00
parent c7f74e410f
commit 891cf55901
21 changed files with 387 additions and 358 deletions

View File

@ -14,12 +14,7 @@ import jiwer
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("--dec-file", type=str, help="Decoded icefall recogs file")
"--dec-file",
type=str,
help="Decoded icefall recogs file"
)
return parser return parser
@ -29,22 +24,22 @@ def cer_(file):
ref = [] ref = []
cer_results = 0 cer_results = 0
ref_lens = 0 ref_lens = 0
with open(file, 'r', encoding='utf-8') as dec: with open(file, "r", encoding="utf-8") as dec:
for line in dec: for line in dec:
id, target = line.split('\t') id, target = line.split("\t")
id = id[0:-2] id = id[0:-2]
target, txt = target.split("=") target, txt = target.split("=")
if target == 'ref': if target == "ref":
words = txt.strip().strip('[]').split(', ') words = txt.strip().strip("[]").split(", ")
word_list = [word.strip("'") for word in words] word_list = [word.strip("'") for word in words]
ref.append(" ".join(word_list)) ref.append(" ".join(word_list))
elif target == 'hyp': elif target == "hyp":
words = txt.strip().strip('[]').split(', ') words = txt.strip().strip("[]").split(", ")
word_list = [word.strip("'") for word in words] word_list = [word.strip("'") for word in words]
hyp.append(" ".join(word_list)) hyp.append(" ".join(word_list))
for h, r in zip(hyp, ref): for h, r in zip(hyp, ref):
if r: if r:
cer_results += (jiwer.cer(r, h)*len(r)) cer_results += jiwer.cer(r, h) * len(r)
ref_lens += len(r) ref_lens += len(r)
print(cer_results / ref_lens) print(cer_results / ref_lens)
@ -55,5 +50,6 @@ def main():
args = parse.parse_args() args = parse.parse_args()
cer_(args.dec_file) cer_(args.dec_file)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -38,6 +38,7 @@ from lhotse.features.kaldifeat import (
KaldifeatMelOptions, KaldifeatMelOptions,
) )
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
@ -70,7 +71,7 @@ def get_args():
def compute_fbank_gpu(args): def compute_fbank_gpu(args):
src_dir = Path("data_seame/manifests") src_dir = Path("data_seame/manifests")
output_dir = Path("data_seame/fbank") output_dir = Path("data_seame/fbank")
num_jobs = min(os.cpu_count(),10) num_jobs = min(os.cpu_count(), 10)
num_mel_bins = 80 num_mel_bins = 80
sampling_rate = 16000 sampling_rate = 16000
sr = 16000 sr = 16000
@ -87,7 +88,10 @@ def compute_fbank_gpu(args):
suffix = "jsonl.gz" suffix = "jsonl.gz"
breakpoint breakpoint
manifests = read_manifests_if_cached( manifests = read_manifests_if_cached(
prefix=prefix, dataset_parts=dataset_parts, output_dir=src_dir,suffix=suffix, prefix=prefix,
dataset_parts=dataset_parts,
output_dir=src_dir,
suffix=suffix,
) )
assert manifests is not None assert manifests is not None
@ -116,15 +120,11 @@ def compute_fbank_gpu(args):
cut_set = cut_set.resample(sr) cut_set = cut_set.resample(sr)
cut_set = cut_set.trim_to_supervisions( cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, keep_overlapping=False, keep_all_channels=False
keep_all_channels=False)
cut_set = cut_set.filter(lambda c: c.duration >= .2 and c.duration <= 30)
if "train" in partition:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
) )
cut_set = cut_set.filter(lambda c: c.duration >= 0.2 and c.duration <= 30)
if "train" in partition:
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
cut_set = cut_set.compute_and_store_features_batch( cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor, extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}", storage_path=f"{output_dir}/{prefix}_feats_{partition}",
@ -147,10 +147,9 @@ def compute_fbank_gpu(args):
) )
cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz") cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz")
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args() args = get_args()

View File

@ -37,6 +37,7 @@ from lhotse.features.kaldifeat import (
KaldifeatMelOptions, KaldifeatMelOptions,
) )
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
@ -69,7 +70,7 @@ def get_args():
def compute_fbank_gpu(args): def compute_fbank_gpu(args):
src_dir = Path("data_seame/manifests") src_dir = Path("data_seame/manifests")
output_dir = Path("data_seame/fbank") output_dir = Path("data_seame/fbank")
num_jobs = min(os.cpu_count(),10) num_jobs = min(os.cpu_count(), 10)
num_mel_bins = 80 num_mel_bins = 80
sampling_rate = 16000 sampling_rate = 16000
sr = 16000 sr = 16000
@ -80,7 +81,6 @@ def compute_fbank_gpu(args):
"train10", "train10",
"train50", "train50",
"train30", "train30",
) )
prefix = "" prefix = ""
suffix = "jsonl.gz" suffix = "jsonl.gz"
@ -103,15 +103,11 @@ def compute_fbank_gpu(args):
cut_set = cut_set.resample(sr) cut_set = cut_set.resample(sr)
cut_set = cut_set.trim_to_supervisions( cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, keep_overlapping=False, keep_all_channels=False
keep_all_channels=False)
cut_set = cut_set.filter(lambda c: c.duration >= .5 and c.duration <= 30)
if "train" in part:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
) )
cut_set = cut_set.filter(lambda c: c.duration >= 0.5 and c.duration <= 30)
if "train" in part:
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
cut_set = cut_set.compute_and_store_features_batch( cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor, extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{part}", storage_path=f"{output_dir}/{prefix}_feats_{part}",
@ -134,10 +130,9 @@ def compute_fbank_gpu(args):
) )
cut_set.to_file(output_dir / f"cuts_{part}.jsonl.gz") cut_set.to_file(output_dir / f"cuts_{part}.jsonl.gz")
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args() args = get_args()

View File

@ -7,7 +7,6 @@ from lhotse.qa import fix_manifests, validate_recordings_and_supervisions
import pdb import pdb
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -44,7 +43,7 @@ def get_parser():
def valid_asr(cut): def valid_asr(cut):
tol = 2e-3 tol = 2e-3
i=0 i = 0
total_dur = 0 total_dur = 0
for c in cut: for c in cut:
if c.supervisions != []: if c.supervisions != []:
@ -52,10 +51,14 @@ def valid_asr(cut):
logging.info(f"Supervision beyond the cut. Cut number: {i}") logging.info(f"Supervision beyond the cut. Cut number: {i}")
total_dur += c.duration total_dur += c.duration
logging.info(f"id: {c.id}, sup_end: {c.supervisions[0].end}, dur: {c.duration}, source {c.recording.sources[0].source}") logging.info(
f"id: {c.id}, sup_end: {c.supervisions[0].end}, dur: {c.duration}, source {c.recording.sources[0].source}"
)
elif c.supervisions[0].start < -tol: elif c.supervisions[0].start < -tol:
logging.info(f"Supervision starts before the cut. Cut number: {i}") logging.info(f"Supervision starts before the cut. Cut number: {i}")
logging.info(f"id: {c.id}, sup_start: {c.supervisions[0].start}, dur: {c.duration}, source {c.recording.sources[0].source}") logging.info(
f"id: {c.id}, sup_start: {c.supervisions[0].start}, dur: {c.duration}, source {c.recording.sources[0].source}"
)
else: else:
continue continue
else: else:
@ -83,7 +86,10 @@ def main():
logging.info("Validating manifests") logging.info("Validating manifests")
validate_recordings_and_supervisions(recordings, supervisions) validate_recordings_and_supervisions(recordings, supervisions)
cuts = CutSet.from_manifests(recordings= recordings, supervisions=supervisions,) cuts = CutSet.from_manifests(
recordings=recordings,
supervisions=supervisions,
)
cuts = cuts.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False) cuts = cuts.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
cuts.describe() cuts.describe()
logging.info("Example from cut:") logging.info("Example from cut:")
@ -93,5 +99,6 @@ def main():
if args.savecut != "": if args.savecut != "":
cuts.to_file(args.savecut) cuts.to_file(args.savecut)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -25,9 +25,7 @@ def main():
for line in f: for line in f:
line = line.strip() line = line.strip()
characters = list(line) characters = list(line)
characters = " ".join( characters = " ".join(["V" if char == "*" else char for char in characters])
["V" if char == "*" else char for char in characters]
)
lex[line] = characters lex[line] = characters
with open(args.output, "w", encoding="utf-8") as fp: with open(args.output, "w", encoding="utf-8") as fp:

View File

@ -44,11 +44,12 @@ def main():
if not os.path.exists(langdir): if not os.path.exists(langdir):
os.makedirs(langdir) os.makedirs(langdir)
with open(langdir / "transcript_words.txt", 'w') as txt: with open(langdir / "transcript_words.txt", "w") as txt:
for c in cuts: for c in cuts:
#breakpoint() # breakpoint()
text = c.supervisions[0].text text = c.supervisions[0].text
txt.write(text + '\n') txt.write(text + "\n")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -68,8 +68,13 @@ def main():
recordings = RecordingSet.from_file(args.rec) recordings = RecordingSet.from_file(args.rec)
supervisions = SupervisionSet.from_file(args.sup) supervisions = SupervisionSet.from_file(args.sup)
logging.info("Fixing manifests") logging.info("Fixing manifests")
cuts = CutSet.from_manifests(recordings= recordings, supervisions=supervisions,) cuts = CutSet.from_manifests(
cuts = cuts.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False) recordings=recordings,
supervisions=supervisions,
)
cuts = cuts.trim_to_supervisions(
keep_overlapping=False, keep_all_channels=False
)
shuffled = cuts.shuffle() shuffled = cuts.shuffle()
total_dur = 0 total_dur = 0
@ -86,5 +91,6 @@ def main():
logging.info(f"Saving {args.outcut}") logging.info(f"Saving {args.outcut}")
cuts.to_file(outdir / args.outcut) cuts.to_file(outdir / args.outcut)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -91,7 +91,7 @@ def main():
user_defined_symbols = ["<blk>", "<sos/eos>"] user_defined_symbols = ["<blk>", "<sos/eos>"]
unk_id = len(user_defined_symbols) unk_id = len(user_defined_symbols)
if predef_sym: if predef_sym:
syms = predef_sym.split(',') syms = predef_sym.split(",")
for i in syms: for i in syms:
user_defined_symbols.append(i) user_defined_symbols.append(i)
# Note: unk_id is fixed to 2. # Note: unk_id is fixed to 2.
@ -116,5 +116,6 @@ def main():
shutil.copyfile(model_file, f"{lang_dir}/bpe.model") shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
generate_tokens(lang_dir) generate_tokens(lang_dir)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -25,29 +25,31 @@ def get_parser():
) )
return parser return parser
lids = "en,zh" lids = "en,zh"
lids_dict = {lid:id+1 for id, lid in enumerate(lids.split(","))} lids_dict = {lid: id + 1 for id, lid in enumerate(lids.split(","))}
id2lang = {id+1: lid for id, lid in enumerate(lids.split(","))} id2lang = {id + 1: lid for id, lid in enumerate(lids.split(","))}
bad_id = [] bad_id = []
def extract_info(line, info): def extract_info(line, info):
# Split the line at the first colon to separate the ID # Split the line at the first colon to separate the ID
id_part, rest = line.split(':', 1) id_part, rest = line.split(":", 1)
# Extract 'ref' by finding its start and end # Extract 'ref' by finding its start and end
ref_start = rest.find(info) ref_start = rest.find(info)
ref_end = rest.find(']', ref_start) ref_end = rest.find("]", ref_start)
ref = rest[ref_start+len(info):ref_end].replace("'", "").split(', ') ref = rest[ref_start + len(info) : ref_end].replace("'", "").split(", ")
# Extract 'lid' # Extract 'lid'
if 'lid=' in rest: if "lid=" in rest:
lid_start = rest.find('lid=[') lid_start = rest.find("lid=[")
lid_end = rest.find(']', lid_start) lid_end = rest.find("]", lid_start)
lid = rest[lid_start+len('lid=['):lid_end].split(', ') lid = rest[lid_start + len("lid=[") : lid_end].split(", ")
else: else:
lid = [''] lid = [""]
if lid[0]=='': if lid[0] == "":
bad_id.append(id_part) bad_id.append(id_part)
if " ".join(lid): if " ".join(lid):
lid = [int(i) for i in lid] # Convert each element to integer lid = [int(i) for i in lid] # Convert each element to integer
@ -58,6 +60,7 @@ def is_English(c):
"""check character is in English""" """check character is in English"""
return ord(c.lower()) >= ord("a") and ord(c.lower()) <= ord("z") return ord(c.lower()) >= ord("a") and ord(c.lower()) <= ord("z")
def get_en(text): def get_en(text):
res = [] res = []
for w in text: for w in text:
@ -68,6 +71,7 @@ def get_en(text):
continue continue
return res return res
def get_zh(text): def get_zh(text):
res = [] res = []
for w in text: for w in text:
@ -79,34 +83,33 @@ def get_zh(text):
return res return res
def extract_info_lid(line, tag): def extract_info_lid(line, tag):
# Split the line at the first colon to separate the ID # Split the line at the first colon to separate the ID
id_part, rest = line.split(':', 1) id_part, rest = line.split(":", 1)
# Extract 'ref' by finding its start and end # Extract 'ref' by finding its start and end
ref_start = rest.find(tag) ref_start = rest.find(tag)
ref_end = rest.find(']', ref_start) ref_end = rest.find("]", ref_start)
ref = rest[ref_start+len(tag):ref_end].replace("'", "").split(', ') ref = rest[ref_start + len(tag) : ref_end].replace("'", "").split(", ")
return id_part.strip(), ref return id_part.strip(), ref
def align_lid2(labels_a, labels_b, a, b): def align_lid2(labels_a, labels_b, a, b):
# Alignment # Alignment
EPS = '*' EPS = "*"
ali = align(a, b, EPS, sclite_mode=True) ali = align(a, b, EPS, sclite_mode=True)
a2idx = {(i,idx):j for idx,(i,j) in enumerate(zip(a,labels_a))} a2idx = {(i, idx): j for idx, (i, j) in enumerate(zip(a, labels_a))}
b2idx = {(i,idx):j for idx,(i,j) in enumerate(zip(b,labels_b))} b2idx = {(i, idx): j for idx, (i, j) in enumerate(zip(b, labels_b))}
# Comparing labels of aligned elements # Comparing labels of aligned elements
idx_a = 0 idx_a = 0
idx_b = 0 idx_b = 0
ali_idx=0 ali_idx = 0
aligned_a = [] aligned_a = []
aligned_b = [] aligned_b = []
while idx_a <len(a) and idx_b <len(b) and ali_idx < len(ali): while idx_a < len(a) and idx_b < len(b) and ali_idx < len(ali):
elem_a, elem_b = ali[ali_idx] elem_a, elem_b = ali[ali_idx]
if elem_a == EPS: if elem_a == EPS:
idx_b += 1 idx_b += 1
@ -114,8 +117,8 @@ def align_lid2(labels_a, labels_b, a, b):
idx_a += 1 idx_a += 1
elif elem_a != EPS and elem_b != EPS: elif elem_a != EPS and elem_b != EPS:
label_a = a2idx[(elem_a,idx_a)] label_a = a2idx[(elem_a, idx_a)]
label_b = b2idx[(elem_b,idx_b)] label_b = b2idx[(elem_b, idx_b)]
aligned_a.append(label_a) aligned_a.append(label_a)
aligned_b.append(label_b) aligned_b.append(label_b)
idx_b += 1 idx_b += 1
@ -128,7 +131,7 @@ def align_lid2(labels_a, labels_b, a, b):
def align_lid(labels_a, labels_b): def align_lid(labels_a, labels_b):
# Alignment # Alignment
res_a, res_b = [], [] res_a, res_b = [], []
EPS = '*' EPS = "*"
ali = align(labels_a, labels_b, EPS, sclite_mode=True) ali = align(labels_a, labels_b, EPS, sclite_mode=True)
# Comparing labels of aligned elements # Comparing labels of aligned elements
@ -139,17 +142,18 @@ def align_lid(labels_a, labels_b):
def read_file(infile, tag): def read_file(infile, tag):
""""returns list of dict (id, lid, text)""" """ "returns list of dict (id, lid, text)"""
res = [] res = []
with open(infile, 'r') as file: with open(infile, "r") as file:
for line in file: for line in file:
_, rest = line.split(':', 1) _, rest = line.split(":", 1)
if tag in rest: if tag in rest:
_id, text = extract_info_lid(line, tag) _id, text = extract_info_lid(line, tag)
res.append((_id, text)) res.append((_id, text))
return res return res
def wer(results, sclite_mode=False): def wer(results, sclite_mode=False):
subs = defaultdict(int) subs = defaultdict(int)
ins = defaultdict(int) ins = defaultdict(int)
@ -185,12 +189,12 @@ def wer(results, sclite_mode=False):
return tot_err_rate return tot_err_rate
if __name__ == '__main__': if __name__ == "__main__":
parser = get_parser() parser = get_parser()
args = parser.parse_args() args = parser.parse_args()
ref_data = read_file(args.rec, 'ref=[') ref_data = read_file(args.rec, "ref=[")
ref_data = sorted(ref_data) ref_data = sorted(ref_data)
hyp_data = read_file(args.rec, 'hyp=[') hyp_data = read_file(args.rec, "hyp=[")
hyp_data = sorted(hyp_data) hyp_data = sorted(hyp_data)
results = defaultdict(list) results = defaultdict(list)
@ -205,16 +209,9 @@ if __name__ == '__main__':
hyp_text_en = get_en(hyp_text) hyp_text_en = get_en(hyp_text)
hyp_text_zh = get_zh(hyp_text) hyp_text_zh = get_zh(hyp_text)
results["en"].append((ref[0], ref_text_en, hyp_text_en))
results['en'].append((ref[0],ref_text_en, hyp_text_en)) results["zh"].append((ref[0], ref_text_zh, hyp_text_zh))
results['zh'].append((ref[0],ref_text_zh, hyp_text_zh))
for key, val in results.items(): for key, val in results.items():
print(key) print(key)
res = wer(val) res = wer(val)

View File

@ -399,25 +399,18 @@ class SeameAsrDataModule:
@lru_cache() @lru_cache()
def train_cuts(self) -> CutSet: def train_cuts(self) -> CutSet:
logging.info("Train data: About to get training cuts") logging.info("Train data: About to get training cuts")
return load_manifest_lazy( return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz")
self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
)
@lru_cache() @lru_cache()
def valid_cuts(self) -> CutSet: def valid_cuts(self) -> CutSet:
logging.info("Dev data: About to get develop cuts") logging.info("Dev data: About to get develop cuts")
return load_manifest_lazy( return load_manifest_lazy(self.args.manifest_dir / "cuts_valid.jsonl.gz")
self.args.manifest_dir / "cuts_valid.jsonl.gz"
)
@lru_cache() @lru_cache()
def dev_man(self) -> CutSet: def dev_man(self) -> CutSet:
logging.info("About to get dev_man cuts") logging.info("About to get dev_man cuts")
return load_manifest_lazy( return load_manifest_lazy(self.args.manifest_dir / "cuts_dev_man.jsonl.gz")
self.args.manifest_dir / "cuts_dev_man.jsonl.gz"
)
def dev_sge(self) -> CutSet: def dev_sge(self) -> CutSet:
logging.info("About to get dev_sge cuts") logging.info("About to get dev_sge cuts")
return load_manifest_lazy( return load_manifest_lazy(self.args.manifest_dir / "cuts_dev_sge.jsonl.gz")
self.args.manifest_dir / "cuts_dev_sge.jsonl.gz"
)

View File

@ -111,6 +111,7 @@ import re
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
def remove_punc(text): def remove_punc(text):
"""This function removes all English punctuations except the single quote (verbatim).""" """This function removes all English punctuations except the single quote (verbatim)."""
@ -119,20 +120,22 @@ def remove_punc(text):
# english_punctuations = english_punctuations.replace("'", "") # english_punctuations = english_punctuations.replace("'", "")
# Create a translation table that maps each punctuation to a space. # Create a translation table that maps each punctuation to a space.
translator = str.maketrans(english_punctuations, ' ' * len(english_punctuations)) translator = str.maketrans(english_punctuations, " " * len(english_punctuations))
# Translate the text using the translation table # Translate the text using the translation table
text = text.translate(translator) text = text.translate(translator)
return text return text
def clean(text): def clean(text):
text = remove_punc(text) text = remove_punc(text)
text = text.lower() text = text.lower()
text = re.sub(r'\s+', ' ', text) text = re.sub(r"\s+", " ", text)
text = text.rstrip() text = text.rstrip()
return text return text
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter

View File

@ -1384,5 +1384,6 @@ def main():
else: else:
run(rank=0, world_size=1, args=args) run(rank=0, world_size=1, args=args)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -924,9 +924,9 @@ def modified_beam_search_lm_rescore_LODR(
# is equivalent to log(1 - sigmoid(logits[..., 0])). # is equivalent to log(1 - sigmoid(logits[..., 0])).
nb_shift = logp_b - logits[..., 0] nb_shift = logp_b - logits[..., 0]
nb_shift = nb_shift.unsqueeze(-1) nb_shift = nb_shift.unsqueeze(-1)
log_probs1 = (logits[..., 1:]/ temperature).log_softmax(dim=-1) + nb_shift log_probs1 = (logits[..., 1:] / temperature).log_softmax(dim=-1) + nb_shift
#log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) # log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1) log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1)
log_probs.add_(ys_log_probs) log_probs.add_(ys_log_probs)

View File

@ -91,6 +91,7 @@ import re
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
def remove_punc(text): def remove_punc(text):
"""This function removes all English punctuations except the single quote (verbatim).""" """This function removes all English punctuations except the single quote (verbatim)."""
@ -99,20 +100,22 @@ def remove_punc(text):
# english_punctuations = english_punctuations.replace("'", "") # english_punctuations = english_punctuations.replace("'", "")
# Create a translation table that maps each punctuation to a space. # Create a translation table that maps each punctuation to a space.
translator = str.maketrans(english_punctuations, ' ' * len(english_punctuations)) translator = str.maketrans(english_punctuations, " " * len(english_punctuations))
# Translate the text using the translation table # Translate the text using the translation table
text = text.translate(translator) text = text.translate(translator)
return text return text
def clean(text): def clean(text):
text = remove_punc(text) text = remove_punc(text)
text = text.lower() text = text.lower()
text = re.sub(r'\s+', ' ', text) text = re.sub(r"\s+", " ", text)
text = text.rstrip() text = text.rstrip()
return text return text
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -813,12 +816,10 @@ def main():
model.eval() model.eval()
# only load the neural network LM if required # only load the neural network LM if required
if ( if params.use_shallow_fusion or params.decoding_method in (
params.use_shallow_fusion
or params.decoding_method in (
"modified_beam_search_lm_shallow_fusion", "modified_beam_search_lm_shallow_fusion",
"modified_beam_search_LODR", "modified_beam_search_LODR",
"modified_beam_search_lm_rescore_LODR",) "modified_beam_search_lm_rescore_LODR",
): ):
LM = LmScorer( LM = LmScorer(
lm_type=params.lm_type, lm_type=params.lm_type,

View File

@ -349,7 +349,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--train-size", "--train-size",
type=str, type=str,
default='full', default="full",
help="train datasize", help="train datasize",
) )
@ -551,7 +551,7 @@ def get_params() -> AttributeDict:
"valid_interval": 2000, # For the 100h subset, use 800 "valid_interval": 2000, # For the 100h subset, use 800
# parameters for zipformer # parameters for zipformer
"feature_dim": 80, "feature_dim": 80,
#"model_warm_step": 5000, # "model_warm_step": 5000,
"subsampling_factor": 4, # not passed in, this is fixed. "subsampling_factor": 4, # not passed in, this is fixed.
"warm_step": 5000, "warm_step": 5000,
# parameters for ctc loss # parameters for ctc loss
@ -1199,11 +1199,11 @@ def run(rank, world_size, args):
seame = SeameAsrDataModule(args) seame = SeameAsrDataModule(args)
if params.train_size == '30': if params.train_size == "30":
train_cuts = seame.train30_cuts() train_cuts = seame.train30_cuts()
elif params.train_size == '10': elif params.train_size == "10":
train_cuts = seame.train10_cuts() train_cuts = seame.train10_cuts()
elif params.train_size == '50': elif params.train_size == "50":
train_cuts = seame.train50_cuts() train_cuts = seame.train50_cuts()
else: else:
train_cuts = seame.train_cuts() train_cuts = seame.train_cuts()
@ -1379,5 +1379,6 @@ def main():
else: else:
run(rank=0, world_size=1, args=args) run(rank=0, world_size=1, args=args)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -37,6 +37,7 @@ from icefall.utils import (
get_texts_with_timestamp, get_texts_with_timestamp,
) )
@dataclass @dataclass
class Result: class Result:
# timestamps[k] contains the frame number on which tokens[k] # timestamps[k] contains the frame number on which tokens[k]
@ -465,7 +466,9 @@ def modified_beam_search(
lid_current_encoder_out = lid_encoder_out.data[start:end] lid_current_encoder_out = lid_encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
asr_lid_current_encoder_out = asr_lid_current_encoder_out.unsqueeze(1).unsqueeze(1) asr_lid_current_encoder_out = asr_lid_current_encoder_out.unsqueeze(
1
).unsqueeze(1)
lid_current_encoder_out = lid_current_encoder_out.unsqueeze(1).unsqueeze(1) lid_current_encoder_out = lid_current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end offset = end
@ -492,7 +495,6 @@ def modified_beam_search(
decoder_out = model.joiner.decoder_proj(decoder_out_) decoder_out = model.joiner.decoder_proj(decoder_out_)
lid_decoder_out = model.lid_joiner.decoder_proj(decoder_out_) lid_decoder_out = model.lid_joiner.decoder_proj(decoder_out_)
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim) # decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
@ -879,6 +881,7 @@ def modified_beam_search_lm_shallow_fusion(
timestamps=ans_timestamps, timestamps=ans_timestamps,
) )
def modified_beam_search_auxlm_shallow_fusion( def modified_beam_search_auxlm_shallow_fusion(
model: nn.Module, model: nn.Module,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
@ -1160,6 +1163,7 @@ def modified_beam_search_auxlm_shallow_fusion(
timestamps=ans_timestamps, timestamps=ans_timestamps,
) )
def modified_beam_search_lm_rescore_LODR( def modified_beam_search_lm_rescore_LODR(
model: nn.Module, model: nn.Module,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
@ -1282,9 +1286,9 @@ def modified_beam_search_lm_rescore_LODR(
# is equivalent to log(1 - sigmoid(logits[..., 0])). # is equivalent to log(1 - sigmoid(logits[..., 0])).
nb_shift = logp_b - logits[..., 0] nb_shift = logp_b - logits[..., 0]
nb_shift = nb_shift.unsqueeze(-1) nb_shift = nb_shift.unsqueeze(-1)
log_probs1 = (logits[..., 1:]/ temperature).log_softmax(dim=-1) + nb_shift log_probs1 = (logits[..., 1:] / temperature).log_softmax(dim=-1) + nb_shift
#log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) # log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1) log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1)
log_probs.add_(ys_log_probs) log_probs.add_(ys_log_probs)

View File

@ -118,6 +118,7 @@ import matplotlib.pyplot as plt
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
def remove_punc(text): def remove_punc(text):
"""This function removes all English punctuations except the single quote (verbatim).""" """This function removes all English punctuations except the single quote (verbatim)."""
@ -126,21 +127,23 @@ def remove_punc(text):
english_punctuations = english_punctuations.replace("'", "") english_punctuations = english_punctuations.replace("'", "")
# Create a translation table that maps each punctuation to a space. # Create a translation table that maps each punctuation to a space.
#translator = str.maketrans(english_punctuations, ' ' * len(english_punctuations)) # translator = str.maketrans(english_punctuations, ' ' * len(english_punctuations))
translator = str.maketrans('', '', english_punctuations) translator = str.maketrans("", "", english_punctuations)
# Translate the text using the translation table # Translate the text using the translation table
text = text.translate(translator) text = text.translate(translator)
return text return text
def clean(text): def clean(text):
text = remove_punc(text) text = remove_punc(text)
text = text.lower() text = text.lower()
text = re.sub(r'\s+', ' ', text) text = re.sub(r"\s+", " ", text)
text = text.rstrip() text = text.rstrip()
return text return text
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -399,20 +402,21 @@ def get_parser():
return parser return parser
def align_lid(labels_a, labels_b, a, b): def align_lid(labels_a, labels_b, a, b):
# Alignment # Alignment
EPS = '*' EPS = "*"
ali = align(a, b, EPS, sclite_mode=True) ali = align(a, b, EPS, sclite_mode=True)
a2idx = {(i,idx):j for idx,(i,j) in enumerate(zip(a,labels_a))} a2idx = {(i, idx): j for idx, (i, j) in enumerate(zip(a, labels_a))}
b2idx = {(i,idx):j for idx,(i,j) in enumerate(zip(b,labels_b))} b2idx = {(i, idx): j for idx, (i, j) in enumerate(zip(b, labels_b))}
# Comparing labels of aligned elements # Comparing labels of aligned elements
idx_a = 0 idx_a = 0
idx_b = 0 idx_b = 0
ali_idx=0 ali_idx = 0
aligned_a = [] aligned_a = []
aligned_b = [] aligned_b = []
while idx_a <len(a) and idx_b <len(b) and ali_idx < len(ali): while idx_a < len(a) and idx_b < len(b) and ali_idx < len(ali):
elem_a, elem_b = ali[ali_idx] elem_a, elem_b = ali[ali_idx]
if elem_a == EPS: if elem_a == EPS:
idx_b += 1 idx_b += 1
@ -420,8 +424,8 @@ def align_lid(labels_a, labels_b, a, b):
idx_a += 1 idx_a += 1
elif elem_a != EPS and elem_b != EPS: elif elem_a != EPS and elem_b != EPS:
label_a = a2idx[(elem_a,idx_a)] label_a = a2idx[(elem_a, idx_a)]
label_b = b2idx[(elem_b,idx_b)] label_b = b2idx[(elem_b, idx_b)]
aligned_a.append(label_a) aligned_a.append(label_a)
aligned_b.append(label_b) aligned_b.append(label_b)
idx_b += 1 idx_b += 1
@ -430,29 +434,33 @@ def align_lid(labels_a, labels_b, a, b):
ali_idx += 1 ali_idx += 1
return aligned_a, aligned_b return aligned_a, aligned_b
def write_lid_results(lid_path, f1_path, text, lid ):
def write_lid_results(lid_path, f1_path, text, lid):
lid_hyp = [] lid_hyp = []
lid_ref = [] lid_ref = []
with open(lid_path, 'w') as file: with open(lid_path, "w") as file:
# Write each line to the file # Write each line to the file
for text_line, lid_line in zip(text, lid): for text_line, lid_line in zip(text, lid):
file.write(f"{text_line[0]}: ref={text_line[1]} lid={lid_line[1]}" + '\n') file.write(f"{text_line[0]}: ref={text_line[1]} lid={lid_line[1]}" + "\n")
aligned_ref, aligned_hyp = align_lid(lid_line[1],lid_line[2],text_line[1], text_line[2]) aligned_ref, aligned_hyp = align_lid(
lid_line[1], lid_line[2], text_line[1], text_line[2]
)
lid_ref.extend(aligned_ref) lid_ref.extend(aligned_ref)
lid_hyp.extend(aligned_hyp) lid_hyp.extend(aligned_hyp)
file.write(f"{lid_line[0]}: hyp={text_line[2]} lid={lid_line[2]}" + '\n') file.write(f"{lid_line[0]}: hyp={text_line[2]} lid={lid_line[2]}" + "\n")
report = classification_report(lid_ref, lid_hyp, zero_division=0) report = classification_report(lid_ref, lid_hyp, zero_division=0)
f1 = f1_score(lid_ref, lid_hyp, average='weighted') f1 = f1_score(lid_ref, lid_hyp, average="weighted")
with open(f1_path, 'w') as file: with open(f1_path, "w") as file:
file.write(report) file.write(report)
file.write('\n') file.write("\n")
file.write(f"F1 score: {f1} \n") file.write(f"F1 score: {f1} \n")
filename = os.path.basename(lid_path).replace('.txt', '.png') filename = os.path.basename(lid_path).replace(".txt", ".png")
dirname = os.path.dirname(lid_path) dirname = os.path.dirname(lid_path)
save_conf_mat(os.path.join(dirname,filename), lid_ref, lid_hyp) save_conf_mat(os.path.join(dirname, filename), lid_ref, lid_hyp)
def save_conf_mat(path, lid_ref, lid_hyp): def save_conf_mat(path, lid_ref, lid_hyp):
all_labels = [1, 2, 3, 4] all_labels = [1, 2, 3, 4]
@ -463,15 +471,23 @@ def save_conf_mat(path, lid_ref, lid_hyp):
# Plot the confusion matrix # Plot the confusion matrix
plt.figure(figsize=(10, 7)) plt.figure(figsize=(10, 7))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names) sns.heatmap(
plt.title('Confusion Matrix') cm,
plt.xlabel('Predicted Labels') annot=True,
plt.ylabel('True Labels') fmt="d",
cmap="Blues",
xticklabels=class_names,
yticklabels=class_names,
)
plt.title("Confusion Matrix")
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.savefig(path) plt.savefig(path)
def most_frequent(List): def most_frequent(List):
return max(set(List), key = List.count) return max(set(List), key=List.count)
def mapp(enc, LID): def mapp(enc, LID):
pt1 = 0 pt1 = 0
@ -492,6 +508,7 @@ def mapp(enc, LID):
return new_lid return new_lid
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -561,7 +578,9 @@ def decode_one_batch(
value=LOG_EPS, value=LOG_EPS,
) )
encoder_out, encoder_out_lens, lid_encoder_out = model.forward_encoder(feature, feature_lens) encoder_out, encoder_out_lens, lid_encoder_out = model.forward_encoder(
feature, feature_lens
)
hyps = [] hyps = []
B, T, F = feature.shape B, T, F = feature.shape
@ -579,7 +598,7 @@ def decode_one_batch(
hyp = results[i] hyp = results[i]
token_pieces = sp.IdToPiece(results[i].hyps) token_pieces = sp.IdToPiece(results[i].hyps)
new_lid = mapp(token_pieces, results[i].lid_hyps) new_lid = mapp(token_pieces, results[i].lid_hyps)
hyps.append((sp.decode(results[i].hyps).split(),new_lid)) hyps.append((sp.decode(results[i].hyps).split(), new_lid))
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
@ -701,7 +720,7 @@ def decode_dataset(
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
if params.lid: if params.lid:
lids_dict = {lid:id+1 for id, lid in enumerate(params.lids.split(","))} lids_dict = {lid: id + 1 for id, lid in enumerate(params.lids.split(","))}
text_list = [t.split("|") for t in texts] text_list = [t.split("|") for t in texts]
num_tokens = [[len(clean(t).split()) for t in utt] for utt in text_list] num_tokens = [[len(clean(t).split()) for t in utt] for utt in text_list]
@ -758,7 +777,6 @@ def decode_dataset(
if params.lid: if params.lid:
results_lid[name].extend(this_batch_lid) results_lid[name].extend(this_batch_lid)
num_cuts += len(texts) num_cuts += len(texts)
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
@ -766,7 +784,7 @@ def decode_dataset(
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
if params.lid: if params.lid:
return {"text":results, "lid":results_lid} return {"text": results, "lid": results_lid}
else: else:
return results return results
@ -820,19 +838,16 @@ def save_results_lid(
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
): ):
key = list(results_dict['text'].keys())[0] key = list(results_dict["text"].keys())[0]
results_text = sorted(results_dict['text'][key], key=lambda x: x[0]) results_text = sorted(results_dict["text"][key], key=lambda x: x[0])
results_lid = sorted(results_dict['lid'][key], key=lambda x: x[0]) results_lid = sorted(results_dict["lid"][key], key=lambda x: x[0])
test_set_f1s = dict() test_set_f1s = dict()
lid_path = ( lid_path = params.res_dir / f"lid-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir / f"lid-{test_set_name}-{key}-{params.suffix}.txt" f1_path = params.res_dir / f"f1-{test_set_name}-{key}-{params.suffix}.txt"
)
f1_path = (
params.res_dir / f"f1-{test_set_name}-{key}-{params.suffix}.txt"
)
write_lid_results(lid_path, f1_path, results_text, results_lid) write_lid_results(lid_path, f1_path, results_text, results_lid)
logging.info(f"The lids are stored in {lid_path}") logging.info(f"The lids are stored in {lid_path}")
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
@ -999,12 +1014,10 @@ def main():
model.eval() model.eval()
# only load the neural network LM if required # only load the neural network LM if required
if ( if params.use_shallow_fusion or params.decoding_method in (
params.use_shallow_fusion
or params.decoding_method in (
"modified_beam_search_lm_shallow_fusion", "modified_beam_search_lm_shallow_fusion",
"modified_beam_search_LODR", "modified_beam_search_LODR",
"modified_beam_search_lm_rescore_LODR",) "modified_beam_search_lm_rescore_LODR",
): ):
LM = LmScorer( LM = LmScorer(
lm_type=params.lm_type, lm_type=params.lm_type,

View File

@ -19,6 +19,7 @@ import torch.nn as nn
from scaling import ScaledLinear from scaling import ScaledLinear
from typing import Optional from typing import Optional
class Joiner(nn.Module): class Joiner(nn.Module):
def __init__( def __init__(
self, self,

View File

@ -141,7 +141,9 @@ class AsrModel(nn.Module):
src_key_padding_mask = make_pad_mask(x_lens) src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens, lid_output = self.encoder(x, x_lens, src_key_padding_mask) encoder_out, encoder_out_lens, lid_output = self.encoder(
x, x_lens, src_key_padding_mask
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
@ -217,7 +219,6 @@ class AsrModel(nn.Module):
""" """
# Now for the decoder, i.e., the prediction network # Now for the decoder, i.e., the prediction network
blank_id = self.decoder.blank_id blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id) sos_y = add_sos(y, sos_id=blank_id)
@ -291,11 +292,14 @@ class AsrModel(nn.Module):
ranges=ranges, ranges=ranges,
) )
lid_logits = self.lid_joiner( lid_logits = self.lid_joiner(
lid_am_pruned, lid_lm_pruned, project_input=False) lid_am_pruned, lid_lm_pruned, project_input=False
)
# project_input=False since we applied the decoder's input projections # project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False, lid_out=lid_pruned) logits = self.joiner(
am_pruned, lm_pruned, project_input=False, lid_out=lid_pruned
)
# Add blank logits to lid_logits # Add blank logits to lid_logits
logits = torch.cat((lid_logits[..., 0].unsqueeze(-1), logits), dim=-1) logits = torch.cat((lid_logits[..., 0].unsqueeze(-1), logits), dim=-1)
@ -315,7 +319,9 @@ class AsrModel(nn.Module):
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
pruned_lid_loss = k2.rnnt_loss_pruned( pruned_lid_loss = k2.rnnt_loss_pruned(
logits=lid_logits.float(), logits=lid_logits.float(),
symbols=y_lid.pad(mode="constant", padding_value=blank_id).to(torch.int64), symbols=y_lid.pad(mode="constant", padding_value=blank_id).to(
torch.int64
),
ranges=ranges, ranges=ranges,
termination_symbol=blank_id, termination_symbol=blank_id,
boundary=boundary, boundary=boundary,
@ -374,7 +380,9 @@ class AsrModel(nn.Module):
# Compute encoder outputs # Compute encoder outputs
if self.lid_joiner != None: if self.lid_joiner != None:
encoder_out, encoder_out_lens, lid_encoder_out = self.forward_encoder(x, x_lens) encoder_out, encoder_out_lens, lid_encoder_out = self.forward_encoder(
x, x_lens
)
else: else:
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)

View File

@ -366,7 +366,8 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--lid-value-head-dim", "--lid-value-head-dim",
type=str, type=str,
default="12", default="12",
help="Value dimension per head in encoder stacks: a single int or comma-separated list.",) help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
)
parser.add_argument( parser.add_argument(
"--lid-pos-head-dim", "--lid-pos-head-dim",
type=str, type=str,
@ -429,6 +430,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Whether to skip positional embedding in the lid encoder.", help="Whether to skip positional embedding in the lid encoder.",
) )
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -781,9 +783,11 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
causal=params.causal, causal=params.causal,
chunk_size=_to_int_tuple(params.chunk_size), chunk_size=_to_int_tuple(params.chunk_size),
left_context_frames=_to_int_tuple(params.left_context_frames), left_context_frames=_to_int_tuple(params.left_context_frames),
lid_output_layer=params.lid_output_layer if params.use_lid_encoder else None,) lid_output_layer=params.lid_output_layer if params.use_lid_encoder else None,
)
return encoder return encoder
def get_lid_encoder_model(params: AttributeDict) -> nn.Module: def get_lid_encoder_model(params: AttributeDict) -> nn.Module:
lid_encoder = Zipformer2( lid_encoder = Zipformer2(
output_downsampling_factor=2, output_downsampling_factor=2,
@ -806,6 +810,7 @@ def get_lid_encoder_model(params: AttributeDict) -> nn.Module:
) )
return lid_encoder return lid_encoder
def get_decoder_model(params: AttributeDict) -> nn.Module: def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
@ -826,15 +831,17 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
) )
return joiner return joiner
def get_lid_joiner_model(params: AttributeDict) -> nn.Module: def get_lid_joiner_model(params: AttributeDict) -> nn.Module:
lid_joiner = Joiner( lid_joiner = Joiner(
encoder_dim=int(params.lid_encoder_dim.split(",")[-1]), encoder_dim=int(params.lid_encoder_dim.split(",")[-1]),
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.lid_joiner_dim, joiner_dim=params.lid_joiner_dim,
vocab_size=len(params.lids.split(","))+1, vocab_size=len(params.lids.split(",")) + 1,
) )
return lid_joiner return lid_joiner
def get_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module:
assert params.use_transducer or params.use_ctc, ( assert params.use_transducer or params.use_ctc, (
f"At least one of them should be True, " f"At least one of them should be True, "
@ -858,7 +865,6 @@ def get_model(params: AttributeDict) -> nn.Module:
decoder = None decoder = None
joiner = None joiner = None
model = AsrModel( model = AsrModel(
encoder_embed=encoder_embed, encoder_embed=encoder_embed,
encoder=encoder, encoder=encoder,
@ -875,7 +881,6 @@ def get_model(params: AttributeDict) -> nn.Module:
return model return model
def load_checkpoint_if_available( def load_checkpoint_if_available(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -1018,7 +1023,7 @@ def compute_loss(
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
lids_dict = {lid:id+1 for id, lid in enumerate(params.lids.split(","))} lids_dict = {lid: id + 1 for id, lid in enumerate(params.lids.split(","))}
device = model.device if isinstance(model, DDP) else next(model.parameters()).device device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
@ -1066,9 +1071,7 @@ def compute_loss(
lid_pruned_loss_is_finite = torch.isfinite(lid_pruned_loss) lid_pruned_loss_is_finite = torch.isfinite(lid_pruned_loss)
is_finite = ( is_finite = (
simple_loss_is_finite simple_loss_is_finite & pruned_loss_is_finite & lid_pruned_loss_is_finite
& pruned_loss_is_finite
& lid_pruned_loss_is_finite
) )
if not torch.all(is_finite): if not torch.all(is_finite):
logging.info( logging.info(
@ -1096,12 +1099,13 @@ def compute_loss(
else 0.1 + 0.9 * (batch_idx_train / warm_step) else 0.1 + 0.9 * (batch_idx_train / warm_step)
) )
loss += (1 - params.lid_loss_scale) * (
loss += (1-params.lid_loss_scale)*(simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss) simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
#loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss )
# loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
if params.use_lid_joiner: if params.use_lid_joiner:
loss += params.lid_loss_scale * pruned_loss_scale * lid_pruned_loss loss += params.lid_loss_scale * pruned_loss_scale * lid_pruned_loss
#loss += pruned_loss_scale * lid_pruned_loss # loss += pruned_loss_scale * lid_pruned_loss
if params.use_ctc: if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss loss += params.ctc_loss_scale * ctc_loss