mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
black formating
This commit is contained in:
parent
c7f74e410f
commit
891cf55901
@ -14,12 +14,7 @@ import jiwer
|
|||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument("--dec-file", type=str, help="Decoded icefall recogs file")
|
||||||
"--dec-file",
|
|
||||||
type=str,
|
|
||||||
help="Decoded icefall recogs file"
|
|
||||||
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -29,22 +24,22 @@ def cer_(file):
|
|||||||
ref = []
|
ref = []
|
||||||
cer_results = 0
|
cer_results = 0
|
||||||
ref_lens = 0
|
ref_lens = 0
|
||||||
with open(file, 'r', encoding='utf-8') as dec:
|
with open(file, "r", encoding="utf-8") as dec:
|
||||||
for line in dec:
|
for line in dec:
|
||||||
id, target = line.split('\t')
|
id, target = line.split("\t")
|
||||||
id = id[0:-2]
|
id = id[0:-2]
|
||||||
target, txt = target.split("=")
|
target, txt = target.split("=")
|
||||||
if target == 'ref':
|
if target == "ref":
|
||||||
words = txt.strip().strip('[]').split(', ')
|
words = txt.strip().strip("[]").split(", ")
|
||||||
word_list = [word.strip("'") for word in words]
|
word_list = [word.strip("'") for word in words]
|
||||||
ref.append(" ".join(word_list))
|
ref.append(" ".join(word_list))
|
||||||
elif target == 'hyp':
|
elif target == "hyp":
|
||||||
words = txt.strip().strip('[]').split(', ')
|
words = txt.strip().strip("[]").split(", ")
|
||||||
word_list = [word.strip("'") for word in words]
|
word_list = [word.strip("'") for word in words]
|
||||||
hyp.append(" ".join(word_list))
|
hyp.append(" ".join(word_list))
|
||||||
for h, r in zip(hyp, ref):
|
for h, r in zip(hyp, ref):
|
||||||
if r:
|
if r:
|
||||||
cer_results += (jiwer.cer(r, h)*len(r))
|
cer_results += jiwer.cer(r, h) * len(r)
|
||||||
|
|
||||||
ref_lens += len(r)
|
ref_lens += len(r)
|
||||||
print(cer_results / ref_lens)
|
print(cer_results / ref_lens)
|
||||||
@ -55,5 +50,6 @@ def main():
|
|||||||
args = parse.parse_args()
|
args = parse.parse_args()
|
||||||
cer_(args.dec_file)
|
cer_(args.dec_file)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
@ -38,6 +38,7 @@ from lhotse.features.kaldifeat import (
|
|||||||
KaldifeatMelOptions,
|
KaldifeatMelOptions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -70,7 +71,7 @@ def get_args():
|
|||||||
def compute_fbank_gpu(args):
|
def compute_fbank_gpu(args):
|
||||||
src_dir = Path("data_seame/manifests")
|
src_dir = Path("data_seame/manifests")
|
||||||
output_dir = Path("data_seame/fbank")
|
output_dir = Path("data_seame/fbank")
|
||||||
num_jobs = min(os.cpu_count(),10)
|
num_jobs = min(os.cpu_count(), 10)
|
||||||
num_mel_bins = 80
|
num_mel_bins = 80
|
||||||
sampling_rate = 16000
|
sampling_rate = 16000
|
||||||
sr = 16000
|
sr = 16000
|
||||||
@ -87,7 +88,10 @@ def compute_fbank_gpu(args):
|
|||||||
suffix = "jsonl.gz"
|
suffix = "jsonl.gz"
|
||||||
breakpoint
|
breakpoint
|
||||||
manifests = read_manifests_if_cached(
|
manifests = read_manifests_if_cached(
|
||||||
prefix=prefix, dataset_parts=dataset_parts, output_dir=src_dir,suffix=suffix,
|
prefix=prefix,
|
||||||
|
dataset_parts=dataset_parts,
|
||||||
|
output_dir=src_dir,
|
||||||
|
suffix=suffix,
|
||||||
)
|
)
|
||||||
assert manifests is not None
|
assert manifests is not None
|
||||||
|
|
||||||
@ -116,15 +120,11 @@ def compute_fbank_gpu(args):
|
|||||||
cut_set = cut_set.resample(sr)
|
cut_set = cut_set.resample(sr)
|
||||||
|
|
||||||
cut_set = cut_set.trim_to_supervisions(
|
cut_set = cut_set.trim_to_supervisions(
|
||||||
keep_overlapping=False,
|
keep_overlapping=False, keep_all_channels=False
|
||||||
keep_all_channels=False)
|
|
||||||
cut_set = cut_set.filter(lambda c: c.duration >= .2 and c.duration <= 30)
|
|
||||||
if "train" in partition:
|
|
||||||
cut_set = (
|
|
||||||
cut_set
|
|
||||||
+ cut_set.perturb_speed(0.9)
|
|
||||||
+ cut_set.perturb_speed(1.1)
|
|
||||||
)
|
)
|
||||||
|
cut_set = cut_set.filter(lambda c: c.duration >= 0.2 and c.duration <= 30)
|
||||||
|
if "train" in partition:
|
||||||
|
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
cut_set = cut_set.compute_and_store_features_batch(
|
cut_set = cut_set.compute_and_store_features_batch(
|
||||||
extractor=extractor,
|
extractor=extractor,
|
||||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||||
@ -147,10 +147,9 @@ def compute_fbank_gpu(args):
|
|||||||
)
|
)
|
||||||
cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz")
|
cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = (
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
args = get_args()
|
args = get_args()
|
||||||
|
@ -37,6 +37,7 @@ from lhotse.features.kaldifeat import (
|
|||||||
KaldifeatMelOptions,
|
KaldifeatMelOptions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -69,7 +70,7 @@ def get_args():
|
|||||||
def compute_fbank_gpu(args):
|
def compute_fbank_gpu(args):
|
||||||
src_dir = Path("data_seame/manifests")
|
src_dir = Path("data_seame/manifests")
|
||||||
output_dir = Path("data_seame/fbank")
|
output_dir = Path("data_seame/fbank")
|
||||||
num_jobs = min(os.cpu_count(),10)
|
num_jobs = min(os.cpu_count(), 10)
|
||||||
num_mel_bins = 80
|
num_mel_bins = 80
|
||||||
sampling_rate = 16000
|
sampling_rate = 16000
|
||||||
sr = 16000
|
sr = 16000
|
||||||
@ -80,7 +81,6 @@ def compute_fbank_gpu(args):
|
|||||||
"train10",
|
"train10",
|
||||||
"train50",
|
"train50",
|
||||||
"train30",
|
"train30",
|
||||||
|
|
||||||
)
|
)
|
||||||
prefix = ""
|
prefix = ""
|
||||||
suffix = "jsonl.gz"
|
suffix = "jsonl.gz"
|
||||||
@ -103,15 +103,11 @@ def compute_fbank_gpu(args):
|
|||||||
cut_set = cut_set.resample(sr)
|
cut_set = cut_set.resample(sr)
|
||||||
|
|
||||||
cut_set = cut_set.trim_to_supervisions(
|
cut_set = cut_set.trim_to_supervisions(
|
||||||
keep_overlapping=False,
|
keep_overlapping=False, keep_all_channels=False
|
||||||
keep_all_channels=False)
|
|
||||||
cut_set = cut_set.filter(lambda c: c.duration >= .5 and c.duration <= 30)
|
|
||||||
if "train" in part:
|
|
||||||
cut_set = (
|
|
||||||
cut_set
|
|
||||||
+ cut_set.perturb_speed(0.9)
|
|
||||||
+ cut_set.perturb_speed(1.1)
|
|
||||||
)
|
)
|
||||||
|
cut_set = cut_set.filter(lambda c: c.duration >= 0.5 and c.duration <= 30)
|
||||||
|
if "train" in part:
|
||||||
|
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
cut_set = cut_set.compute_and_store_features_batch(
|
cut_set = cut_set.compute_and_store_features_batch(
|
||||||
extractor=extractor,
|
extractor=extractor,
|
||||||
storage_path=f"{output_dir}/{prefix}_feats_{part}",
|
storage_path=f"{output_dir}/{prefix}_feats_{part}",
|
||||||
@ -134,10 +130,9 @@ def compute_fbank_gpu(args):
|
|||||||
)
|
)
|
||||||
cut_set.to_file(output_dir / f"cuts_{part}.jsonl.gz")
|
cut_set.to_file(output_dir / f"cuts_{part}.jsonl.gz")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = (
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
args = get_args()
|
args = get_args()
|
||||||
|
@ -7,7 +7,6 @@ from lhotse.qa import fix_manifests, validate_recordings_and_supervisions
|
|||||||
import pdb
|
import pdb
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -44,7 +43,7 @@ def get_parser():
|
|||||||
|
|
||||||
def valid_asr(cut):
|
def valid_asr(cut):
|
||||||
tol = 2e-3
|
tol = 2e-3
|
||||||
i=0
|
i = 0
|
||||||
total_dur = 0
|
total_dur = 0
|
||||||
for c in cut:
|
for c in cut:
|
||||||
if c.supervisions != []:
|
if c.supervisions != []:
|
||||||
@ -52,10 +51,14 @@ def valid_asr(cut):
|
|||||||
|
|
||||||
logging.info(f"Supervision beyond the cut. Cut number: {i}")
|
logging.info(f"Supervision beyond the cut. Cut number: {i}")
|
||||||
total_dur += c.duration
|
total_dur += c.duration
|
||||||
logging.info(f"id: {c.id}, sup_end: {c.supervisions[0].end}, dur: {c.duration}, source {c.recording.sources[0].source}")
|
logging.info(
|
||||||
|
f"id: {c.id}, sup_end: {c.supervisions[0].end}, dur: {c.duration}, source {c.recording.sources[0].source}"
|
||||||
|
)
|
||||||
elif c.supervisions[0].start < -tol:
|
elif c.supervisions[0].start < -tol:
|
||||||
logging.info(f"Supervision starts before the cut. Cut number: {i}")
|
logging.info(f"Supervision starts before the cut. Cut number: {i}")
|
||||||
logging.info(f"id: {c.id}, sup_start: {c.supervisions[0].start}, dur: {c.duration}, source {c.recording.sources[0].source}")
|
logging.info(
|
||||||
|
f"id: {c.id}, sup_start: {c.supervisions[0].start}, dur: {c.duration}, source {c.recording.sources[0].source}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
@ -83,7 +86,10 @@ def main():
|
|||||||
logging.info("Validating manifests")
|
logging.info("Validating manifests")
|
||||||
validate_recordings_and_supervisions(recordings, supervisions)
|
validate_recordings_and_supervisions(recordings, supervisions)
|
||||||
|
|
||||||
cuts = CutSet.from_manifests(recordings= recordings, supervisions=supervisions,)
|
cuts = CutSet.from_manifests(
|
||||||
|
recordings=recordings,
|
||||||
|
supervisions=supervisions,
|
||||||
|
)
|
||||||
cuts = cuts.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
|
cuts = cuts.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
|
||||||
cuts.describe()
|
cuts.describe()
|
||||||
logging.info("Example from cut:")
|
logging.info("Example from cut:")
|
||||||
@ -93,5 +99,6 @@ def main():
|
|||||||
if args.savecut != "":
|
if args.savecut != "":
|
||||||
cuts.to_file(args.savecut)
|
cuts.to_file(args.savecut)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
@ -25,9 +25,7 @@ def main():
|
|||||||
for line in f:
|
for line in f:
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
characters = list(line)
|
characters = list(line)
|
||||||
characters = " ".join(
|
characters = " ".join(["V" if char == "*" else char for char in characters])
|
||||||
["V" if char == "*" else char for char in characters]
|
|
||||||
)
|
|
||||||
lex[line] = characters
|
lex[line] = characters
|
||||||
|
|
||||||
with open(args.output, "w", encoding="utf-8") as fp:
|
with open(args.output, "w", encoding="utf-8") as fp:
|
||||||
|
@ -44,11 +44,12 @@ def main():
|
|||||||
if not os.path.exists(langdir):
|
if not os.path.exists(langdir):
|
||||||
os.makedirs(langdir)
|
os.makedirs(langdir)
|
||||||
|
|
||||||
with open(langdir / "transcript_words.txt", 'w') as txt:
|
with open(langdir / "transcript_words.txt", "w") as txt:
|
||||||
for c in cuts:
|
for c in cuts:
|
||||||
#breakpoint()
|
# breakpoint()
|
||||||
text = c.supervisions[0].text
|
text = c.supervisions[0].text
|
||||||
txt.write(text + '\n')
|
txt.write(text + "\n")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
@ -68,8 +68,13 @@ def main():
|
|||||||
recordings = RecordingSet.from_file(args.rec)
|
recordings = RecordingSet.from_file(args.rec)
|
||||||
supervisions = SupervisionSet.from_file(args.sup)
|
supervisions = SupervisionSet.from_file(args.sup)
|
||||||
logging.info("Fixing manifests")
|
logging.info("Fixing manifests")
|
||||||
cuts = CutSet.from_manifests(recordings= recordings, supervisions=supervisions,)
|
cuts = CutSet.from_manifests(
|
||||||
cuts = cuts.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
|
recordings=recordings,
|
||||||
|
supervisions=supervisions,
|
||||||
|
)
|
||||||
|
cuts = cuts.trim_to_supervisions(
|
||||||
|
keep_overlapping=False, keep_all_channels=False
|
||||||
|
)
|
||||||
|
|
||||||
shuffled = cuts.shuffle()
|
shuffled = cuts.shuffle()
|
||||||
total_dur = 0
|
total_dur = 0
|
||||||
@ -86,5 +91,6 @@ def main():
|
|||||||
logging.info(f"Saving {args.outcut}")
|
logging.info(f"Saving {args.outcut}")
|
||||||
cuts.to_file(outdir / args.outcut)
|
cuts.to_file(outdir / args.outcut)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
@ -91,7 +91,7 @@ def main():
|
|||||||
user_defined_symbols = ["<blk>", "<sos/eos>"]
|
user_defined_symbols = ["<blk>", "<sos/eos>"]
|
||||||
unk_id = len(user_defined_symbols)
|
unk_id = len(user_defined_symbols)
|
||||||
if predef_sym:
|
if predef_sym:
|
||||||
syms = predef_sym.split(',')
|
syms = predef_sym.split(",")
|
||||||
for i in syms:
|
for i in syms:
|
||||||
user_defined_symbols.append(i)
|
user_defined_symbols.append(i)
|
||||||
# Note: unk_id is fixed to 2.
|
# Note: unk_id is fixed to 2.
|
||||||
@ -116,5 +116,6 @@ def main():
|
|||||||
shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
|
shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
|
||||||
generate_tokens(lang_dir)
|
generate_tokens(lang_dir)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -25,29 +25,31 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
lids = "en,zh"
|
lids = "en,zh"
|
||||||
lids_dict = {lid:id+1 for id, lid in enumerate(lids.split(","))}
|
lids_dict = {lid: id + 1 for id, lid in enumerate(lids.split(","))}
|
||||||
id2lang = {id+1: lid for id, lid in enumerate(lids.split(","))}
|
id2lang = {id + 1: lid for id, lid in enumerate(lids.split(","))}
|
||||||
bad_id = []
|
bad_id = []
|
||||||
|
|
||||||
|
|
||||||
def extract_info(line, info):
|
def extract_info(line, info):
|
||||||
# Split the line at the first colon to separate the ID
|
# Split the line at the first colon to separate the ID
|
||||||
id_part, rest = line.split(':', 1)
|
id_part, rest = line.split(":", 1)
|
||||||
|
|
||||||
# Extract 'ref' by finding its start and end
|
# Extract 'ref' by finding its start and end
|
||||||
ref_start = rest.find(info)
|
ref_start = rest.find(info)
|
||||||
ref_end = rest.find(']', ref_start)
|
ref_end = rest.find("]", ref_start)
|
||||||
ref = rest[ref_start+len(info):ref_end].replace("'", "").split(', ')
|
ref = rest[ref_start + len(info) : ref_end].replace("'", "").split(", ")
|
||||||
|
|
||||||
# Extract 'lid'
|
# Extract 'lid'
|
||||||
if 'lid=' in rest:
|
if "lid=" in rest:
|
||||||
lid_start = rest.find('lid=[')
|
lid_start = rest.find("lid=[")
|
||||||
lid_end = rest.find(']', lid_start)
|
lid_end = rest.find("]", lid_start)
|
||||||
lid = rest[lid_start+len('lid=['):lid_end].split(', ')
|
lid = rest[lid_start + len("lid=[") : lid_end].split(", ")
|
||||||
else:
|
else:
|
||||||
lid = ['']
|
lid = [""]
|
||||||
|
|
||||||
if lid[0]=='':
|
if lid[0] == "":
|
||||||
bad_id.append(id_part)
|
bad_id.append(id_part)
|
||||||
if " ".join(lid):
|
if " ".join(lid):
|
||||||
lid = [int(i) for i in lid] # Convert each element to integer
|
lid = [int(i) for i in lid] # Convert each element to integer
|
||||||
@ -58,6 +60,7 @@ def is_English(c):
|
|||||||
"""check character is in English"""
|
"""check character is in English"""
|
||||||
return ord(c.lower()) >= ord("a") and ord(c.lower()) <= ord("z")
|
return ord(c.lower()) >= ord("a") and ord(c.lower()) <= ord("z")
|
||||||
|
|
||||||
|
|
||||||
def get_en(text):
|
def get_en(text):
|
||||||
res = []
|
res = []
|
||||||
for w in text:
|
for w in text:
|
||||||
@ -68,6 +71,7 @@ def get_en(text):
|
|||||||
continue
|
continue
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def get_zh(text):
|
def get_zh(text):
|
||||||
res = []
|
res = []
|
||||||
for w in text:
|
for w in text:
|
||||||
@ -79,34 +83,33 @@ def get_zh(text):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def extract_info_lid(line, tag):
|
def extract_info_lid(line, tag):
|
||||||
# Split the line at the first colon to separate the ID
|
# Split the line at the first colon to separate the ID
|
||||||
id_part, rest = line.split(':', 1)
|
id_part, rest = line.split(":", 1)
|
||||||
|
|
||||||
# Extract 'ref' by finding its start and end
|
# Extract 'ref' by finding its start and end
|
||||||
|
|
||||||
ref_start = rest.find(tag)
|
ref_start = rest.find(tag)
|
||||||
ref_end = rest.find(']', ref_start)
|
ref_end = rest.find("]", ref_start)
|
||||||
ref = rest[ref_start+len(tag):ref_end].replace("'", "").split(', ')
|
ref = rest[ref_start + len(tag) : ref_end].replace("'", "").split(", ")
|
||||||
|
|
||||||
return id_part.strip(), ref
|
return id_part.strip(), ref
|
||||||
|
|
||||||
|
|
||||||
def align_lid2(labels_a, labels_b, a, b):
|
def align_lid2(labels_a, labels_b, a, b):
|
||||||
# Alignment
|
# Alignment
|
||||||
EPS = '*'
|
EPS = "*"
|
||||||
ali = align(a, b, EPS, sclite_mode=True)
|
ali = align(a, b, EPS, sclite_mode=True)
|
||||||
|
|
||||||
a2idx = {(i,idx):j for idx,(i,j) in enumerate(zip(a,labels_a))}
|
a2idx = {(i, idx): j for idx, (i, j) in enumerate(zip(a, labels_a))}
|
||||||
b2idx = {(i,idx):j for idx,(i,j) in enumerate(zip(b,labels_b))}
|
b2idx = {(i, idx): j for idx, (i, j) in enumerate(zip(b, labels_b))}
|
||||||
# Comparing labels of aligned elements
|
# Comparing labels of aligned elements
|
||||||
idx_a = 0
|
idx_a = 0
|
||||||
idx_b = 0
|
idx_b = 0
|
||||||
ali_idx=0
|
ali_idx = 0
|
||||||
aligned_a = []
|
aligned_a = []
|
||||||
aligned_b = []
|
aligned_b = []
|
||||||
while idx_a <len(a) and idx_b <len(b) and ali_idx < len(ali):
|
while idx_a < len(a) and idx_b < len(b) and ali_idx < len(ali):
|
||||||
elem_a, elem_b = ali[ali_idx]
|
elem_a, elem_b = ali[ali_idx]
|
||||||
if elem_a == EPS:
|
if elem_a == EPS:
|
||||||
idx_b += 1
|
idx_b += 1
|
||||||
@ -114,8 +117,8 @@ def align_lid2(labels_a, labels_b, a, b):
|
|||||||
idx_a += 1
|
idx_a += 1
|
||||||
elif elem_a != EPS and elem_b != EPS:
|
elif elem_a != EPS and elem_b != EPS:
|
||||||
|
|
||||||
label_a = a2idx[(elem_a,idx_a)]
|
label_a = a2idx[(elem_a, idx_a)]
|
||||||
label_b = b2idx[(elem_b,idx_b)]
|
label_b = b2idx[(elem_b, idx_b)]
|
||||||
aligned_a.append(label_a)
|
aligned_a.append(label_a)
|
||||||
aligned_b.append(label_b)
|
aligned_b.append(label_b)
|
||||||
idx_b += 1
|
idx_b += 1
|
||||||
@ -128,7 +131,7 @@ def align_lid2(labels_a, labels_b, a, b):
|
|||||||
def align_lid(labels_a, labels_b):
|
def align_lid(labels_a, labels_b):
|
||||||
# Alignment
|
# Alignment
|
||||||
res_a, res_b = [], []
|
res_a, res_b = [], []
|
||||||
EPS = '*'
|
EPS = "*"
|
||||||
ali = align(labels_a, labels_b, EPS, sclite_mode=True)
|
ali = align(labels_a, labels_b, EPS, sclite_mode=True)
|
||||||
|
|
||||||
# Comparing labels of aligned elements
|
# Comparing labels of aligned elements
|
||||||
@ -139,17 +142,18 @@ def align_lid(labels_a, labels_b):
|
|||||||
|
|
||||||
|
|
||||||
def read_file(infile, tag):
|
def read_file(infile, tag):
|
||||||
""""returns list of dict (id, lid, text)"""
|
""" "returns list of dict (id, lid, text)"""
|
||||||
res = []
|
res = []
|
||||||
with open(infile, 'r') as file:
|
with open(infile, "r") as file:
|
||||||
for line in file:
|
for line in file:
|
||||||
_, rest = line.split(':', 1)
|
_, rest = line.split(":", 1)
|
||||||
if tag in rest:
|
if tag in rest:
|
||||||
_id, text = extract_info_lid(line, tag)
|
_id, text = extract_info_lid(line, tag)
|
||||||
|
|
||||||
res.append((_id, text))
|
res.append((_id, text))
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def wer(results, sclite_mode=False):
|
def wer(results, sclite_mode=False):
|
||||||
subs = defaultdict(int)
|
subs = defaultdict(int)
|
||||||
ins = defaultdict(int)
|
ins = defaultdict(int)
|
||||||
@ -185,12 +189,12 @@ def wer(results, sclite_mode=False):
|
|||||||
return tot_err_rate
|
return tot_err_rate
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
ref_data = read_file(args.rec, 'ref=[')
|
ref_data = read_file(args.rec, "ref=[")
|
||||||
ref_data = sorted(ref_data)
|
ref_data = sorted(ref_data)
|
||||||
hyp_data = read_file(args.rec, 'hyp=[')
|
hyp_data = read_file(args.rec, "hyp=[")
|
||||||
hyp_data = sorted(hyp_data)
|
hyp_data = sorted(hyp_data)
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
|
|
||||||
@ -205,16 +209,9 @@ if __name__ == '__main__':
|
|||||||
hyp_text_en = get_en(hyp_text)
|
hyp_text_en = get_en(hyp_text)
|
||||||
hyp_text_zh = get_zh(hyp_text)
|
hyp_text_zh = get_zh(hyp_text)
|
||||||
|
|
||||||
|
results["en"].append((ref[0], ref_text_en, hyp_text_en))
|
||||||
results['en'].append((ref[0],ref_text_en, hyp_text_en))
|
results["zh"].append((ref[0], ref_text_zh, hyp_text_zh))
|
||||||
results['zh'].append((ref[0],ref_text_zh, hyp_text_zh))
|
|
||||||
|
|
||||||
for key, val in results.items():
|
for key, val in results.items():
|
||||||
print(key)
|
print(key)
|
||||||
res = wer(val)
|
res = wer(val)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -399,25 +399,18 @@ class SeameAsrDataModule:
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_cuts(self) -> CutSet:
|
def train_cuts(self) -> CutSet:
|
||||||
logging.info("Train data: About to get training cuts")
|
logging.info("Train data: About to get training cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz")
|
||||||
self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def valid_cuts(self) -> CutSet:
|
def valid_cuts(self) -> CutSet:
|
||||||
logging.info("Dev data: About to get develop cuts")
|
logging.info("Dev data: About to get develop cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(self.args.manifest_dir / "cuts_valid.jsonl.gz")
|
||||||
self.args.manifest_dir / "cuts_valid.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def dev_man(self) -> CutSet:
|
def dev_man(self) -> CutSet:
|
||||||
logging.info("About to get dev_man cuts")
|
logging.info("About to get dev_man cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(self.args.manifest_dir / "cuts_dev_man.jsonl.gz")
|
||||||
self.args.manifest_dir / "cuts_dev_man.jsonl.gz"
|
|
||||||
)
|
|
||||||
def dev_sge(self) -> CutSet:
|
def dev_sge(self) -> CutSet:
|
||||||
logging.info("About to get dev_sge cuts")
|
logging.info("About to get dev_sge cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(self.args.manifest_dir / "cuts_dev_sge.jsonl.gz")
|
||||||
self.args.manifest_dir / "cuts_dev_sge.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
@ -111,6 +111,7 @@ import re
|
|||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
def remove_punc(text):
|
def remove_punc(text):
|
||||||
"""This function removes all English punctuations except the single quote (verbatim)."""
|
"""This function removes all English punctuations except the single quote (verbatim)."""
|
||||||
|
|
||||||
@ -119,20 +120,22 @@ def remove_punc(text):
|
|||||||
# english_punctuations = english_punctuations.replace("'", "")
|
# english_punctuations = english_punctuations.replace("'", "")
|
||||||
|
|
||||||
# Create a translation table that maps each punctuation to a space.
|
# Create a translation table that maps each punctuation to a space.
|
||||||
translator = str.maketrans(english_punctuations, ' ' * len(english_punctuations))
|
translator = str.maketrans(english_punctuations, " " * len(english_punctuations))
|
||||||
|
|
||||||
# Translate the text using the translation table
|
# Translate the text using the translation table
|
||||||
text = text.translate(translator)
|
text = text.translate(translator)
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def clean(text):
|
def clean(text):
|
||||||
text = remove_punc(text)
|
text = remove_punc(text)
|
||||||
text = text.lower()
|
text = text.lower()
|
||||||
text = re.sub(r'\s+', ' ', text)
|
text = re.sub(r"\s+", " ", text)
|
||||||
text = text.rstrip()
|
text = text.rstrip()
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
@ -1384,5 +1384,6 @@ def main():
|
|||||||
else:
|
else:
|
||||||
run(rank=0, world_size=1, args=args)
|
run(rank=0, world_size=1, args=args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -924,9 +924,9 @@ def modified_beam_search_lm_rescore_LODR(
|
|||||||
# is equivalent to log(1 - sigmoid(logits[..., 0])).
|
# is equivalent to log(1 - sigmoid(logits[..., 0])).
|
||||||
nb_shift = logp_b - logits[..., 0]
|
nb_shift = logp_b - logits[..., 0]
|
||||||
nb_shift = nb_shift.unsqueeze(-1)
|
nb_shift = nb_shift.unsqueeze(-1)
|
||||||
log_probs1 = (logits[..., 1:]/ temperature).log_softmax(dim=-1) + nb_shift
|
log_probs1 = (logits[..., 1:] / temperature).log_softmax(dim=-1) + nb_shift
|
||||||
|
|
||||||
#log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
|
# log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||||
log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1)
|
log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1)
|
||||||
|
|
||||||
log_probs.add_(ys_log_probs)
|
log_probs.add_(ys_log_probs)
|
||||||
|
@ -91,6 +91,7 @@ import re
|
|||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
def remove_punc(text):
|
def remove_punc(text):
|
||||||
"""This function removes all English punctuations except the single quote (verbatim)."""
|
"""This function removes all English punctuations except the single quote (verbatim)."""
|
||||||
|
|
||||||
@ -99,20 +100,22 @@ def remove_punc(text):
|
|||||||
# english_punctuations = english_punctuations.replace("'", "")
|
# english_punctuations = english_punctuations.replace("'", "")
|
||||||
|
|
||||||
# Create a translation table that maps each punctuation to a space.
|
# Create a translation table that maps each punctuation to a space.
|
||||||
translator = str.maketrans(english_punctuations, ' ' * len(english_punctuations))
|
translator = str.maketrans(english_punctuations, " " * len(english_punctuations))
|
||||||
|
|
||||||
# Translate the text using the translation table
|
# Translate the text using the translation table
|
||||||
text = text.translate(translator)
|
text = text.translate(translator)
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def clean(text):
|
def clean(text):
|
||||||
text = remove_punc(text)
|
text = remove_punc(text)
|
||||||
text = text.lower()
|
text = text.lower()
|
||||||
text = re.sub(r'\s+', ' ', text)
|
text = re.sub(r"\s+", " ", text)
|
||||||
text = text.rstrip()
|
text = text.rstrip()
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -813,12 +816,10 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# only load the neural network LM if required
|
# only load the neural network LM if required
|
||||||
if (
|
if params.use_shallow_fusion or params.decoding_method in (
|
||||||
params.use_shallow_fusion
|
|
||||||
or params.decoding_method in (
|
|
||||||
"modified_beam_search_lm_shallow_fusion",
|
"modified_beam_search_lm_shallow_fusion",
|
||||||
"modified_beam_search_LODR",
|
"modified_beam_search_LODR",
|
||||||
"modified_beam_search_lm_rescore_LODR",)
|
"modified_beam_search_lm_rescore_LODR",
|
||||||
):
|
):
|
||||||
LM = LmScorer(
|
LM = LmScorer(
|
||||||
lm_type=params.lm_type,
|
lm_type=params.lm_type,
|
||||||
|
@ -349,7 +349,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train-size",
|
"--train-size",
|
||||||
type=str,
|
type=str,
|
||||||
default='full',
|
default="full",
|
||||||
help="train datasize",
|
help="train datasize",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -551,7 +551,7 @@ def get_params() -> AttributeDict:
|
|||||||
"valid_interval": 2000, # For the 100h subset, use 800
|
"valid_interval": 2000, # For the 100h subset, use 800
|
||||||
# parameters for zipformer
|
# parameters for zipformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
#"model_warm_step": 5000,
|
# "model_warm_step": 5000,
|
||||||
"subsampling_factor": 4, # not passed in, this is fixed.
|
"subsampling_factor": 4, # not passed in, this is fixed.
|
||||||
"warm_step": 5000,
|
"warm_step": 5000,
|
||||||
# parameters for ctc loss
|
# parameters for ctc loss
|
||||||
@ -1199,11 +1199,11 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
seame = SeameAsrDataModule(args)
|
seame = SeameAsrDataModule(args)
|
||||||
|
|
||||||
if params.train_size == '30':
|
if params.train_size == "30":
|
||||||
train_cuts = seame.train30_cuts()
|
train_cuts = seame.train30_cuts()
|
||||||
elif params.train_size == '10':
|
elif params.train_size == "10":
|
||||||
train_cuts = seame.train10_cuts()
|
train_cuts = seame.train10_cuts()
|
||||||
elif params.train_size == '50':
|
elif params.train_size == "50":
|
||||||
train_cuts = seame.train50_cuts()
|
train_cuts = seame.train50_cuts()
|
||||||
else:
|
else:
|
||||||
train_cuts = seame.train_cuts()
|
train_cuts = seame.train_cuts()
|
||||||
@ -1379,5 +1379,6 @@ def main():
|
|||||||
else:
|
else:
|
||||||
run(rank=0, world_size=1, args=args)
|
run(rank=0, world_size=1, args=args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -37,6 +37,7 @@ from icefall.utils import (
|
|||||||
get_texts_with_timestamp,
|
get_texts_with_timestamp,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Result:
|
class Result:
|
||||||
# timestamps[k] contains the frame number on which tokens[k]
|
# timestamps[k] contains the frame number on which tokens[k]
|
||||||
@ -465,7 +466,9 @@ def modified_beam_search(
|
|||||||
lid_current_encoder_out = lid_encoder_out.data[start:end]
|
lid_current_encoder_out = lid_encoder_out.data[start:end]
|
||||||
|
|
||||||
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||||
asr_lid_current_encoder_out = asr_lid_current_encoder_out.unsqueeze(1).unsqueeze(1)
|
asr_lid_current_encoder_out = asr_lid_current_encoder_out.unsqueeze(
|
||||||
|
1
|
||||||
|
).unsqueeze(1)
|
||||||
lid_current_encoder_out = lid_current_encoder_out.unsqueeze(1).unsqueeze(1)
|
lid_current_encoder_out = lid_current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||||
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||||
offset = end
|
offset = end
|
||||||
@ -492,7 +495,6 @@ def modified_beam_search(
|
|||||||
decoder_out = model.joiner.decoder_proj(decoder_out_)
|
decoder_out = model.joiner.decoder_proj(decoder_out_)
|
||||||
lid_decoder_out = model.lid_joiner.decoder_proj(decoder_out_)
|
lid_decoder_out = model.lid_joiner.decoder_proj(decoder_out_)
|
||||||
|
|
||||||
|
|
||||||
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
|
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
|
||||||
|
|
||||||
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||||
@ -879,6 +881,7 @@ def modified_beam_search_lm_shallow_fusion(
|
|||||||
timestamps=ans_timestamps,
|
timestamps=ans_timestamps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def modified_beam_search_auxlm_shallow_fusion(
|
def modified_beam_search_auxlm_shallow_fusion(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
@ -1160,6 +1163,7 @@ def modified_beam_search_auxlm_shallow_fusion(
|
|||||||
timestamps=ans_timestamps,
|
timestamps=ans_timestamps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def modified_beam_search_lm_rescore_LODR(
|
def modified_beam_search_lm_rescore_LODR(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
@ -1282,9 +1286,9 @@ def modified_beam_search_lm_rescore_LODR(
|
|||||||
# is equivalent to log(1 - sigmoid(logits[..., 0])).
|
# is equivalent to log(1 - sigmoid(logits[..., 0])).
|
||||||
nb_shift = logp_b - logits[..., 0]
|
nb_shift = logp_b - logits[..., 0]
|
||||||
nb_shift = nb_shift.unsqueeze(-1)
|
nb_shift = nb_shift.unsqueeze(-1)
|
||||||
log_probs1 = (logits[..., 1:]/ temperature).log_softmax(dim=-1) + nb_shift
|
log_probs1 = (logits[..., 1:] / temperature).log_softmax(dim=-1) + nb_shift
|
||||||
|
|
||||||
#log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
|
# log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||||
log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1)
|
log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1)
|
||||||
|
|
||||||
log_probs.add_(ys_log_probs)
|
log_probs.add_(ys_log_probs)
|
||||||
|
@ -118,6 +118,7 @@ import matplotlib.pyplot as plt
|
|||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
def remove_punc(text):
|
def remove_punc(text):
|
||||||
"""This function removes all English punctuations except the single quote (verbatim)."""
|
"""This function removes all English punctuations except the single quote (verbatim)."""
|
||||||
|
|
||||||
@ -126,21 +127,23 @@ def remove_punc(text):
|
|||||||
english_punctuations = english_punctuations.replace("'", "")
|
english_punctuations = english_punctuations.replace("'", "")
|
||||||
|
|
||||||
# Create a translation table that maps each punctuation to a space.
|
# Create a translation table that maps each punctuation to a space.
|
||||||
#translator = str.maketrans(english_punctuations, ' ' * len(english_punctuations))
|
# translator = str.maketrans(english_punctuations, ' ' * len(english_punctuations))
|
||||||
translator = str.maketrans('', '', english_punctuations)
|
translator = str.maketrans("", "", english_punctuations)
|
||||||
|
|
||||||
# Translate the text using the translation table
|
# Translate the text using the translation table
|
||||||
text = text.translate(translator)
|
text = text.translate(translator)
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def clean(text):
|
def clean(text):
|
||||||
text = remove_punc(text)
|
text = remove_punc(text)
|
||||||
text = text.lower()
|
text = text.lower()
|
||||||
text = re.sub(r'\s+', ' ', text)
|
text = re.sub(r"\s+", " ", text)
|
||||||
text = text.rstrip()
|
text = text.rstrip()
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -399,20 +402,21 @@ def get_parser():
|
|||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def align_lid(labels_a, labels_b, a, b):
|
def align_lid(labels_a, labels_b, a, b):
|
||||||
# Alignment
|
# Alignment
|
||||||
EPS = '*'
|
EPS = "*"
|
||||||
ali = align(a, b, EPS, sclite_mode=True)
|
ali = align(a, b, EPS, sclite_mode=True)
|
||||||
|
|
||||||
a2idx = {(i,idx):j for idx,(i,j) in enumerate(zip(a,labels_a))}
|
a2idx = {(i, idx): j for idx, (i, j) in enumerate(zip(a, labels_a))}
|
||||||
b2idx = {(i,idx):j for idx,(i,j) in enumerate(zip(b,labels_b))}
|
b2idx = {(i, idx): j for idx, (i, j) in enumerate(zip(b, labels_b))}
|
||||||
# Comparing labels of aligned elements
|
# Comparing labels of aligned elements
|
||||||
idx_a = 0
|
idx_a = 0
|
||||||
idx_b = 0
|
idx_b = 0
|
||||||
ali_idx=0
|
ali_idx = 0
|
||||||
aligned_a = []
|
aligned_a = []
|
||||||
aligned_b = []
|
aligned_b = []
|
||||||
while idx_a <len(a) and idx_b <len(b) and ali_idx < len(ali):
|
while idx_a < len(a) and idx_b < len(b) and ali_idx < len(ali):
|
||||||
elem_a, elem_b = ali[ali_idx]
|
elem_a, elem_b = ali[ali_idx]
|
||||||
if elem_a == EPS:
|
if elem_a == EPS:
|
||||||
idx_b += 1
|
idx_b += 1
|
||||||
@ -420,8 +424,8 @@ def align_lid(labels_a, labels_b, a, b):
|
|||||||
idx_a += 1
|
idx_a += 1
|
||||||
elif elem_a != EPS and elem_b != EPS:
|
elif elem_a != EPS and elem_b != EPS:
|
||||||
|
|
||||||
label_a = a2idx[(elem_a,idx_a)]
|
label_a = a2idx[(elem_a, idx_a)]
|
||||||
label_b = b2idx[(elem_b,idx_b)]
|
label_b = b2idx[(elem_b, idx_b)]
|
||||||
aligned_a.append(label_a)
|
aligned_a.append(label_a)
|
||||||
aligned_b.append(label_b)
|
aligned_b.append(label_b)
|
||||||
idx_b += 1
|
idx_b += 1
|
||||||
@ -430,29 +434,33 @@ def align_lid(labels_a, labels_b, a, b):
|
|||||||
ali_idx += 1
|
ali_idx += 1
|
||||||
return aligned_a, aligned_b
|
return aligned_a, aligned_b
|
||||||
|
|
||||||
def write_lid_results(lid_path, f1_path, text, lid ):
|
|
||||||
|
def write_lid_results(lid_path, f1_path, text, lid):
|
||||||
lid_hyp = []
|
lid_hyp = []
|
||||||
lid_ref = []
|
lid_ref = []
|
||||||
|
|
||||||
with open(lid_path, 'w') as file:
|
with open(lid_path, "w") as file:
|
||||||
# Write each line to the file
|
# Write each line to the file
|
||||||
for text_line, lid_line in zip(text, lid):
|
for text_line, lid_line in zip(text, lid):
|
||||||
file.write(f"{text_line[0]}: ref={text_line[1]} lid={lid_line[1]}" + '\n')
|
file.write(f"{text_line[0]}: ref={text_line[1]} lid={lid_line[1]}" + "\n")
|
||||||
aligned_ref, aligned_hyp = align_lid(lid_line[1],lid_line[2],text_line[1], text_line[2])
|
aligned_ref, aligned_hyp = align_lid(
|
||||||
|
lid_line[1], lid_line[2], text_line[1], text_line[2]
|
||||||
|
)
|
||||||
lid_ref.extend(aligned_ref)
|
lid_ref.extend(aligned_ref)
|
||||||
lid_hyp.extend(aligned_hyp)
|
lid_hyp.extend(aligned_hyp)
|
||||||
file.write(f"{lid_line[0]}: hyp={text_line[2]} lid={lid_line[2]}" + '\n')
|
file.write(f"{lid_line[0]}: hyp={text_line[2]} lid={lid_line[2]}" + "\n")
|
||||||
|
|
||||||
report = classification_report(lid_ref, lid_hyp, zero_division=0)
|
report = classification_report(lid_ref, lid_hyp, zero_division=0)
|
||||||
f1 = f1_score(lid_ref, lid_hyp, average='weighted')
|
f1 = f1_score(lid_ref, lid_hyp, average="weighted")
|
||||||
|
|
||||||
with open(f1_path, 'w') as file:
|
with open(f1_path, "w") as file:
|
||||||
file.write(report)
|
file.write(report)
|
||||||
file.write('\n')
|
file.write("\n")
|
||||||
file.write(f"F1 score: {f1} \n")
|
file.write(f"F1 score: {f1} \n")
|
||||||
filename = os.path.basename(lid_path).replace('.txt', '.png')
|
filename = os.path.basename(lid_path).replace(".txt", ".png")
|
||||||
dirname = os.path.dirname(lid_path)
|
dirname = os.path.dirname(lid_path)
|
||||||
save_conf_mat(os.path.join(dirname,filename), lid_ref, lid_hyp)
|
save_conf_mat(os.path.join(dirname, filename), lid_ref, lid_hyp)
|
||||||
|
|
||||||
|
|
||||||
def save_conf_mat(path, lid_ref, lid_hyp):
|
def save_conf_mat(path, lid_ref, lid_hyp):
|
||||||
all_labels = [1, 2, 3, 4]
|
all_labels = [1, 2, 3, 4]
|
||||||
@ -463,15 +471,23 @@ def save_conf_mat(path, lid_ref, lid_hyp):
|
|||||||
|
|
||||||
# Plot the confusion matrix
|
# Plot the confusion matrix
|
||||||
plt.figure(figsize=(10, 7))
|
plt.figure(figsize=(10, 7))
|
||||||
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
|
sns.heatmap(
|
||||||
plt.title('Confusion Matrix')
|
cm,
|
||||||
plt.xlabel('Predicted Labels')
|
annot=True,
|
||||||
plt.ylabel('True Labels')
|
fmt="d",
|
||||||
|
cmap="Blues",
|
||||||
|
xticklabels=class_names,
|
||||||
|
yticklabels=class_names,
|
||||||
|
)
|
||||||
|
plt.title("Confusion Matrix")
|
||||||
|
plt.xlabel("Predicted Labels")
|
||||||
|
plt.ylabel("True Labels")
|
||||||
plt.savefig(path)
|
plt.savefig(path)
|
||||||
|
|
||||||
|
|
||||||
def most_frequent(List):
|
def most_frequent(List):
|
||||||
return max(set(List), key = List.count)
|
return max(set(List), key=List.count)
|
||||||
|
|
||||||
|
|
||||||
def mapp(enc, LID):
|
def mapp(enc, LID):
|
||||||
pt1 = 0
|
pt1 = 0
|
||||||
@ -492,6 +508,7 @@ def mapp(enc, LID):
|
|||||||
|
|
||||||
return new_lid
|
return new_lid
|
||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -561,7 +578,9 @@ def decode_one_batch(
|
|||||||
value=LOG_EPS,
|
value=LOG_EPS,
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens, lid_encoder_out = model.forward_encoder(feature, feature_lens)
|
encoder_out, encoder_out_lens, lid_encoder_out = model.forward_encoder(
|
||||||
|
feature, feature_lens
|
||||||
|
)
|
||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
B, T, F = feature.shape
|
B, T, F = feature.shape
|
||||||
@ -579,7 +598,7 @@ def decode_one_batch(
|
|||||||
hyp = results[i]
|
hyp = results[i]
|
||||||
token_pieces = sp.IdToPiece(results[i].hyps)
|
token_pieces = sp.IdToPiece(results[i].hyps)
|
||||||
new_lid = mapp(token_pieces, results[i].lid_hyps)
|
new_lid = mapp(token_pieces, results[i].lid_hyps)
|
||||||
hyps.append((sp.decode(results[i].hyps).split(),new_lid))
|
hyps.append((sp.decode(results[i].hyps).split(), new_lid))
|
||||||
|
|
||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
@ -701,7 +720,7 @@ def decode_dataset(
|
|||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
if params.lid:
|
if params.lid:
|
||||||
lids_dict = {lid:id+1 for id, lid in enumerate(params.lids.split(","))}
|
lids_dict = {lid: id + 1 for id, lid in enumerate(params.lids.split(","))}
|
||||||
|
|
||||||
text_list = [t.split("|") for t in texts]
|
text_list = [t.split("|") for t in texts]
|
||||||
num_tokens = [[len(clean(t).split()) for t in utt] for utt in text_list]
|
num_tokens = [[len(clean(t).split()) for t in utt] for utt in text_list]
|
||||||
@ -758,7 +777,6 @@ def decode_dataset(
|
|||||||
if params.lid:
|
if params.lid:
|
||||||
results_lid[name].extend(this_batch_lid)
|
results_lid[name].extend(this_batch_lid)
|
||||||
|
|
||||||
|
|
||||||
num_cuts += len(texts)
|
num_cuts += len(texts)
|
||||||
|
|
||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
@ -766,7 +784,7 @@ def decode_dataset(
|
|||||||
|
|
||||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||||
if params.lid:
|
if params.lid:
|
||||||
return {"text":results, "lid":results_lid}
|
return {"text": results, "lid": results_lid}
|
||||||
else:
|
else:
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@ -820,19 +838,16 @@ def save_results_lid(
|
|||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
):
|
):
|
||||||
key = list(results_dict['text'].keys())[0]
|
key = list(results_dict["text"].keys())[0]
|
||||||
results_text = sorted(results_dict['text'][key], key=lambda x: x[0])
|
results_text = sorted(results_dict["text"][key], key=lambda x: x[0])
|
||||||
results_lid = sorted(results_dict['lid'][key], key=lambda x: x[0])
|
results_lid = sorted(results_dict["lid"][key], key=lambda x: x[0])
|
||||||
test_set_f1s = dict()
|
test_set_f1s = dict()
|
||||||
lid_path = (
|
lid_path = params.res_dir / f"lid-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
params.res_dir / f"lid-{test_set_name}-{key}-{params.suffix}.txt"
|
f1_path = params.res_dir / f"f1-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
|
||||||
f1_path = (
|
|
||||||
params.res_dir / f"f1-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
write_lid_results(lid_path, f1_path, results_text, results_lid)
|
write_lid_results(lid_path, f1_path, results_text, results_lid)
|
||||||
logging.info(f"The lids are stored in {lid_path}")
|
logging.info(f"The lids are stored in {lid_path}")
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
@ -999,12 +1014,10 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# only load the neural network LM if required
|
# only load the neural network LM if required
|
||||||
if (
|
if params.use_shallow_fusion or params.decoding_method in (
|
||||||
params.use_shallow_fusion
|
|
||||||
or params.decoding_method in (
|
|
||||||
"modified_beam_search_lm_shallow_fusion",
|
"modified_beam_search_lm_shallow_fusion",
|
||||||
"modified_beam_search_LODR",
|
"modified_beam_search_LODR",
|
||||||
"modified_beam_search_lm_rescore_LODR",)
|
"modified_beam_search_lm_rescore_LODR",
|
||||||
):
|
):
|
||||||
LM = LmScorer(
|
LM = LmScorer(
|
||||||
lm_type=params.lm_type,
|
lm_type=params.lm_type,
|
||||||
|
@ -19,6 +19,7 @@ import torch.nn as nn
|
|||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class Joiner(nn.Module):
|
class Joiner(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -141,7 +141,9 @@ class AsrModel(nn.Module):
|
|||||||
src_key_padding_mask = make_pad_mask(x_lens)
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens, lid_output = self.encoder(x, x_lens, src_key_padding_mask)
|
encoder_out, encoder_out_lens, lid_output = self.encoder(
|
||||||
|
x, x_lens, src_key_padding_mask
|
||||||
|
)
|
||||||
|
|
||||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
|
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
|
||||||
@ -217,7 +219,6 @@ class AsrModel(nn.Module):
|
|||||||
"""
|
"""
|
||||||
# Now for the decoder, i.e., the prediction network
|
# Now for the decoder, i.e., the prediction network
|
||||||
|
|
||||||
|
|
||||||
blank_id = self.decoder.blank_id
|
blank_id = self.decoder.blank_id
|
||||||
sos_y = add_sos(y, sos_id=blank_id)
|
sos_y = add_sos(y, sos_id=blank_id)
|
||||||
|
|
||||||
@ -291,11 +292,14 @@ class AsrModel(nn.Module):
|
|||||||
ranges=ranges,
|
ranges=ranges,
|
||||||
)
|
)
|
||||||
lid_logits = self.lid_joiner(
|
lid_logits = self.lid_joiner(
|
||||||
lid_am_pruned, lid_lm_pruned, project_input=False)
|
lid_am_pruned, lid_lm_pruned, project_input=False
|
||||||
|
)
|
||||||
|
|
||||||
# project_input=False since we applied the decoder's input projections
|
# project_input=False since we applied the decoder's input projections
|
||||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False, lid_out=lid_pruned)
|
logits = self.joiner(
|
||||||
|
am_pruned, lm_pruned, project_input=False, lid_out=lid_pruned
|
||||||
|
)
|
||||||
# Add blank logits to lid_logits
|
# Add blank logits to lid_logits
|
||||||
logits = torch.cat((lid_logits[..., 0].unsqueeze(-1), logits), dim=-1)
|
logits = torch.cat((lid_logits[..., 0].unsqueeze(-1), logits), dim=-1)
|
||||||
|
|
||||||
@ -315,7 +319,9 @@ class AsrModel(nn.Module):
|
|||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
pruned_lid_loss = k2.rnnt_loss_pruned(
|
pruned_lid_loss = k2.rnnt_loss_pruned(
|
||||||
logits=lid_logits.float(),
|
logits=lid_logits.float(),
|
||||||
symbols=y_lid.pad(mode="constant", padding_value=blank_id).to(torch.int64),
|
symbols=y_lid.pad(mode="constant", padding_value=blank_id).to(
|
||||||
|
torch.int64
|
||||||
|
),
|
||||||
ranges=ranges,
|
ranges=ranges,
|
||||||
termination_symbol=blank_id,
|
termination_symbol=blank_id,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
@ -374,7 +380,9 @@ class AsrModel(nn.Module):
|
|||||||
|
|
||||||
# Compute encoder outputs
|
# Compute encoder outputs
|
||||||
if self.lid_joiner != None:
|
if self.lid_joiner != None:
|
||||||
encoder_out, encoder_out_lens, lid_encoder_out = self.forward_encoder(x, x_lens)
|
encoder_out, encoder_out_lens, lid_encoder_out = self.forward_encoder(
|
||||||
|
x, x_lens
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
||||||
|
|
||||||
|
@ -366,7 +366,8 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
"--lid-value-head-dim",
|
"--lid-value-head-dim",
|
||||||
type=str,
|
type=str,
|
||||||
default="12",
|
default="12",
|
||||||
help="Value dimension per head in encoder stacks: a single int or comma-separated list.",)
|
help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lid-pos-head-dim",
|
"--lid-pos-head-dim",
|
||||||
type=str,
|
type=str,
|
||||||
@ -429,6 +430,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="Whether to skip positional embedding in the lid encoder.",
|
help="Whether to skip positional embedding in the lid encoder.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -781,9 +783,11 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
causal=params.causal,
|
causal=params.causal,
|
||||||
chunk_size=_to_int_tuple(params.chunk_size),
|
chunk_size=_to_int_tuple(params.chunk_size),
|
||||||
left_context_frames=_to_int_tuple(params.left_context_frames),
|
left_context_frames=_to_int_tuple(params.left_context_frames),
|
||||||
lid_output_layer=params.lid_output_layer if params.use_lid_encoder else None,)
|
lid_output_layer=params.lid_output_layer if params.use_lid_encoder else None,
|
||||||
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
def get_lid_encoder_model(params: AttributeDict) -> nn.Module:
|
def get_lid_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
lid_encoder = Zipformer2(
|
lid_encoder = Zipformer2(
|
||||||
output_downsampling_factor=2,
|
output_downsampling_factor=2,
|
||||||
@ -806,6 +810,7 @@ def get_lid_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
)
|
)
|
||||||
return lid_encoder
|
return lid_encoder
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
@ -826,15 +831,17 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
)
|
)
|
||||||
return joiner
|
return joiner
|
||||||
|
|
||||||
|
|
||||||
def get_lid_joiner_model(params: AttributeDict) -> nn.Module:
|
def get_lid_joiner_model(params: AttributeDict) -> nn.Module:
|
||||||
lid_joiner = Joiner(
|
lid_joiner = Joiner(
|
||||||
encoder_dim=int(params.lid_encoder_dim.split(",")[-1]),
|
encoder_dim=int(params.lid_encoder_dim.split(",")[-1]),
|
||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
joiner_dim=params.lid_joiner_dim,
|
joiner_dim=params.lid_joiner_dim,
|
||||||
vocab_size=len(params.lids.split(","))+1,
|
vocab_size=len(params.lids.split(",")) + 1,
|
||||||
)
|
)
|
||||||
return lid_joiner
|
return lid_joiner
|
||||||
|
|
||||||
|
|
||||||
def get_model(params: AttributeDict) -> nn.Module:
|
def get_model(params: AttributeDict) -> nn.Module:
|
||||||
assert params.use_transducer or params.use_ctc, (
|
assert params.use_transducer or params.use_ctc, (
|
||||||
f"At least one of them should be True, "
|
f"At least one of them should be True, "
|
||||||
@ -858,7 +865,6 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
decoder = None
|
decoder = None
|
||||||
joiner = None
|
joiner = None
|
||||||
|
|
||||||
|
|
||||||
model = AsrModel(
|
model = AsrModel(
|
||||||
encoder_embed=encoder_embed,
|
encoder_embed=encoder_embed,
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
@ -875,7 +881,6 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint_if_available(
|
def load_checkpoint_if_available(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -1018,7 +1023,7 @@ def compute_loss(
|
|||||||
values >= 1.0 are fully warmed up and have all modules present.
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
lids_dict = {lid:id+1 for id, lid in enumerate(params.lids.split(","))}
|
lids_dict = {lid: id + 1 for id, lid in enumerate(params.lids.split(","))}
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
@ -1066,9 +1071,7 @@ def compute_loss(
|
|||||||
lid_pruned_loss_is_finite = torch.isfinite(lid_pruned_loss)
|
lid_pruned_loss_is_finite = torch.isfinite(lid_pruned_loss)
|
||||||
|
|
||||||
is_finite = (
|
is_finite = (
|
||||||
simple_loss_is_finite
|
simple_loss_is_finite & pruned_loss_is_finite & lid_pruned_loss_is_finite
|
||||||
& pruned_loss_is_finite
|
|
||||||
& lid_pruned_loss_is_finite
|
|
||||||
)
|
)
|
||||||
if not torch.all(is_finite):
|
if not torch.all(is_finite):
|
||||||
logging.info(
|
logging.info(
|
||||||
@ -1096,12 +1099,13 @@ def compute_loss(
|
|||||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
loss += (1 - params.lid_loss_scale) * (
|
||||||
loss += (1-params.lid_loss_scale)*(simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss)
|
simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||||
#loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
)
|
||||||
|
# loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||||
if params.use_lid_joiner:
|
if params.use_lid_joiner:
|
||||||
loss += params.lid_loss_scale * pruned_loss_scale * lid_pruned_loss
|
loss += params.lid_loss_scale * pruned_loss_scale * lid_pruned_loss
|
||||||
#loss += pruned_loss_scale * lid_pruned_loss
|
# loss += pruned_loss_scale * lid_pruned_loss
|
||||||
|
|
||||||
if params.use_ctc:
|
if params.use_ctc:
|
||||||
loss += params.ctc_loss_scale * ctc_loss
|
loss += params.ctc_loss_scale * ctc_loss
|
||||||
|
Loading…
x
Reference in New Issue
Block a user