mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
black formating
This commit is contained in:
parent
c7f74e410f
commit
891cf55901
@ -14,12 +14,7 @@ import jiwer
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--dec-file",
|
||||
type=str,
|
||||
help="Decoded icefall recogs file"
|
||||
|
||||
)
|
||||
parser.add_argument("--dec-file", type=str, help="Decoded icefall recogs file")
|
||||
|
||||
return parser
|
||||
|
||||
@ -29,22 +24,22 @@ def cer_(file):
|
||||
ref = []
|
||||
cer_results = 0
|
||||
ref_lens = 0
|
||||
with open(file, 'r', encoding='utf-8') as dec:
|
||||
with open(file, "r", encoding="utf-8") as dec:
|
||||
for line in dec:
|
||||
id, target = line.split('\t')
|
||||
id, target = line.split("\t")
|
||||
id = id[0:-2]
|
||||
target, txt = target.split("=")
|
||||
if target == 'ref':
|
||||
words = txt.strip().strip('[]').split(', ')
|
||||
if target == "ref":
|
||||
words = txt.strip().strip("[]").split(", ")
|
||||
word_list = [word.strip("'") for word in words]
|
||||
ref.append(" ".join(word_list))
|
||||
elif target == 'hyp':
|
||||
words = txt.strip().strip('[]').split(', ')
|
||||
elif target == "hyp":
|
||||
words = txt.strip().strip("[]").split(", ")
|
||||
word_list = [word.strip("'") for word in words]
|
||||
hyp.append(" ".join(word_list))
|
||||
for h, r in zip(hyp, ref):
|
||||
if r:
|
||||
cer_results += (jiwer.cer(r, h)*len(r))
|
||||
cer_results += jiwer.cer(r, h) * len(r)
|
||||
|
||||
ref_lens += len(r)
|
||||
print(cer_results / ref_lens)
|
||||
@ -55,5 +50,6 @@ def main():
|
||||
args = parse.parse_args()
|
||||
cer_(args.dec_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -38,6 +38,7 @@ from lhotse.features.kaldifeat import (
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@ -70,7 +71,7 @@ def get_args():
|
||||
def compute_fbank_gpu(args):
|
||||
src_dir = Path("data_seame/manifests")
|
||||
output_dir = Path("data_seame/fbank")
|
||||
num_jobs = min(os.cpu_count(),10)
|
||||
num_jobs = min(os.cpu_count(), 10)
|
||||
num_mel_bins = 80
|
||||
sampling_rate = 16000
|
||||
sr = 16000
|
||||
@ -87,7 +88,10 @@ def compute_fbank_gpu(args):
|
||||
suffix = "jsonl.gz"
|
||||
breakpoint
|
||||
manifests = read_manifests_if_cached(
|
||||
prefix=prefix, dataset_parts=dataset_parts, output_dir=src_dir,suffix=suffix,
|
||||
prefix=prefix,
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=src_dir,
|
||||
suffix=suffix,
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
@ -116,15 +120,11 @@ def compute_fbank_gpu(args):
|
||||
cut_set = cut_set.resample(sr)
|
||||
|
||||
cut_set = cut_set.trim_to_supervisions(
|
||||
keep_overlapping=False,
|
||||
keep_all_channels=False)
|
||||
cut_set = cut_set.filter(lambda c: c.duration >= .2 and c.duration <= 30)
|
||||
if "train" in partition:
|
||||
cut_set = (
|
||||
cut_set
|
||||
+ cut_set.perturb_speed(0.9)
|
||||
+ cut_set.perturb_speed(1.1)
|
||||
keep_overlapping=False, keep_all_channels=False
|
||||
)
|
||||
cut_set = cut_set.filter(lambda c: c.duration >= 0.2 and c.duration <= 30)
|
||||
if "train" in partition:
|
||||
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
cut_set = cut_set.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||
@ -147,10 +147,9 @@ def compute_fbank_gpu(args):
|
||||
)
|
||||
cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
args = get_args()
|
||||
|
@ -37,6 +37,7 @@ from lhotse.features.kaldifeat import (
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@ -69,7 +70,7 @@ def get_args():
|
||||
def compute_fbank_gpu(args):
|
||||
src_dir = Path("data_seame/manifests")
|
||||
output_dir = Path("data_seame/fbank")
|
||||
num_jobs = min(os.cpu_count(),10)
|
||||
num_jobs = min(os.cpu_count(), 10)
|
||||
num_mel_bins = 80
|
||||
sampling_rate = 16000
|
||||
sr = 16000
|
||||
@ -80,7 +81,6 @@ def compute_fbank_gpu(args):
|
||||
"train10",
|
||||
"train50",
|
||||
"train30",
|
||||
|
||||
)
|
||||
prefix = ""
|
||||
suffix = "jsonl.gz"
|
||||
@ -103,15 +103,11 @@ def compute_fbank_gpu(args):
|
||||
cut_set = cut_set.resample(sr)
|
||||
|
||||
cut_set = cut_set.trim_to_supervisions(
|
||||
keep_overlapping=False,
|
||||
keep_all_channels=False)
|
||||
cut_set = cut_set.filter(lambda c: c.duration >= .5 and c.duration <= 30)
|
||||
if "train" in part:
|
||||
cut_set = (
|
||||
cut_set
|
||||
+ cut_set.perturb_speed(0.9)
|
||||
+ cut_set.perturb_speed(1.1)
|
||||
keep_overlapping=False, keep_all_channels=False
|
||||
)
|
||||
cut_set = cut_set.filter(lambda c: c.duration >= 0.5 and c.duration <= 30)
|
||||
if "train" in part:
|
||||
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
cut_set = cut_set.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/{prefix}_feats_{part}",
|
||||
@ -134,10 +130,9 @@ def compute_fbank_gpu(args):
|
||||
)
|
||||
cut_set.to_file(output_dir / f"cuts_{part}.jsonl.gz")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
args = get_args()
|
||||
|
@ -7,7 +7,6 @@ from lhotse.qa import fix_manifests, validate_recordings_and_supervisions
|
||||
import pdb
|
||||
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -44,7 +43,7 @@ def get_parser():
|
||||
|
||||
def valid_asr(cut):
|
||||
tol = 2e-3
|
||||
i=0
|
||||
i = 0
|
||||
total_dur = 0
|
||||
for c in cut:
|
||||
if c.supervisions != []:
|
||||
@ -52,10 +51,14 @@ def valid_asr(cut):
|
||||
|
||||
logging.info(f"Supervision beyond the cut. Cut number: {i}")
|
||||
total_dur += c.duration
|
||||
logging.info(f"id: {c.id}, sup_end: {c.supervisions[0].end}, dur: {c.duration}, source {c.recording.sources[0].source}")
|
||||
logging.info(
|
||||
f"id: {c.id}, sup_end: {c.supervisions[0].end}, dur: {c.duration}, source {c.recording.sources[0].source}"
|
||||
)
|
||||
elif c.supervisions[0].start < -tol:
|
||||
logging.info(f"Supervision starts before the cut. Cut number: {i}")
|
||||
logging.info(f"id: {c.id}, sup_start: {c.supervisions[0].start}, dur: {c.duration}, source {c.recording.sources[0].source}")
|
||||
logging.info(
|
||||
f"id: {c.id}, sup_start: {c.supervisions[0].start}, dur: {c.duration}, source {c.recording.sources[0].source}"
|
||||
)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
@ -83,7 +86,10 @@ def main():
|
||||
logging.info("Validating manifests")
|
||||
validate_recordings_and_supervisions(recordings, supervisions)
|
||||
|
||||
cuts = CutSet.from_manifests(recordings= recordings, supervisions=supervisions,)
|
||||
cuts = CutSet.from_manifests(
|
||||
recordings=recordings,
|
||||
supervisions=supervisions,
|
||||
)
|
||||
cuts = cuts.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
|
||||
cuts.describe()
|
||||
logging.info("Example from cut:")
|
||||
@ -93,5 +99,6 @@ def main():
|
||||
if args.savecut != "":
|
||||
cuts.to_file(args.savecut)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -25,9 +25,7 @@ def main():
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
characters = list(line)
|
||||
characters = " ".join(
|
||||
["V" if char == "*" else char for char in characters]
|
||||
)
|
||||
characters = " ".join(["V" if char == "*" else char for char in characters])
|
||||
lex[line] = characters
|
||||
|
||||
with open(args.output, "w", encoding="utf-8") as fp:
|
||||
|
@ -44,11 +44,12 @@ def main():
|
||||
if not os.path.exists(langdir):
|
||||
os.makedirs(langdir)
|
||||
|
||||
with open(langdir / "transcript_words.txt", 'w') as txt:
|
||||
with open(langdir / "transcript_words.txt", "w") as txt:
|
||||
for c in cuts:
|
||||
#breakpoint()
|
||||
# breakpoint()
|
||||
text = c.supervisions[0].text
|
||||
txt.write(text + '\n')
|
||||
txt.write(text + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -68,8 +68,13 @@ def main():
|
||||
recordings = RecordingSet.from_file(args.rec)
|
||||
supervisions = SupervisionSet.from_file(args.sup)
|
||||
logging.info("Fixing manifests")
|
||||
cuts = CutSet.from_manifests(recordings= recordings, supervisions=supervisions,)
|
||||
cuts = cuts.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
|
||||
cuts = CutSet.from_manifests(
|
||||
recordings=recordings,
|
||||
supervisions=supervisions,
|
||||
)
|
||||
cuts = cuts.trim_to_supervisions(
|
||||
keep_overlapping=False, keep_all_channels=False
|
||||
)
|
||||
|
||||
shuffled = cuts.shuffle()
|
||||
total_dur = 0
|
||||
@ -86,5 +91,6 @@ def main():
|
||||
logging.info(f"Saving {args.outcut}")
|
||||
cuts.to_file(outdir / args.outcut)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -91,7 +91,7 @@ def main():
|
||||
user_defined_symbols = ["<blk>", "<sos/eos>"]
|
||||
unk_id = len(user_defined_symbols)
|
||||
if predef_sym:
|
||||
syms = predef_sym.split(',')
|
||||
syms = predef_sym.split(",")
|
||||
for i in syms:
|
||||
user_defined_symbols.append(i)
|
||||
# Note: unk_id is fixed to 2.
|
||||
@ -116,5 +116,6 @@ def main():
|
||||
shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
|
||||
generate_tokens(lang_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -25,29 +25,31 @@ def get_parser():
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
lids = "en,zh"
|
||||
lids_dict = {lid:id+1 for id, lid in enumerate(lids.split(","))}
|
||||
id2lang = {id+1: lid for id, lid in enumerate(lids.split(","))}
|
||||
lids_dict = {lid: id + 1 for id, lid in enumerate(lids.split(","))}
|
||||
id2lang = {id + 1: lid for id, lid in enumerate(lids.split(","))}
|
||||
bad_id = []
|
||||
|
||||
|
||||
def extract_info(line, info):
|
||||
# Split the line at the first colon to separate the ID
|
||||
id_part, rest = line.split(':', 1)
|
||||
id_part, rest = line.split(":", 1)
|
||||
|
||||
# Extract 'ref' by finding its start and end
|
||||
ref_start = rest.find(info)
|
||||
ref_end = rest.find(']', ref_start)
|
||||
ref = rest[ref_start+len(info):ref_end].replace("'", "").split(', ')
|
||||
ref_end = rest.find("]", ref_start)
|
||||
ref = rest[ref_start + len(info) : ref_end].replace("'", "").split(", ")
|
||||
|
||||
# Extract 'lid'
|
||||
if 'lid=' in rest:
|
||||
lid_start = rest.find('lid=[')
|
||||
lid_end = rest.find(']', lid_start)
|
||||
lid = rest[lid_start+len('lid=['):lid_end].split(', ')
|
||||
if "lid=" in rest:
|
||||
lid_start = rest.find("lid=[")
|
||||
lid_end = rest.find("]", lid_start)
|
||||
lid = rest[lid_start + len("lid=[") : lid_end].split(", ")
|
||||
else:
|
||||
lid = ['']
|
||||
lid = [""]
|
||||
|
||||
if lid[0]=='':
|
||||
if lid[0] == "":
|
||||
bad_id.append(id_part)
|
||||
if " ".join(lid):
|
||||
lid = [int(i) for i in lid] # Convert each element to integer
|
||||
@ -58,6 +60,7 @@ def is_English(c):
|
||||
"""check character is in English"""
|
||||
return ord(c.lower()) >= ord("a") and ord(c.lower()) <= ord("z")
|
||||
|
||||
|
||||
def get_en(text):
|
||||
res = []
|
||||
for w in text:
|
||||
@ -68,6 +71,7 @@ def get_en(text):
|
||||
continue
|
||||
return res
|
||||
|
||||
|
||||
def get_zh(text):
|
||||
res = []
|
||||
for w in text:
|
||||
@ -79,34 +83,33 @@ def get_zh(text):
|
||||
return res
|
||||
|
||||
|
||||
|
||||
def extract_info_lid(line, tag):
|
||||
# Split the line at the first colon to separate the ID
|
||||
id_part, rest = line.split(':', 1)
|
||||
id_part, rest = line.split(":", 1)
|
||||
|
||||
# Extract 'ref' by finding its start and end
|
||||
|
||||
ref_start = rest.find(tag)
|
||||
ref_end = rest.find(']', ref_start)
|
||||
ref = rest[ref_start+len(tag):ref_end].replace("'", "").split(', ')
|
||||
ref_end = rest.find("]", ref_start)
|
||||
ref = rest[ref_start + len(tag) : ref_end].replace("'", "").split(", ")
|
||||
|
||||
return id_part.strip(), ref
|
||||
|
||||
|
||||
def align_lid2(labels_a, labels_b, a, b):
|
||||
# Alignment
|
||||
EPS = '*'
|
||||
EPS = "*"
|
||||
ali = align(a, b, EPS, sclite_mode=True)
|
||||
|
||||
a2idx = {(i,idx):j for idx,(i,j) in enumerate(zip(a,labels_a))}
|
||||
b2idx = {(i,idx):j for idx,(i,j) in enumerate(zip(b,labels_b))}
|
||||
a2idx = {(i, idx): j for idx, (i, j) in enumerate(zip(a, labels_a))}
|
||||
b2idx = {(i, idx): j for idx, (i, j) in enumerate(zip(b, labels_b))}
|
||||
# Comparing labels of aligned elements
|
||||
idx_a = 0
|
||||
idx_b = 0
|
||||
ali_idx=0
|
||||
ali_idx = 0
|
||||
aligned_a = []
|
||||
aligned_b = []
|
||||
while idx_a <len(a) and idx_b <len(b) and ali_idx < len(ali):
|
||||
while idx_a < len(a) and idx_b < len(b) and ali_idx < len(ali):
|
||||
elem_a, elem_b = ali[ali_idx]
|
||||
if elem_a == EPS:
|
||||
idx_b += 1
|
||||
@ -114,8 +117,8 @@ def align_lid2(labels_a, labels_b, a, b):
|
||||
idx_a += 1
|
||||
elif elem_a != EPS and elem_b != EPS:
|
||||
|
||||
label_a = a2idx[(elem_a,idx_a)]
|
||||
label_b = b2idx[(elem_b,idx_b)]
|
||||
label_a = a2idx[(elem_a, idx_a)]
|
||||
label_b = b2idx[(elem_b, idx_b)]
|
||||
aligned_a.append(label_a)
|
||||
aligned_b.append(label_b)
|
||||
idx_b += 1
|
||||
@ -128,7 +131,7 @@ def align_lid2(labels_a, labels_b, a, b):
|
||||
def align_lid(labels_a, labels_b):
|
||||
# Alignment
|
||||
res_a, res_b = [], []
|
||||
EPS = '*'
|
||||
EPS = "*"
|
||||
ali = align(labels_a, labels_b, EPS, sclite_mode=True)
|
||||
|
||||
# Comparing labels of aligned elements
|
||||
@ -139,17 +142,18 @@ def align_lid(labels_a, labels_b):
|
||||
|
||||
|
||||
def read_file(infile, tag):
|
||||
""""returns list of dict (id, lid, text)"""
|
||||
""" "returns list of dict (id, lid, text)"""
|
||||
res = []
|
||||
with open(infile, 'r') as file:
|
||||
with open(infile, "r") as file:
|
||||
for line in file:
|
||||
_, rest = line.split(':', 1)
|
||||
_, rest = line.split(":", 1)
|
||||
if tag in rest:
|
||||
_id, text = extract_info_lid(line, tag)
|
||||
|
||||
res.append((_id, text))
|
||||
return res
|
||||
|
||||
|
||||
def wer(results, sclite_mode=False):
|
||||
subs = defaultdict(int)
|
||||
ins = defaultdict(int)
|
||||
@ -185,12 +189,12 @@ def wer(results, sclite_mode=False):
|
||||
return tot_err_rate
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
ref_data = read_file(args.rec, 'ref=[')
|
||||
ref_data = read_file(args.rec, "ref=[")
|
||||
ref_data = sorted(ref_data)
|
||||
hyp_data = read_file(args.rec, 'hyp=[')
|
||||
hyp_data = read_file(args.rec, "hyp=[")
|
||||
hyp_data = sorted(hyp_data)
|
||||
results = defaultdict(list)
|
||||
|
||||
@ -205,16 +209,9 @@ if __name__ == '__main__':
|
||||
hyp_text_en = get_en(hyp_text)
|
||||
hyp_text_zh = get_zh(hyp_text)
|
||||
|
||||
|
||||
results['en'].append((ref[0],ref_text_en, hyp_text_en))
|
||||
results['zh'].append((ref[0],ref_text_zh, hyp_text_zh))
|
||||
results["en"].append((ref[0], ref_text_en, hyp_text_en))
|
||||
results["zh"].append((ref[0], ref_text_zh, hyp_text_zh))
|
||||
|
||||
for key, val in results.items():
|
||||
print(key)
|
||||
res = wer(val)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -399,25 +399,18 @@ class SeameAsrDataModule:
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("Train data: About to get training cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
|
||||
)
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("Dev data: About to get develop cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_valid.jsonl.gz"
|
||||
)
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_valid.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def dev_man(self) -> CutSet:
|
||||
logging.info("About to get dev_man cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_dev_man.jsonl.gz"
|
||||
)
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_dev_man.jsonl.gz")
|
||||
|
||||
def dev_sge(self) -> CutSet:
|
||||
logging.info("About to get dev_sge cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_dev_sge.jsonl.gz"
|
||||
)
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_dev_sge.jsonl.gz")
|
||||
|
@ -111,6 +111,7 @@ import re
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def remove_punc(text):
|
||||
"""This function removes all English punctuations except the single quote (verbatim)."""
|
||||
|
||||
@ -119,20 +120,22 @@ def remove_punc(text):
|
||||
# english_punctuations = english_punctuations.replace("'", "")
|
||||
|
||||
# Create a translation table that maps each punctuation to a space.
|
||||
translator = str.maketrans(english_punctuations, ' ' * len(english_punctuations))
|
||||
translator = str.maketrans(english_punctuations, " " * len(english_punctuations))
|
||||
|
||||
# Translate the text using the translation table
|
||||
text = text.translate(translator)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def clean(text):
|
||||
text = remove_punc(text)
|
||||
text = text.lower()
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.rstrip()
|
||||
return text
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
|
@ -1384,5 +1384,6 @@ def main():
|
||||
else:
|
||||
run(rank=0, world_size=1, args=args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -924,9 +924,9 @@ def modified_beam_search_lm_rescore_LODR(
|
||||
# is equivalent to log(1 - sigmoid(logits[..., 0])).
|
||||
nb_shift = logp_b - logits[..., 0]
|
||||
nb_shift = nb_shift.unsqueeze(-1)
|
||||
log_probs1 = (logits[..., 1:]/ temperature).log_softmax(dim=-1) + nb_shift
|
||||
log_probs1 = (logits[..., 1:] / temperature).log_softmax(dim=-1) + nb_shift
|
||||
|
||||
#log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||
# log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||
log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1)
|
||||
|
||||
log_probs.add_(ys_log_probs)
|
||||
|
@ -91,6 +91,7 @@ import re
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def remove_punc(text):
|
||||
"""This function removes all English punctuations except the single quote (verbatim)."""
|
||||
|
||||
@ -99,20 +100,22 @@ def remove_punc(text):
|
||||
# english_punctuations = english_punctuations.replace("'", "")
|
||||
|
||||
# Create a translation table that maps each punctuation to a space.
|
||||
translator = str.maketrans(english_punctuations, ' ' * len(english_punctuations))
|
||||
translator = str.maketrans(english_punctuations, " " * len(english_punctuations))
|
||||
|
||||
# Translate the text using the translation table
|
||||
text = text.translate(translator)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def clean(text):
|
||||
text = remove_punc(text)
|
||||
text = text.lower()
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.rstrip()
|
||||
return text
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -813,12 +816,10 @@ def main():
|
||||
model.eval()
|
||||
|
||||
# only load the neural network LM if required
|
||||
if (
|
||||
params.use_shallow_fusion
|
||||
or params.decoding_method in (
|
||||
if params.use_shallow_fusion or params.decoding_method in (
|
||||
"modified_beam_search_lm_shallow_fusion",
|
||||
"modified_beam_search_LODR",
|
||||
"modified_beam_search_lm_rescore_LODR",)
|
||||
"modified_beam_search_lm_rescore_LODR",
|
||||
):
|
||||
LM = LmScorer(
|
||||
lm_type=params.lm_type,
|
||||
|
@ -349,7 +349,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--train-size",
|
||||
type=str,
|
||||
default='full',
|
||||
default="full",
|
||||
help="train datasize",
|
||||
)
|
||||
|
||||
@ -551,7 +551,7 @@ def get_params() -> AttributeDict:
|
||||
"valid_interval": 2000, # For the 100h subset, use 800
|
||||
# parameters for zipformer
|
||||
"feature_dim": 80,
|
||||
#"model_warm_step": 5000,
|
||||
# "model_warm_step": 5000,
|
||||
"subsampling_factor": 4, # not passed in, this is fixed.
|
||||
"warm_step": 5000,
|
||||
# parameters for ctc loss
|
||||
@ -1199,11 +1199,11 @@ def run(rank, world_size, args):
|
||||
|
||||
seame = SeameAsrDataModule(args)
|
||||
|
||||
if params.train_size == '30':
|
||||
if params.train_size == "30":
|
||||
train_cuts = seame.train30_cuts()
|
||||
elif params.train_size == '10':
|
||||
elif params.train_size == "10":
|
||||
train_cuts = seame.train10_cuts()
|
||||
elif params.train_size == '50':
|
||||
elif params.train_size == "50":
|
||||
train_cuts = seame.train50_cuts()
|
||||
else:
|
||||
train_cuts = seame.train_cuts()
|
||||
@ -1379,5 +1379,6 @@ def main():
|
||||
else:
|
||||
run(rank=0, world_size=1, args=args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -37,6 +37,7 @@ from icefall.utils import (
|
||||
get_texts_with_timestamp,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
# timestamps[k] contains the frame number on which tokens[k]
|
||||
@ -465,7 +466,9 @@ def modified_beam_search(
|
||||
lid_current_encoder_out = lid_encoder_out.data[start:end]
|
||||
|
||||
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||
asr_lid_current_encoder_out = asr_lid_current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||
asr_lid_current_encoder_out = asr_lid_current_encoder_out.unsqueeze(
|
||||
1
|
||||
).unsqueeze(1)
|
||||
lid_current_encoder_out = lid_current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||
offset = end
|
||||
@ -492,7 +495,6 @@ def modified_beam_search(
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out_)
|
||||
lid_decoder_out = model.lid_joiner.decoder_proj(decoder_out_)
|
||||
|
||||
|
||||
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
|
||||
|
||||
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||
@ -879,6 +881,7 @@ def modified_beam_search_lm_shallow_fusion(
|
||||
timestamps=ans_timestamps,
|
||||
)
|
||||
|
||||
|
||||
def modified_beam_search_auxlm_shallow_fusion(
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
@ -1160,6 +1163,7 @@ def modified_beam_search_auxlm_shallow_fusion(
|
||||
timestamps=ans_timestamps,
|
||||
)
|
||||
|
||||
|
||||
def modified_beam_search_lm_rescore_LODR(
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
@ -1282,9 +1286,9 @@ def modified_beam_search_lm_rescore_LODR(
|
||||
# is equivalent to log(1 - sigmoid(logits[..., 0])).
|
||||
nb_shift = logp_b - logits[..., 0]
|
||||
nb_shift = nb_shift.unsqueeze(-1)
|
||||
log_probs1 = (logits[..., 1:]/ temperature).log_softmax(dim=-1) + nb_shift
|
||||
log_probs1 = (logits[..., 1:] / temperature).log_softmax(dim=-1) + nb_shift
|
||||
|
||||
#log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||
# log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||
log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1)
|
||||
|
||||
log_probs.add_(ys_log_probs)
|
||||
|
@ -118,6 +118,7 @@ import matplotlib.pyplot as plt
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def remove_punc(text):
|
||||
"""This function removes all English punctuations except the single quote (verbatim)."""
|
||||
|
||||
@ -126,21 +127,23 @@ def remove_punc(text):
|
||||
english_punctuations = english_punctuations.replace("'", "")
|
||||
|
||||
# Create a translation table that maps each punctuation to a space.
|
||||
#translator = str.maketrans(english_punctuations, ' ' * len(english_punctuations))
|
||||
translator = str.maketrans('', '', english_punctuations)
|
||||
# translator = str.maketrans(english_punctuations, ' ' * len(english_punctuations))
|
||||
translator = str.maketrans("", "", english_punctuations)
|
||||
|
||||
# Translate the text using the translation table
|
||||
text = text.translate(translator)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def clean(text):
|
||||
text = remove_punc(text)
|
||||
text = text.lower()
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.rstrip()
|
||||
return text
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -399,20 +402,21 @@ def get_parser():
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def align_lid(labels_a, labels_b, a, b):
|
||||
# Alignment
|
||||
EPS = '*'
|
||||
EPS = "*"
|
||||
ali = align(a, b, EPS, sclite_mode=True)
|
||||
|
||||
a2idx = {(i,idx):j for idx,(i,j) in enumerate(zip(a,labels_a))}
|
||||
b2idx = {(i,idx):j for idx,(i,j) in enumerate(zip(b,labels_b))}
|
||||
a2idx = {(i, idx): j for idx, (i, j) in enumerate(zip(a, labels_a))}
|
||||
b2idx = {(i, idx): j for idx, (i, j) in enumerate(zip(b, labels_b))}
|
||||
# Comparing labels of aligned elements
|
||||
idx_a = 0
|
||||
idx_b = 0
|
||||
ali_idx=0
|
||||
ali_idx = 0
|
||||
aligned_a = []
|
||||
aligned_b = []
|
||||
while idx_a <len(a) and idx_b <len(b) and ali_idx < len(ali):
|
||||
while idx_a < len(a) and idx_b < len(b) and ali_idx < len(ali):
|
||||
elem_a, elem_b = ali[ali_idx]
|
||||
if elem_a == EPS:
|
||||
idx_b += 1
|
||||
@ -420,8 +424,8 @@ def align_lid(labels_a, labels_b, a, b):
|
||||
idx_a += 1
|
||||
elif elem_a != EPS and elem_b != EPS:
|
||||
|
||||
label_a = a2idx[(elem_a,idx_a)]
|
||||
label_b = b2idx[(elem_b,idx_b)]
|
||||
label_a = a2idx[(elem_a, idx_a)]
|
||||
label_b = b2idx[(elem_b, idx_b)]
|
||||
aligned_a.append(label_a)
|
||||
aligned_b.append(label_b)
|
||||
idx_b += 1
|
||||
@ -430,29 +434,33 @@ def align_lid(labels_a, labels_b, a, b):
|
||||
ali_idx += 1
|
||||
return aligned_a, aligned_b
|
||||
|
||||
def write_lid_results(lid_path, f1_path, text, lid ):
|
||||
|
||||
def write_lid_results(lid_path, f1_path, text, lid):
|
||||
lid_hyp = []
|
||||
lid_ref = []
|
||||
|
||||
with open(lid_path, 'w') as file:
|
||||
with open(lid_path, "w") as file:
|
||||
# Write each line to the file
|
||||
for text_line, lid_line in zip(text, lid):
|
||||
file.write(f"{text_line[0]}: ref={text_line[1]} lid={lid_line[1]}" + '\n')
|
||||
aligned_ref, aligned_hyp = align_lid(lid_line[1],lid_line[2],text_line[1], text_line[2])
|
||||
file.write(f"{text_line[0]}: ref={text_line[1]} lid={lid_line[1]}" + "\n")
|
||||
aligned_ref, aligned_hyp = align_lid(
|
||||
lid_line[1], lid_line[2], text_line[1], text_line[2]
|
||||
)
|
||||
lid_ref.extend(aligned_ref)
|
||||
lid_hyp.extend(aligned_hyp)
|
||||
file.write(f"{lid_line[0]}: hyp={text_line[2]} lid={lid_line[2]}" + '\n')
|
||||
file.write(f"{lid_line[0]}: hyp={text_line[2]} lid={lid_line[2]}" + "\n")
|
||||
|
||||
report = classification_report(lid_ref, lid_hyp, zero_division=0)
|
||||
f1 = f1_score(lid_ref, lid_hyp, average='weighted')
|
||||
f1 = f1_score(lid_ref, lid_hyp, average="weighted")
|
||||
|
||||
with open(f1_path, 'w') as file:
|
||||
with open(f1_path, "w") as file:
|
||||
file.write(report)
|
||||
file.write('\n')
|
||||
file.write("\n")
|
||||
file.write(f"F1 score: {f1} \n")
|
||||
filename = os.path.basename(lid_path).replace('.txt', '.png')
|
||||
filename = os.path.basename(lid_path).replace(".txt", ".png")
|
||||
dirname = os.path.dirname(lid_path)
|
||||
save_conf_mat(os.path.join(dirname,filename), lid_ref, lid_hyp)
|
||||
save_conf_mat(os.path.join(dirname, filename), lid_ref, lid_hyp)
|
||||
|
||||
|
||||
def save_conf_mat(path, lid_ref, lid_hyp):
|
||||
all_labels = [1, 2, 3, 4]
|
||||
@ -463,15 +471,23 @@ def save_conf_mat(path, lid_ref, lid_hyp):
|
||||
|
||||
# Plot the confusion matrix
|
||||
plt.figure(figsize=(10, 7))
|
||||
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
|
||||
plt.title('Confusion Matrix')
|
||||
plt.xlabel('Predicted Labels')
|
||||
plt.ylabel('True Labels')
|
||||
sns.heatmap(
|
||||
cm,
|
||||
annot=True,
|
||||
fmt="d",
|
||||
cmap="Blues",
|
||||
xticklabels=class_names,
|
||||
yticklabels=class_names,
|
||||
)
|
||||
plt.title("Confusion Matrix")
|
||||
plt.xlabel("Predicted Labels")
|
||||
plt.ylabel("True Labels")
|
||||
plt.savefig(path)
|
||||
|
||||
|
||||
def most_frequent(List):
|
||||
return max(set(List), key = List.count)
|
||||
return max(set(List), key=List.count)
|
||||
|
||||
|
||||
def mapp(enc, LID):
|
||||
pt1 = 0
|
||||
@ -492,6 +508,7 @@ def mapp(enc, LID):
|
||||
|
||||
return new_lid
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
@ -561,7 +578,9 @@ def decode_one_batch(
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
encoder_out, encoder_out_lens, lid_encoder_out = model.forward_encoder(feature, feature_lens)
|
||||
encoder_out, encoder_out_lens, lid_encoder_out = model.forward_encoder(
|
||||
feature, feature_lens
|
||||
)
|
||||
|
||||
hyps = []
|
||||
B, T, F = feature.shape
|
||||
@ -579,7 +598,7 @@ def decode_one_batch(
|
||||
hyp = results[i]
|
||||
token_pieces = sp.IdToPiece(results[i].hyps)
|
||||
new_lid = mapp(token_pieces, results[i].lid_hyps)
|
||||
hyps.append((sp.decode(results[i].hyps).split(),new_lid))
|
||||
hyps.append((sp.decode(results[i].hyps).split(), new_lid))
|
||||
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
@ -701,7 +720,7 @@ def decode_dataset(
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
if params.lid:
|
||||
lids_dict = {lid:id+1 for id, lid in enumerate(params.lids.split(","))}
|
||||
lids_dict = {lid: id + 1 for id, lid in enumerate(params.lids.split(","))}
|
||||
|
||||
text_list = [t.split("|") for t in texts]
|
||||
num_tokens = [[len(clean(t).split()) for t in utt] for utt in text_list]
|
||||
@ -758,7 +777,6 @@ def decode_dataset(
|
||||
if params.lid:
|
||||
results_lid[name].extend(this_batch_lid)
|
||||
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
@ -766,7 +784,7 @@ def decode_dataset(
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
if params.lid:
|
||||
return {"text":results, "lid":results_lid}
|
||||
return {"text": results, "lid": results_lid}
|
||||
else:
|
||||
return results
|
||||
|
||||
@ -820,19 +838,16 @@ def save_results_lid(
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
key = list(results_dict['text'].keys())[0]
|
||||
results_text = sorted(results_dict['text'][key], key=lambda x: x[0])
|
||||
results_lid = sorted(results_dict['lid'][key], key=lambda x: x[0])
|
||||
key = list(results_dict["text"].keys())[0]
|
||||
results_text = sorted(results_dict["text"][key], key=lambda x: x[0])
|
||||
results_lid = sorted(results_dict["lid"][key], key=lambda x: x[0])
|
||||
test_set_f1s = dict()
|
||||
lid_path = (
|
||||
params.res_dir / f"lid-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
f1_path = (
|
||||
params.res_dir / f"f1-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
lid_path = params.res_dir / f"lid-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
f1_path = params.res_dir / f"f1-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
write_lid_results(lid_path, f1_path, results_text, results_lid)
|
||||
logging.info(f"The lids are stored in {lid_path}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
@ -999,12 +1014,10 @@ def main():
|
||||
model.eval()
|
||||
|
||||
# only load the neural network LM if required
|
||||
if (
|
||||
params.use_shallow_fusion
|
||||
or params.decoding_method in (
|
||||
if params.use_shallow_fusion or params.decoding_method in (
|
||||
"modified_beam_search_lm_shallow_fusion",
|
||||
"modified_beam_search_LODR",
|
||||
"modified_beam_search_lm_rescore_LODR",)
|
||||
"modified_beam_search_lm_rescore_LODR",
|
||||
):
|
||||
LM = LmScorer(
|
||||
lm_type=params.lm_type,
|
||||
|
@ -19,6 +19,7 @@ import torch.nn as nn
|
||||
from scaling import ScaledLinear
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Joiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -141,7 +141,9 @@ class AsrModel(nn.Module):
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens, lid_output = self.encoder(x, x_lens, src_key_padding_mask)
|
||||
encoder_out, encoder_out_lens, lid_output = self.encoder(
|
||||
x, x_lens, src_key_padding_mask
|
||||
)
|
||||
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
|
||||
@ -217,7 +219,6 @@ class AsrModel(nn.Module):
|
||||
"""
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
|
||||
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
@ -291,11 +292,14 @@ class AsrModel(nn.Module):
|
||||
ranges=ranges,
|
||||
)
|
||||
lid_logits = self.lid_joiner(
|
||||
lid_am_pruned, lid_lm_pruned, project_input=False)
|
||||
lid_am_pruned, lid_lm_pruned, project_input=False
|
||||
)
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False, lid_out=lid_pruned)
|
||||
logits = self.joiner(
|
||||
am_pruned, lm_pruned, project_input=False, lid_out=lid_pruned
|
||||
)
|
||||
# Add blank logits to lid_logits
|
||||
logits = torch.cat((lid_logits[..., 0].unsqueeze(-1), logits), dim=-1)
|
||||
|
||||
@ -315,7 +319,9 @@ class AsrModel(nn.Module):
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_lid_loss = k2.rnnt_loss_pruned(
|
||||
logits=lid_logits.float(),
|
||||
symbols=y_lid.pad(mode="constant", padding_value=blank_id).to(torch.int64),
|
||||
symbols=y_lid.pad(mode="constant", padding_value=blank_id).to(
|
||||
torch.int64
|
||||
),
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
@ -374,7 +380,9 @@ class AsrModel(nn.Module):
|
||||
|
||||
# Compute encoder outputs
|
||||
if self.lid_joiner != None:
|
||||
encoder_out, encoder_out_lens, lid_encoder_out = self.forward_encoder(x, x_lens)
|
||||
encoder_out, encoder_out_lens, lid_encoder_out = self.forward_encoder(
|
||||
x, x_lens
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
||||
|
||||
|
@ -366,7 +366,8 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
"--lid-value-head-dim",
|
||||
type=str,
|
||||
default="12",
|
||||
help="Value dimension per head in encoder stacks: a single int or comma-separated list.",)
|
||||
help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lid-pos-head-dim",
|
||||
type=str,
|
||||
@ -429,6 +430,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="Whether to skip positional embedding in the lid encoder.",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -781,9 +783,11 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
causal=params.causal,
|
||||
chunk_size=_to_int_tuple(params.chunk_size),
|
||||
left_context_frames=_to_int_tuple(params.left_context_frames),
|
||||
lid_output_layer=params.lid_output_layer if params.use_lid_encoder else None,)
|
||||
lid_output_layer=params.lid_output_layer if params.use_lid_encoder else None,
|
||||
)
|
||||
return encoder
|
||||
|
||||
|
||||
def get_lid_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
lid_encoder = Zipformer2(
|
||||
output_downsampling_factor=2,
|
||||
@ -806,6 +810,7 @@ def get_lid_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
)
|
||||
return lid_encoder
|
||||
|
||||
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
@ -826,15 +831,17 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
)
|
||||
return joiner
|
||||
|
||||
|
||||
def get_lid_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
lid_joiner = Joiner(
|
||||
encoder_dim=int(params.lid_encoder_dim.split(",")[-1]),
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.lid_joiner_dim,
|
||||
vocab_size=len(params.lids.split(","))+1,
|
||||
vocab_size=len(params.lids.split(",")) + 1,
|
||||
)
|
||||
return lid_joiner
|
||||
|
||||
|
||||
def get_model(params: AttributeDict) -> nn.Module:
|
||||
assert params.use_transducer or params.use_ctc, (
|
||||
f"At least one of them should be True, "
|
||||
@ -858,7 +865,6 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = None
|
||||
joiner = None
|
||||
|
||||
|
||||
model = AsrModel(
|
||||
encoder_embed=encoder_embed,
|
||||
encoder=encoder,
|
||||
@ -875,7 +881,6 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
@ -1018,7 +1023,7 @@ def compute_loss(
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
|
||||
lids_dict = {lid:id+1 for id, lid in enumerate(params.lids.split(","))}
|
||||
lids_dict = {lid: id + 1 for id, lid in enumerate(params.lids.split(","))}
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is (N, T, C)
|
||||
@ -1066,9 +1071,7 @@ def compute_loss(
|
||||
lid_pruned_loss_is_finite = torch.isfinite(lid_pruned_loss)
|
||||
|
||||
is_finite = (
|
||||
simple_loss_is_finite
|
||||
& pruned_loss_is_finite
|
||||
& lid_pruned_loss_is_finite
|
||||
simple_loss_is_finite & pruned_loss_is_finite & lid_pruned_loss_is_finite
|
||||
)
|
||||
if not torch.all(is_finite):
|
||||
logging.info(
|
||||
@ -1096,12 +1099,13 @@ def compute_loss(
|
||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||
)
|
||||
|
||||
|
||||
loss += (1-params.lid_loss_scale)*(simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss)
|
||||
#loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
loss += (1 - params.lid_loss_scale) * (
|
||||
simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
)
|
||||
# loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
if params.use_lid_joiner:
|
||||
loss += params.lid_loss_scale * pruned_loss_scale * lid_pruned_loss
|
||||
#loss += pruned_loss_scale * lid_pruned_loss
|
||||
# loss += pruned_loss_scale * lid_pruned_loss
|
||||
|
||||
if params.use_ctc:
|
||||
loss += params.ctc_loss_scale * ctc_loss
|
||||
|
Loading…
x
Reference in New Issue
Block a user