black formating

This commit is contained in:
AmirHussein96 2024-04-05 13:00:29 -04:00
parent c7f74e410f
commit 891cf55901
21 changed files with 387 additions and 358 deletions

View File

@ -14,12 +14,7 @@ import jiwer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dec-file",
type=str,
help="Decoded icefall recogs file"
)
parser.add_argument("--dec-file", type=str, help="Decoded icefall recogs file")
return parser
@ -29,22 +24,22 @@ def cer_(file):
ref = []
cer_results = 0
ref_lens = 0
with open(file, 'r', encoding='utf-8') as dec:
with open(file, "r", encoding="utf-8") as dec:
for line in dec:
id, target = line.split('\t')
id, target = line.split("\t")
id = id[0:-2]
target, txt = target.split("=")
if target == 'ref':
words = txt.strip().strip('[]').split(', ')
if target == "ref":
words = txt.strip().strip("[]").split(", ")
word_list = [word.strip("'") for word in words]
ref.append(" ".join(word_list))
elif target == 'hyp':
words = txt.strip().strip('[]').split(', ')
elif target == "hyp":
words = txt.strip().strip("[]").split(", ")
word_list = [word.strip("'") for word in words]
hyp.append(" ".join(word_list))
for h, r in zip(hyp, ref):
if r:
cer_results += (jiwer.cer(r, h)*len(r))
cer_results += jiwer.cer(r, h) * len(r)
ref_lens += len(r)
print(cer_results / ref_lens)
@ -55,5 +50,6 @@ def main():
args = parse.parse_args()
cer_(args.dec_file)
if __name__ == "__main__":
main()

View File

@ -38,6 +38,7 @@ from lhotse.features.kaldifeat import (
KaldifeatMelOptions,
)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
@ -87,7 +88,10 @@ def compute_fbank_gpu(args):
suffix = "jsonl.gz"
breakpoint
manifests = read_manifests_if_cached(
prefix=prefix, dataset_parts=dataset_parts, output_dir=src_dir,suffix=suffix,
prefix=prefix,
dataset_parts=dataset_parts,
output_dir=src_dir,
suffix=suffix,
)
assert manifests is not None
@ -116,15 +120,11 @@ def compute_fbank_gpu(args):
cut_set = cut_set.resample(sr)
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False,
keep_all_channels=False)
cut_set = cut_set.filter(lambda c: c.duration >= .2 and c.duration <= 30)
if "train" in partition:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
keep_overlapping=False, keep_all_channels=False
)
cut_set = cut_set.filter(lambda c: c.duration >= 0.2 and c.duration <= 30)
if "train" in partition:
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
@ -147,10 +147,9 @@ def compute_fbank_gpu(args):
)
cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()

View File

@ -37,6 +37,7 @@ from lhotse.features.kaldifeat import (
KaldifeatMelOptions,
)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
@ -80,7 +81,6 @@ def compute_fbank_gpu(args):
"train10",
"train50",
"train30",
)
prefix = ""
suffix = "jsonl.gz"
@ -103,15 +103,11 @@ def compute_fbank_gpu(args):
cut_set = cut_set.resample(sr)
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False,
keep_all_channels=False)
cut_set = cut_set.filter(lambda c: c.duration >= .5 and c.duration <= 30)
if "train" in part:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
keep_overlapping=False, keep_all_channels=False
)
cut_set = cut_set.filter(lambda c: c.duration >= 0.5 and c.duration <= 30)
if "train" in part:
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{part}",
@ -134,10 +130,9 @@ def compute_fbank_gpu(args):
)
cut_set.to_file(output_dir / f"cuts_{part}.jsonl.gz")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()

View File

@ -7,7 +7,6 @@ from lhotse.qa import fix_manifests, validate_recordings_and_supervisions
import pdb
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -52,10 +51,14 @@ def valid_asr(cut):
logging.info(f"Supervision beyond the cut. Cut number: {i}")
total_dur += c.duration
logging.info(f"id: {c.id}, sup_end: {c.supervisions[0].end}, dur: {c.duration}, source {c.recording.sources[0].source}")
logging.info(
f"id: {c.id}, sup_end: {c.supervisions[0].end}, dur: {c.duration}, source {c.recording.sources[0].source}"
)
elif c.supervisions[0].start < -tol:
logging.info(f"Supervision starts before the cut. Cut number: {i}")
logging.info(f"id: {c.id}, sup_start: {c.supervisions[0].start}, dur: {c.duration}, source {c.recording.sources[0].source}")
logging.info(
f"id: {c.id}, sup_start: {c.supervisions[0].start}, dur: {c.duration}, source {c.recording.sources[0].source}"
)
else:
continue
else:
@ -83,7 +86,10 @@ def main():
logging.info("Validating manifests")
validate_recordings_and_supervisions(recordings, supervisions)
cuts = CutSet.from_manifests(recordings= recordings, supervisions=supervisions,)
cuts = CutSet.from_manifests(
recordings=recordings,
supervisions=supervisions,
)
cuts = cuts.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
cuts.describe()
logging.info("Example from cut:")
@ -93,5 +99,6 @@ def main():
if args.savecut != "":
cuts.to_file(args.savecut)
if __name__ == "__main__":
main()

View File

@ -25,9 +25,7 @@ def main():
for line in f:
line = line.strip()
characters = list(line)
characters = " ".join(
["V" if char == "*" else char for char in characters]
)
characters = " ".join(["V" if char == "*" else char for char in characters])
lex[line] = characters
with open(args.output, "w", encoding="utf-8") as fp:

View File

@ -44,11 +44,12 @@ def main():
if not os.path.exists(langdir):
os.makedirs(langdir)
with open(langdir / "transcript_words.txt", 'w') as txt:
with open(langdir / "transcript_words.txt", "w") as txt:
for c in cuts:
# breakpoint()
text = c.supervisions[0].text
txt.write(text + '\n')
txt.write(text + "\n")
if __name__ == "__main__":
main()

View File

@ -68,8 +68,13 @@ def main():
recordings = RecordingSet.from_file(args.rec)
supervisions = SupervisionSet.from_file(args.sup)
logging.info("Fixing manifests")
cuts = CutSet.from_manifests(recordings= recordings, supervisions=supervisions,)
cuts = cuts.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
cuts = CutSet.from_manifests(
recordings=recordings,
supervisions=supervisions,
)
cuts = cuts.trim_to_supervisions(
keep_overlapping=False, keep_all_channels=False
)
shuffled = cuts.shuffle()
total_dur = 0
@ -86,5 +91,6 @@ def main():
logging.info(f"Saving {args.outcut}")
cuts.to_file(outdir / args.outcut)
if __name__ == "__main__":
main()

View File

@ -91,7 +91,7 @@ def main():
user_defined_symbols = ["<blk>", "<sos/eos>"]
unk_id = len(user_defined_symbols)
if predef_sym:
syms = predef_sym.split(',')
syms = predef_sym.split(",")
for i in syms:
user_defined_symbols.append(i)
# Note: unk_id is fixed to 2.
@ -116,5 +116,6 @@ def main():
shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
generate_tokens(lang_dir)
if __name__ == "__main__":
main()

View File

@ -25,29 +25,31 @@ def get_parser():
)
return parser
lids = "en,zh"
lids_dict = {lid: id + 1 for id, lid in enumerate(lids.split(","))}
id2lang = {id + 1: lid for id, lid in enumerate(lids.split(","))}
bad_id = []
def extract_info(line, info):
# Split the line at the first colon to separate the ID
id_part, rest = line.split(':', 1)
id_part, rest = line.split(":", 1)
# Extract 'ref' by finding its start and end
ref_start = rest.find(info)
ref_end = rest.find(']', ref_start)
ref = rest[ref_start+len(info):ref_end].replace("'", "").split(', ')
ref_end = rest.find("]", ref_start)
ref = rest[ref_start + len(info) : ref_end].replace("'", "").split(", ")
# Extract 'lid'
if 'lid=' in rest:
lid_start = rest.find('lid=[')
lid_end = rest.find(']', lid_start)
lid = rest[lid_start+len('lid=['):lid_end].split(', ')
if "lid=" in rest:
lid_start = rest.find("lid=[")
lid_end = rest.find("]", lid_start)
lid = rest[lid_start + len("lid=[") : lid_end].split(", ")
else:
lid = ['']
lid = [""]
if lid[0]=='':
if lid[0] == "":
bad_id.append(id_part)
if " ".join(lid):
lid = [int(i) for i in lid] # Convert each element to integer
@ -58,6 +60,7 @@ def is_English(c):
"""check character is in English"""
return ord(c.lower()) >= ord("a") and ord(c.lower()) <= ord("z")
def get_en(text):
res = []
for w in text:
@ -68,6 +71,7 @@ def get_en(text):
continue
return res
def get_zh(text):
res = []
for w in text:
@ -79,23 +83,22 @@ def get_zh(text):
return res
def extract_info_lid(line, tag):
# Split the line at the first colon to separate the ID
id_part, rest = line.split(':', 1)
id_part, rest = line.split(":", 1)
# Extract 'ref' by finding its start and end
ref_start = rest.find(tag)
ref_end = rest.find(']', ref_start)
ref = rest[ref_start+len(tag):ref_end].replace("'", "").split(', ')
ref_end = rest.find("]", ref_start)
ref = rest[ref_start + len(tag) : ref_end].replace("'", "").split(", ")
return id_part.strip(), ref
def align_lid2(labels_a, labels_b, a, b):
# Alignment
EPS = '*'
EPS = "*"
ali = align(a, b, EPS, sclite_mode=True)
a2idx = {(i, idx): j for idx, (i, j) in enumerate(zip(a, labels_a))}
@ -128,7 +131,7 @@ def align_lid2(labels_a, labels_b, a, b):
def align_lid(labels_a, labels_b):
# Alignment
res_a, res_b = [], []
EPS = '*'
EPS = "*"
ali = align(labels_a, labels_b, EPS, sclite_mode=True)
# Comparing labels of aligned elements
@ -141,15 +144,16 @@ def align_lid(labels_a, labels_b):
def read_file(infile, tag):
""" "returns list of dict (id, lid, text)"""
res = []
with open(infile, 'r') as file:
with open(infile, "r") as file:
for line in file:
_, rest = line.split(':', 1)
_, rest = line.split(":", 1)
if tag in rest:
_id, text = extract_info_lid(line, tag)
res.append((_id, text))
return res
def wer(results, sclite_mode=False):
subs = defaultdict(int)
ins = defaultdict(int)
@ -185,12 +189,12 @@ def wer(results, sclite_mode=False):
return tot_err_rate
if __name__ == '__main__':
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
ref_data = read_file(args.rec, 'ref=[')
ref_data = read_file(args.rec, "ref=[")
ref_data = sorted(ref_data)
hyp_data = read_file(args.rec, 'hyp=[')
hyp_data = read_file(args.rec, "hyp=[")
hyp_data = sorted(hyp_data)
results = defaultdict(list)
@ -205,16 +209,9 @@ if __name__ == '__main__':
hyp_text_en = get_en(hyp_text)
hyp_text_zh = get_zh(hyp_text)
results['en'].append((ref[0],ref_text_en, hyp_text_en))
results['zh'].append((ref[0],ref_text_zh, hyp_text_zh))
results["en"].append((ref[0], ref_text_en, hyp_text_en))
results["zh"].append((ref[0], ref_text_zh, hyp_text_zh))
for key, val in results.items():
print(key)
res = wer(val)

View File

@ -399,25 +399,18 @@ class SeameAsrDataModule:
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("Train data: About to get training cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
)
return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz")
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("Dev data: About to get develop cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_valid.jsonl.gz"
)
return load_manifest_lazy(self.args.manifest_dir / "cuts_valid.jsonl.gz")
@lru_cache()
def dev_man(self) -> CutSet:
logging.info("About to get dev_man cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_dev_man.jsonl.gz"
)
return load_manifest_lazy(self.args.manifest_dir / "cuts_dev_man.jsonl.gz")
def dev_sge(self) -> CutSet:
logging.info("About to get dev_sge cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_dev_sge.jsonl.gz"
)
return load_manifest_lazy(self.args.manifest_dir / "cuts_dev_sge.jsonl.gz")

View File

@ -111,6 +111,7 @@ import re
LOG_EPS = math.log(1e-10)
def remove_punc(text):
"""This function removes all English punctuations except the single quote (verbatim)."""
@ -119,20 +120,22 @@ def remove_punc(text):
# english_punctuations = english_punctuations.replace("'", "")
# Create a translation table that maps each punctuation to a space.
translator = str.maketrans(english_punctuations, ' ' * len(english_punctuations))
translator = str.maketrans(english_punctuations, " " * len(english_punctuations))
# Translate the text using the translation table
text = text.translate(translator)
return text
def clean(text):
text = remove_punc(text)
text = text.lower()
text = re.sub(r'\s+', ' ', text)
text = re.sub(r"\s+", " ", text)
text = text.rstrip()
return text
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter

View File

@ -1384,5 +1384,6 @@ def main():
else:
run(rank=0, world_size=1, args=args)
if __name__ == "__main__":
main()

View File

@ -91,6 +91,7 @@ import re
LOG_EPS = math.log(1e-10)
def remove_punc(text):
"""This function removes all English punctuations except the single quote (verbatim)."""
@ -99,20 +100,22 @@ def remove_punc(text):
# english_punctuations = english_punctuations.replace("'", "")
# Create a translation table that maps each punctuation to a space.
translator = str.maketrans(english_punctuations, ' ' * len(english_punctuations))
translator = str.maketrans(english_punctuations, " " * len(english_punctuations))
# Translate the text using the translation table
text = text.translate(translator)
return text
def clean(text):
text = remove_punc(text)
text = text.lower()
text = re.sub(r'\s+', ' ', text)
text = re.sub(r"\s+", " ", text)
text = text.rstrip()
return text
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -813,12 +816,10 @@ def main():
model.eval()
# only load the neural network LM if required
if (
params.use_shallow_fusion
or params.decoding_method in (
if params.use_shallow_fusion or params.decoding_method in (
"modified_beam_search_lm_shallow_fusion",
"modified_beam_search_LODR",
"modified_beam_search_lm_rescore_LODR",)
"modified_beam_search_lm_rescore_LODR",
):
LM = LmScorer(
lm_type=params.lm_type,

View File

@ -349,7 +349,7 @@ def get_parser():
parser.add_argument(
"--train-size",
type=str,
default='full',
default="full",
help="train datasize",
)
@ -1199,11 +1199,11 @@ def run(rank, world_size, args):
seame = SeameAsrDataModule(args)
if params.train_size == '30':
if params.train_size == "30":
train_cuts = seame.train30_cuts()
elif params.train_size == '10':
elif params.train_size == "10":
train_cuts = seame.train10_cuts()
elif params.train_size == '50':
elif params.train_size == "50":
train_cuts = seame.train50_cuts()
else:
train_cuts = seame.train_cuts()
@ -1379,5 +1379,6 @@ def main():
else:
run(rank=0, world_size=1, args=args)
if __name__ == "__main__":
main()

View File

@ -37,6 +37,7 @@ from icefall.utils import (
get_texts_with_timestamp,
)
@dataclass
class Result:
# timestamps[k] contains the frame number on which tokens[k]
@ -465,7 +466,9 @@ def modified_beam_search(
lid_current_encoder_out = lid_encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
asr_lid_current_encoder_out = asr_lid_current_encoder_out.unsqueeze(1).unsqueeze(1)
asr_lid_current_encoder_out = asr_lid_current_encoder_out.unsqueeze(
1
).unsqueeze(1)
lid_current_encoder_out = lid_current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
@ -492,7 +495,6 @@ def modified_beam_search(
decoder_out = model.joiner.decoder_proj(decoder_out_)
lid_decoder_out = model.lid_joiner.decoder_proj(decoder_out_)
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
@ -879,6 +881,7 @@ def modified_beam_search_lm_shallow_fusion(
timestamps=ans_timestamps,
)
def modified_beam_search_auxlm_shallow_fusion(
model: nn.Module,
encoder_out: torch.Tensor,
@ -1160,6 +1163,7 @@ def modified_beam_search_auxlm_shallow_fusion(
timestamps=ans_timestamps,
)
def modified_beam_search_lm_rescore_LODR(
model: nn.Module,
encoder_out: torch.Tensor,

View File

@ -118,6 +118,7 @@ import matplotlib.pyplot as plt
LOG_EPS = math.log(1e-10)
def remove_punc(text):
"""This function removes all English punctuations except the single quote (verbatim)."""
@ -127,20 +128,22 @@ def remove_punc(text):
# Create a translation table that maps each punctuation to a space.
# translator = str.maketrans(english_punctuations, ' ' * len(english_punctuations))
translator = str.maketrans('', '', english_punctuations)
translator = str.maketrans("", "", english_punctuations)
# Translate the text using the translation table
text = text.translate(translator)
return text
def clean(text):
text = remove_punc(text)
text = text.lower()
text = re.sub(r'\s+', ' ', text)
text = re.sub(r"\s+", " ", text)
text = text.rstrip()
return text
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -399,9 +402,10 @@ def get_parser():
return parser
def align_lid(labels_a, labels_b, a, b):
# Alignment
EPS = '*'
EPS = "*"
ali = align(a, b, EPS, sclite_mode=True)
a2idx = {(i, idx): j for idx, (i, j) in enumerate(zip(a, labels_a))}
@ -430,30 +434,34 @@ def align_lid(labels_a, labels_b, a, b):
ali_idx += 1
return aligned_a, aligned_b
def write_lid_results(lid_path, f1_path, text, lid):
lid_hyp = []
lid_ref = []
with open(lid_path, 'w') as file:
with open(lid_path, "w") as file:
# Write each line to the file
for text_line, lid_line in zip(text, lid):
file.write(f"{text_line[0]}: ref={text_line[1]} lid={lid_line[1]}" + '\n')
aligned_ref, aligned_hyp = align_lid(lid_line[1],lid_line[2],text_line[1], text_line[2])
file.write(f"{text_line[0]}: ref={text_line[1]} lid={lid_line[1]}" + "\n")
aligned_ref, aligned_hyp = align_lid(
lid_line[1], lid_line[2], text_line[1], text_line[2]
)
lid_ref.extend(aligned_ref)
lid_hyp.extend(aligned_hyp)
file.write(f"{lid_line[0]}: hyp={text_line[2]} lid={lid_line[2]}" + '\n')
file.write(f"{lid_line[0]}: hyp={text_line[2]} lid={lid_line[2]}" + "\n")
report = classification_report(lid_ref, lid_hyp, zero_division=0)
f1 = f1_score(lid_ref, lid_hyp, average='weighted')
f1 = f1_score(lid_ref, lid_hyp, average="weighted")
with open(f1_path, 'w') as file:
with open(f1_path, "w") as file:
file.write(report)
file.write('\n')
file.write("\n")
file.write(f"F1 score: {f1} \n")
filename = os.path.basename(lid_path).replace('.txt', '.png')
filename = os.path.basename(lid_path).replace(".txt", ".png")
dirname = os.path.dirname(lid_path)
save_conf_mat(os.path.join(dirname, filename), lid_ref, lid_hyp)
def save_conf_mat(path, lid_ref, lid_hyp):
all_labels = [1, 2, 3, 4]
class_names = ["En", "Es", "Ar", "Zh"]
@ -463,16 +471,24 @@ def save_conf_mat(path, lid_ref, lid_hyp):
# Plot the confusion matrix
plt.figure(figsize=(10, 7))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
sns.heatmap(
cm,
annot=True,
fmt="d",
cmap="Blues",
xticklabels=class_names,
yticklabels=class_names,
)
plt.title("Confusion Matrix")
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.savefig(path)
def most_frequent(List):
return max(set(List), key=List.count)
def mapp(enc, LID):
pt1 = 0
new_lid = []
@ -492,6 +508,7 @@ def mapp(enc, LID):
return new_lid
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
@ -561,7 +578,9 @@ def decode_one_batch(
value=LOG_EPS,
)
encoder_out, encoder_out_lens, lid_encoder_out = model.forward_encoder(feature, feature_lens)
encoder_out, encoder_out_lens, lid_encoder_out = model.forward_encoder(
feature, feature_lens
)
hyps = []
B, T, F = feature.shape
@ -758,7 +777,6 @@ def decode_dataset(
if params.lid:
results_lid[name].extend(this_batch_lid)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
@ -820,19 +838,16 @@ def save_results_lid(
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
key = list(results_dict['text'].keys())[0]
results_text = sorted(results_dict['text'][key], key=lambda x: x[0])
results_lid = sorted(results_dict['lid'][key], key=lambda x: x[0])
key = list(results_dict["text"].keys())[0]
results_text = sorted(results_dict["text"][key], key=lambda x: x[0])
results_lid = sorted(results_dict["lid"][key], key=lambda x: x[0])
test_set_f1s = dict()
lid_path = (
params.res_dir / f"lid-{test_set_name}-{key}-{params.suffix}.txt"
)
f1_path = (
params.res_dir / f"f1-{test_set_name}-{key}-{params.suffix}.txt"
)
lid_path = params.res_dir / f"lid-{test_set_name}-{key}-{params.suffix}.txt"
f1_path = params.res_dir / f"f1-{test_set_name}-{key}-{params.suffix}.txt"
write_lid_results(lid_path, f1_path, results_text, results_lid)
logging.info(f"The lids are stored in {lid_path}")
@torch.no_grad()
def main():
parser = get_parser()
@ -999,12 +1014,10 @@ def main():
model.eval()
# only load the neural network LM if required
if (
params.use_shallow_fusion
or params.decoding_method in (
if params.use_shallow_fusion or params.decoding_method in (
"modified_beam_search_lm_shallow_fusion",
"modified_beam_search_LODR",
"modified_beam_search_lm_rescore_LODR",)
"modified_beam_search_lm_rescore_LODR",
):
LM = LmScorer(
lm_type=params.lm_type,

View File

@ -19,6 +19,7 @@ import torch.nn as nn
from scaling import ScaledLinear
from typing import Optional
class Joiner(nn.Module):
def __init__(
self,

View File

@ -141,7 +141,9 @@ class AsrModel(nn.Module):
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens, lid_output = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out, encoder_out_lens, lid_output = self.encoder(
x, x_lens, src_key_padding_mask
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
@ -217,7 +219,6 @@ class AsrModel(nn.Module):
"""
# Now for the decoder, i.e., the prediction network
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
@ -291,11 +292,14 @@ class AsrModel(nn.Module):
ranges=ranges,
)
lid_logits = self.lid_joiner(
lid_am_pruned, lid_lm_pruned, project_input=False)
lid_am_pruned, lid_lm_pruned, project_input=False
)
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False, lid_out=lid_pruned)
logits = self.joiner(
am_pruned, lm_pruned, project_input=False, lid_out=lid_pruned
)
# Add blank logits to lid_logits
logits = torch.cat((lid_logits[..., 0].unsqueeze(-1), logits), dim=-1)
@ -315,7 +319,9 @@ class AsrModel(nn.Module):
with torch.cuda.amp.autocast(enabled=False):
pruned_lid_loss = k2.rnnt_loss_pruned(
logits=lid_logits.float(),
symbols=y_lid.pad(mode="constant", padding_value=blank_id).to(torch.int64),
symbols=y_lid.pad(mode="constant", padding_value=blank_id).to(
torch.int64
),
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
@ -374,7 +380,9 @@ class AsrModel(nn.Module):
# Compute encoder outputs
if self.lid_joiner != None:
encoder_out, encoder_out_lens, lid_encoder_out = self.forward_encoder(x, x_lens)
encoder_out, encoder_out_lens, lid_encoder_out = self.forward_encoder(
x, x_lens
)
else:
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)

View File

@ -366,7 +366,8 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--lid-value-head-dim",
type=str,
default="12",
help="Value dimension per head in encoder stacks: a single int or comma-separated list.",)
help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
)
parser.add_argument(
"--lid-pos-head-dim",
type=str,
@ -429,6 +430,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Whether to skip positional embedding in the lid encoder.",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -781,9 +783,11 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
causal=params.causal,
chunk_size=_to_int_tuple(params.chunk_size),
left_context_frames=_to_int_tuple(params.left_context_frames),
lid_output_layer=params.lid_output_layer if params.use_lid_encoder else None,)
lid_output_layer=params.lid_output_layer if params.use_lid_encoder else None,
)
return encoder
def get_lid_encoder_model(params: AttributeDict) -> nn.Module:
lid_encoder = Zipformer2(
output_downsampling_factor=2,
@ -806,6 +810,7 @@ def get_lid_encoder_model(params: AttributeDict) -> nn.Module:
)
return lid_encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
@ -826,6 +831,7 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
)
return joiner
def get_lid_joiner_model(params: AttributeDict) -> nn.Module:
lid_joiner = Joiner(
encoder_dim=int(params.lid_encoder_dim.split(",")[-1]),
@ -835,6 +841,7 @@ def get_lid_joiner_model(params: AttributeDict) -> nn.Module:
)
return lid_joiner
def get_model(params: AttributeDict) -> nn.Module:
assert params.use_transducer or params.use_ctc, (
f"At least one of them should be True, "
@ -858,7 +865,6 @@ def get_model(params: AttributeDict) -> nn.Module:
decoder = None
joiner = None
model = AsrModel(
encoder_embed=encoder_embed,
encoder=encoder,
@ -875,7 +881,6 @@ def get_model(params: AttributeDict) -> nn.Module:
return model
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
@ -1066,9 +1071,7 @@ def compute_loss(
lid_pruned_loss_is_finite = torch.isfinite(lid_pruned_loss)
is_finite = (
simple_loss_is_finite
& pruned_loss_is_finite
& lid_pruned_loss_is_finite
simple_loss_is_finite & pruned_loss_is_finite & lid_pruned_loss_is_finite
)
if not torch.all(is_finite):
logging.info(
@ -1096,8 +1099,9 @@ def compute_loss(
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss += (1-params.lid_loss_scale)*(simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss)
loss += (1 - params.lid_loss_scale) * (
simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
)
# loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
if params.use_lid_joiner:
loss += params.lid_loss_scale * pruned_loss_scale * lid_pruned_loss