adding black reformatted files

This commit is contained in:
s-mousmita 2022-07-28 07:58:20 -04:00
parent 67e3607863
commit 83e2b30a22
17 changed files with 240 additions and 349 deletions

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -47,18 +48,19 @@ def compute_fbank_fisher_swbd_eval2000():
num_jobs = min(25, os.cpu_count())
num_mel_bins = 80
sampling_rate = 8000
dataset_parts = (
"eval2000",
"fisher",
"swbd",
)
test_dataset=("eval2000",)
dataset_parts = ("eval2000", "fisher", "swbd")
test_dataset = ("eval2000",)
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts, output_dir=src_dir, lazy=True, suffix="jsonl"
dataset_parts=dataset_parts,
output_dir=src_dir,
lazy=True,
suffix="jsonl",
)
assert manifests is not None
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=sampling_rate))
extractor = Fbank(
FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=sampling_rate)
)
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
@ -67,10 +69,9 @@ def compute_fbank_fisher_swbd_eval2000():
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
recordings=m["recordings"], supervisions=m["supervisions"]
)
#if "train" in partition:
# if "train" in partition:
if partition not in test_dataset:
logging.info(f"Adding speed perturbations to : {partition}")
cut_set = (

View File

@ -47,11 +47,7 @@ def compute_fbank_musan():
num_jobs = min(15, os.cpu_count())
num_mel_bins = 80
sampling_rate = 8000
dataset_parts = (
"music",
"speech",
"noise",
)
dataset_parts = ("music", "speech", "noise")
prefix = "musan"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
@ -75,7 +71,9 @@ def compute_fbank_musan():
logging.info("Extracting features for Musan")
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=sampling_rate)
extractor = Fbank(
FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=sampling_rate)
)
with get_executor() as ex: # Initialize the executor only once.
# create chunks of Musan with duration 5 - 10 seconds

View File

@ -1,58 +1,56 @@
#!/usr/bin/env python3
#
# script to extract cutids corresponding to a list of source audio files.
# It takes three arguments: list of audio (.sph) , cut jsonl and out jsonl
import sys, json ;
import ntpath;
import sys, json
import ntpath
list_of_sph = sys.argv[1];
jsonfile = sys.argv[2];
out_partition_json = sys.argv[3];
list_of_sph = sys.argv[1]
jsonfile = sys.argv[2]
out_partition_json = sys.argv[3]
list_of_sph=[line.rstrip('\n') for line in open(list_of_sph)]
list_of_sph = [line.rstrip("\n") for line in open(list_of_sph)]
sph_basename_list=[]
sph_basename_list = []
for f in list_of_sph:
bsname=ntpath.basename(f)
#print(bsname)
bsname = ntpath.basename(f)
sph_basename_list.append(ntpath.basename(f))
json_str=[line.rstrip('\n') for line in open(jsonfile)]
json_str = [line.rstrip("\n") for line in open(jsonfile)]
num_json = len(json_str)
#cutid2sph=dict()
out_partition=open(out_partition_json,'w',encoding='utf-8')
out_partition = open(out_partition_json, "w", encoding="utf-8")
for i in range(num_json):
if json_str[i] != '':
#print(json_str[i])
if json_str[i] != "":
# print(json_str[i])
cur_json = json.loads(json_str[i])
#print(cur_json)
cur_cutid= cur_json['id']
cur_rec = cur_json['recording']
cur_sources = cur_rec['sources']
#print(cur_cutid)
#print(cur_rec)
#print(cur_sources)
# print(cur_json)
cur_cutid = cur_json["id"]
cur_rec = cur_json["recording"]
cur_sources = cur_rec["sources"]
# print(cur_cutid)
# print(cur_rec)
# print(cur_sources)
for s in cur_sources:
cur_sph = s['source']
cur_sph_basename=ntpath.basename(cur_sph)
#print(cur_sph)
#print(cur_sph_basename)
if cur_sph_basename in sph_basename_list :
cur_sph = s["source"]
cur_sph_basename = ntpath.basename(cur_sph)
# print(cur_sph)
# print(cur_sph_basename)
if cur_sph_basename in sph_basename_list:
out_json_line = json_str[i]
out_partition.write(out_json_line)
out_partition.write("\n")
#for keys in cur_json:
#cur_cutid= cur_json['id']
#cur_rec = cur_json['recording_id']
#print(cur_cutid)
# for keys in cur_json:
# cur_cutid= cur_json['id']
# cur_rec = cur_json['recording_id']
# print(cur_cutid)
"""
for keys in cur_json:
#print(keys)
@ -64,9 +62,3 @@ for i in range(num_json):
out_partition.write(out_json_line)
out_partition.write("\n")
"""

View File

@ -1,40 +1,35 @@
#!/usr/bin/env python3
#
#
import sys, json ;
import ntpath;
import sys, json
import ntpath
list_of_sph = sys.argv[1];
jsonfile = sys.argv[2];
out_partition_json = sys.argv[3];
list_of_sph = sys.argv[1]
jsonfile = sys.argv[2]
out_partition_json = sys.argv[3]
list_of_sph=[line.rstrip('\n') for line in open(list_of_sph)]
list_of_sph = [line.rstrip("\n") for line in open(list_of_sph)]
sph_basename_list=[]
sph_basename_list = []
for f in list_of_sph:
bsname=ntpath.basename(f)
#print(bsname)
bsname = ntpath.basename(f)
sph_basename_list.append(ntpath.basename(f))
json_str=[line.rstrip('\n') for line in open(jsonfile)]
json_str = [line.rstrip("\n") for line in open(jsonfile)]
num_json = len(json_str)
out_partition=open(out_partition_json,'w',encoding='utf-8')
out_partition = open(out_partition_json, "w", encoding="utf-8")
for i in range(num_json):
if json_str[i] != '':
#print(json_str[i])
if json_str[i] != "":
cur_json = json.loads(json_str[i])
#print(cur_json)
cur_rec = cur_json['recording_id']
#print(cur_rec)
cur_rec = cur_json["recording_id"]
cur_sph_basename = cur_rec + ".sph"
#print(cur_sph_basename)
if cur_sph_basename in sph_basename_list :
if cur_sph_basename in sph_basename_list:
out_json_line = json_str[i]
out_partition.write(out_json_line)
out_partition.write("\n")

View File

@ -1,38 +1,20 @@
#!/usr/bin/env python3
# extract list of sph from a cut jsonl
# python3 extract_list_of_sph.py dev_cuts_swbd.jsonl > data/fbank/dev_swbd_sph.list
import sys, json ;
import sys, json
inputfile = sys.argv[1]
json_str=[line.rstrip('\n') for line in open(inputfile)]
json_str = [line.rstrip("\n") for line in open(inputfile)]
num_json = len(json_str)
#print(num_json)
#with open(inputfile, 'r',encoding='utf-8') as Jsonfile:
# print("Converting JSON encoded data into Python dictionary")
# json_dict = json.load(Jsonfile)
# for k,v in json_dict:
# print(k,v)
for i in range(num_json):
if json_str[i] != '':
#print(json_str[i])
if json_str[i] != "":
cur_json = json.loads(json_str[i])
# print(cur_json)
for keys in cur_json:
#print(keys)
cur_rec = cur_json['recording']
cur_sources = cur_rec['sources']
#print(cur_sources)
cur_rec = cur_json["recording"]
cur_sources = cur_rec["sources"]
for s in cur_sources:
cur_sph = s['source']
cur_sph = s["source"]
print(cur_sph)
#cur_sph = cur_sources[2]
#print(cur_sph)
#print(json.load(sys.stdin)['source'])

View File

@ -16,67 +16,72 @@ def get_args():
parser.add_argument("output_sups")
return parser.parse_args()
def remove_punctutation_and_other_symbol(text:str) -> str:
text = text.replace("--"," ")
text = text.replace("//"," ")
text = text.replace("."," ")
text = text.replace("?"," ")
text = text.replace("~"," ")
text = text.replace(","," ")
text = text.replace(";"," ")
text = text.replace("("," ")
text = text.replace(")"," ")
text = text.replace("&"," ")
text = text.replace("%"," ")
text = text.replace("*"," ")
text = text.replace("{"," ")
text = text.replace("}"," ")
def remove_punctutation_and_other_symbol(text: str) -> str:
text = text.replace("--", " ")
text = text.replace("//", " ")
text = text.replace(".", " ")
text = text.replace("?", " ")
text = text.replace("~", " ")
text = text.replace(",", " ")
text = text.replace(";", " ")
text = text.replace("(", " ")
text = text.replace(")", " ")
text = text.replace("&", " ")
text = text.replace("%", " ")
text = text.replace("*", " ")
text = text.replace("{", " ")
text = text.replace("}", " ")
return text
def eval2000_clean_eform(text: str, eform_count) -> str:
string_to_remove = []
piece=text.split("\">")
for i in range(0,len(piece)):
s=piece[i]+"\">"
res = re.search(r'<contraction e_form(.*?)\">', s)
piece = text.split('">')
for i in range(0, len(piece)):
s = piece[i] + '">'
res = re.search(r"<contraction e_form(.*?)\">", s)
if res is not None:
res_rm= res.group(1)
res_rm = res.group(1)
string_to_remove.append(res_rm)
for p in string_to_remove:
eform_string = p
text = text.replace(eform_string, " ")
eform_1 = "<contraction e_form"
text = text.replace(eform_1, " ")
eform_2="\">"
text = text.replace(eform_2," ")
#print("TEXT final: ", text)
eform_2 = '">'
text = text.replace(eform_2, " ")
# print("TEXT final: ", text)
return text
def replace_silphone(text: str) -> str:
def replace_silphone(text: str) -> str:
text = text.replace("[/BABY CRYING]", " ")
text = text.replace("[/CHILD]" , " ")
text = text.replace("[[DISTORTED]]" , " ")
text = text.replace("[/DISTORTION]" , " ")
text = text.replace("[[DRAWN OUT]]" , " ")
text = text.replace("[[DRAWN-OUT]]" , " ")
text = text.replace("[[FAINT]]" , " ")
text = text.replace("[SMACK]" , " ")
text = text.replace("[[MUMBLES]]" , " ")
text = text.replace("[[HIGH PITCHED SQUEAKY VOICE]]" , " ")
text = text.replace("[[IN THE LAUGH]]" , "[LAUGHTER]")
text = text.replace("[[LAST WORD SPOKEN WITH A LAUGH]]" , "[LAUGHTER]")
text = text.replace("[[PART OF FIRST SYLLABLE OF PREVIOUS WORD CUT OFF]]" , " ")
text = text.replace("[[PREVIOUS WORD SPOKEN WITH A LAUGH]]" , " ")
text = text.replace("[[PREVIOUS TWO WORDS SPOKEN WHILE LAUGHING]]" , " ")
text = text.replace("[[PROLONGED]]" , " ")
text = text.replace("[/RUNNING WATER]" , " ")
text = text.replace("[[SAYS LAUGHING]]" , "[LAUGHTER]")
text = text.replace("[[SINGING]]" , " ")
text = text.replace("[[SPOKEN WHILE LAUGHING]]" , "[LAUGHTER]")
text = text.replace("[/STATIC]" , " ")
text = text.replace("['THIRTIETH' DRAWN OUT]" , " ")
text = text.replace("[/VOICES]" , " ")
text = text.replace("[[WHISPERED]]" , " ")
text = text.replace("[/CHILD]", " ")
text = text.replace("[[DISTORTED]]", " ")
text = text.replace("[/DISTORTION]", " ")
text = text.replace("[[DRAWN OUT]]", " ")
text = text.replace("[[DRAWN-OUT]]", " ")
text = text.replace("[[FAINT]]", " ")
text = text.replace("[SMACK]", " ")
text = text.replace("[[MUMBLES]]", " ")
text = text.replace("[[HIGH PITCHED SQUEAKY VOICE]]", " ")
text = text.replace("[[IN THE LAUGH]]", "[LAUGHTER]")
text = text.replace("[[LAST WORD SPOKEN WITH A LAUGH]]", "[LAUGHTER]")
text = text.replace(
"[[PART OF FIRST SYLLABLE OF PREVIOUS WORD CUT OFF]]", " "
)
text = text.replace("[[PREVIOUS WORD SPOKEN WITH A LAUGH]]", " ")
text = text.replace("[[PREVIOUS TWO WORDS SPOKEN WHILE LAUGHING]]", " ")
text = text.replace("[[PROLONGED]]", " ")
text = text.replace("[/RUNNING WATER]", " ")
text = text.replace("[[SAYS LAUGHING]]", "[LAUGHTER]")
text = text.replace("[[SINGING]]", " ")
text = text.replace("[[SPOKEN WHILE LAUGHING]]", "[LAUGHTER]")
text = text.replace("[/STATIC]", " ")
text = text.replace("['THIRTIETH' DRAWN OUT]", " ")
text = text.replace("[/VOICES]", " ")
text = text.replace("[[WHISPERED]]", " ")
text = text.replace("[DISTORTION]", " ")
text = text.replace("[DISTORTION, HIGH VOLUME ON WAVES]", " ")
text = text.replace("[BACKGROUND LAUGHTER]", "[LAUGHTER]")
@ -95,24 +100,24 @@ def replace_silphone(text: str) -> str:
text = text.replace("[BABY CRYING]", " ")
text = text.replace("[METALLIC KNOCKING SOUND]", " ")
text = text.replace("[METALLIC SOUND]", " ")
text = text.replace("[PHONE JIGGLING]", " ")
text = text.replace("[BACKGROUND SOUND]", " ")
text = text.replace("[BACKGROUND VOICE]", " ")
text = text.replace("[BACKGROUND VOICES]", " ")
text = text.replace("[BACKGROUND VOICES]", " ")
text = text.replace("[BACKGROUND NOISE]", " ")
text = text.replace("[CAR HORNS IN BACKGROUND]", " ")
text = text.replace("[CAR HORNS]", " ")
text = text.replace("[CARNATING]", " ")
text = text.replace("[CARNATING]", " ")
text = text.replace("[CRYING CHILD]", " ")
text = text.replace("[CHOPPING SOUND]", " ")
text = text.replace("[BANGING]", " ")
text = text.replace("[CLICKING NOISE]", " ")
text = text.replace("[CLATTERING]", " ")
text = text.replace("[CLATTERING]", " ")
text = text.replace("[ECHO]", " ")
text = text.replace("[KNOCK]", " ")
text = text.replace("[NOISE-GOOD]", "[NOISE]")
text = text.replace("[RIGHT]", " ")
text = text.replace("[KNOCK]", " ")
text = text.replace("[NOISE-GOOD]", "[NOISE]")
text = text.replace("[RIGHT]", " ")
text = text.replace("[SOUND]", " ")
text = text.replace("[SQUEAK]", " ")
text = text.replace("[STATIC]", " ")
@ -131,64 +136,65 @@ def replace_silphone(text: str) -> str:
text = text.replace("Y[OU]I-", "YOU I")
text = text.replace("-[A]ND", "AND")
text = text.replace("JU[ST]", "JUST")
text = text.replace("{BREATH}" , " ")
text = text.replace("{BREATHY}" , " ")
text = text.replace("{CHANNEL NOISE}" , " ")
text = text.replace("{CLEAR THROAT}" , " ")
text = text.replace("{BREATH}", " ")
text = text.replace("{BREATHY}", " ")
text = text.replace("{CHANNEL NOISE}", " ")
text = text.replace("{CLEAR THROAT}", " ")
text = text.replace("{CLEARING THROAT}" , " ")
text = text.replace("{CLEARS THROAT}" , " ")
text = text.replace("{COUGH}" , " ")
text = text.replace("{DRAWN OUT}" , " ")
text = text.replace("{EXHALATION}" , " ")
text = text.replace("{EXHALE}" , " ")
text = text.replace("{GASP}" , " ")
text = text.replace("{HIGH SQUEAL}" , " ")
text = text.replace("{INHALE}" , " ")
text = text.replace("{LAUGH}" , "[LAUGHTER]")
text = text.replace("{LAUGH}" , "[LAUGHTER]")
text = text.replace("{LAUGH}" , "[LAUGHTER]")
text = text.replace("{LIPSMACK}" , " ")
text = text.replace("{LIPSMACK}" , " ")
text = text.replace("{CLEARING THROAT}", " ")
text = text.replace("{CLEARS THROAT}", " ")
text = text.replace("{COUGH}", " ")
text = text.replace("{DRAWN OUT}", " ")
text = text.replace("{EXHALATION}", " ")
text = text.replace("{EXHALE}", " ")
text = text.replace("{GASP}", " ")
text = text.replace("{HIGH SQUEAL}", " ")
text = text.replace("{INHALE}", " ")
text = text.replace("{LAUGH}", "[LAUGHTER]")
text = text.replace("{LAUGH}", "[LAUGHTER]")
text = text.replace("{LAUGH}", "[LAUGHTER]")
text = text.replace("{LIPSMACK}", " ")
text = text.replace("{LIPSMACK}", " ")
text = text.replace("{NOISE OF DISGUST}" , " ")
text = text.replace("{SIGH}" , " ")
text = text.replace("{SNIFF}" , " ")
text = text.replace("{SNORT}" , " ")
text = text.replace("{SHARP EXHALATION}" , " ")
text = text.replace("{BREATH LAUGH}" , " ")
text = text.replace("{NOISE OF DISGUST}", " ")
text = text.replace("{SIGH}", " ")
text = text.replace("{SNIFF}", " ")
text = text.replace("{SNORT}", " ")
text = text.replace("{SHARP EXHALATION}", " ")
text = text.replace("{BREATH LAUGH}", " ")
return text
def remove_languagetag(text:str) -> str:
langtag = re.findall(r'<(.*?)>', text)
def remove_languagetag(text: str) -> str:
langtag = re.findall(r"<(.*?)>", text)
for t in langtag:
text = text.replace(t, " ")
text = text.replace("<"," ")
text = text.replace(">"," ")
text = text.replace("<", " ")
text = text.replace(">", " ")
return text
def eval2000_normalizer(text: str) -> str:
#print("TEXT original: ",text)
eform_count=text.count("contraction e_form")
#print("eform corunt:", eform_count)
if eform_count>0:
text = eval2000_clean_eform(text,eform_count)
# print("TEXT original: ",text)
eform_count = text.count("contraction e_form")
# print("eform corunt:", eform_count)
if eform_count > 0:
text = eval2000_clean_eform(text, eform_count)
text = text.upper()
text = remove_languagetag(text)
text = replace_silphone(text)
text = remove_punctutation_and_other_symbol(text)
text = text.replace("IGNORE_TIME_SEGMENT_IN_SCORING", " ")
text = text.replace("IGNORE_TIME_SEGMENT_SCORING", " ")
spaces = re.findall(r'\s+', text)
spaces = re.findall(r"\s+", text)
for sp in spaces:
text = text.replace(sp," ")
text = text.strip()
#text = self.whitespace_regexp.sub(" ", text).strip()
#print(text)
text = text.replace(sp, " ")
text = text.strip()
# text = self.whitespace_regexp.sub(" ", text).strip()
# print(text)
return text
def main():
args = get_args()
sups = load_manifest_lazy_or_eager(args.input_sups)
@ -203,6 +209,7 @@ def main():
skip += 1
continue
writer.write(sup)
if __name__ == "__main__":
main()

View File

@ -108,9 +108,7 @@ def lexicon_to_fst_no_sil(
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
arcs, disambig_token=disambig_token, disambig_word=disambig_word
)
final_state = next_state
@ -223,9 +221,7 @@ def main():
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil(
lexicon,
token2id=token_sym_table,
word2id=word_sym_table,
lexicon, token2id=token_sym_table, word2id=word_sym_table
)
L_disambig = lexicon_to_fst_no_sil(

View File

@ -68,12 +68,7 @@ def get_args():
def get_g2p_sym2int():
# These symbols are removed from from g2p_en's vocabulary
excluded_symbols = [
"<pad>",
"<s>",
"</s>",
"<unk>",
]
excluded_symbols = ["<pad>", "<s>", "</s>", "<unk>"]
symbols = [p for p in sorted(G2p().phonemes) if p not in excluded_symbols]
# reserve 0 and 1 for blank and sos/eos/pad tokens
@ -345,9 +340,7 @@ def lexicon_to_fst(
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
arcs, disambig_token=disambig_token, disambig_word=disambig_word
)
final_state = next_state
@ -396,9 +389,7 @@ def main():
print(vocab[:10])
if not lexicon_filename.is_file():
lexicon = [
("!SIL", [sil_token]),
]
lexicon = [("!SIL", [sil_token])]
for symbol in special_symbols:
lexicon.append((symbol, [symbol[1:-1]]))
lexicon += [

View File

@ -43,16 +43,10 @@ def get_args():
""",
)
parser.add_argument(
"--transcript",
type=str,
help="Training transcript.",
)
parser.add_argument("--transcript", type=str, help="Training transcript.")
parser.add_argument(
"--vocab-size",
type=int,
help="Vocabulary size for BPE training",
"--vocab-size", type=int, help="Vocabulary size for BPE training"
)
return parser.parse_args()

View File

@ -1,6 +1,6 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
#
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -356,13 +356,10 @@ class FisherSwbdSpeechAsrDataModule:
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
cut_transforms=transforms, return_cuts=self.args.return_cuts
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
cuts_valid, max_duration=self.args.max_duration, shuffle=False
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
@ -384,9 +381,7 @@ class FisherSwbdSpeechAsrDataModule:
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
cuts, max_duration=self.args.max_duration, shuffle=False
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
@ -396,41 +391,52 @@ class FisherSwbdSpeechAsrDataModule:
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_fisher_cuts(self) -> CutSet:
logging.info("About to get fisher cuts")
return load_manifest_lazy(
self.args.manifest_dir / "train_cuts_fisher.jsonl.gz"
)
@lru_cache()
def train_swbd_cuts(self) -> CutSet:
logging.info("About to get train swbd cuts")
return load_manifest_lazy(
self.args.manifest_dir / "train_cuts_swbd.jsonl.gz"
)
@lru_cache()
def dev_fisher_cuts(self) -> CutSet:
logging.info("About to get dev fisher cuts")
return load_manifest_lazy(self.args.manifest_dir / "dev_cuts_fisher.jsonl.gz"
return load_manifest_lazy(
self.args.manifest_dir / "dev_cuts_fisher.jsonl.gz"
)
@lru_cache()
def dev_swbd_cuts(self) -> CutSet:
logging.info("About to get dev swbd cuts")
return load_manifest_lazy(self.args.manifest_dir / "dev_cuts_swbd.jsonl.gz"
return load_manifest_lazy(
self.args.manifest_dir / "dev_cuts_swbd.jsonl.gz"
)
@lru_cache()
def test_eval2000_cuts(self) -> CutSet:
logging.info("About to get test eval2000 cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_eval2000.jsonl.gz"
return load_manifest_lazy(
self.args.manifest_dir / "cuts_eval2000.jsonl.gz"
)
@lru_cache()
def test_swbd_cuts(self) -> CutSet:
logging.info("About to get test eval2000 swbd cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_eval2000_swbd.jsonl.gz"
return load_manifest_lazy(
self.args.manifest_dir / "cuts_eval2000_swbd.jsonl.gz"
)
@lru_cache()
def test_callhome_cuts(self) -> CutSet:
logging.info("About to get test eval2000 callhome cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_eval2000_callhome.jsonl.gz"
return load_manifest_lazy(
self.args.manifest_dir / "cuts_eval2000_callhome.jsonl.gz"
)

View File

@ -550,9 +550,7 @@ def greedy_search(
def greedy_search_batch(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
@ -591,9 +589,7 @@ def greedy_search_batch(
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
hyps, device=device, dtype=torch.int64
) # (N, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
@ -630,9 +626,7 @@ def greedy_search_batch(
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
decoder_input, device=device, dtype=torch.int64
)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
@ -894,9 +888,7 @@ def modified_beam_search(
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
current_encoder_out, decoder_out, project_input=False
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
@ -953,9 +945,7 @@ def modified_beam_search(
def _deprecated_modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
model: Transducer, encoder_out: torch.Tensor, beam: int = 4
) -> List[int]:
"""It limits the maximum number of symbols per frame to 1.
@ -1023,9 +1013,7 @@ def _deprecated_modified_beam_search(
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
current_encoder_out, decoder_out, project_input=False
)
# logits is of shape (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1)
@ -1097,9 +1085,7 @@ def beam_search(
device = next(model.parameters()).device
decoder_input = torch.tensor(
[blank_id] * context_size,
device=device,
dtype=torch.int64,
[blank_id] * context_size, device=device, dtype=torch.int64
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
@ -1318,9 +1304,7 @@ def fast_beam_search_with_nbest_rescoring(
num_unique_paths = len(word_ids_list)
b_to_a_map = torch.zeros(
num_unique_paths,
dtype=torch.int32,
device=lattice.device,
num_unique_paths, dtype=torch.int32, device=lattice.device
)
rescored_word_fsas = k2.intersect_device(
@ -1334,8 +1318,7 @@ def fast_beam_search_with_nbest_rescoring(
rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas)
rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas))
ngram_lm_scores = rescored_word_fsas.get_tot_scores(
use_double_scores=True,
log_semiring=False,
use_double_scores=True, log_semiring=False
)
ans: Dict[str, List[List[int]]] = {}

View File

@ -223,19 +223,10 @@ class Conformer(EncoderInterface):
init_states: List[torch.Tensor] = [
torch.zeros(
(
self.encoder_layers,
left_context,
self.d_model,
),
device=device,
(self.encoder_layers, left_context, self.d_model), device=device
),
torch.zeros(
(
self.encoder_layers,
self.cnn_module_kernel - 1,
self.d_model,
),
(self.encoder_layers, self.cnn_module_kernel - 1, self.d_model),
device=device,
),
]
@ -330,7 +321,9 @@ class Conformer(EncoderInterface):
{(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
given {states[1].shape}."""
lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
lengths -= (
2
) # we will cut off 1 frame on each side of encoder_embed output
src_key_padding_mask = make_pad_mask(lengths)
@ -829,9 +822,7 @@ class RelPositionalEncoding(torch.nn.Module):
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(
self,
x: torch.Tensor,
left_context: int = 0,
self, x: torch.Tensor, left_context: int = 0
) -> Tuple[Tensor, Tensor]:
"""Add positional encoding.
@ -875,10 +866,7 @@ class RelPositionMultiheadAttention(nn.Module):
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
self, embed_dim: int, num_heads: int, dropout: float = 0.0
) -> None:
super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
@ -1272,8 +1260,7 @@ class RelPositionMultiheadAttention(nn.Module):
bsz, num_heads, tgt_len, src_len
)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")
)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len
@ -1420,10 +1407,7 @@ class ConvolutionModule(nn.Module):
)
def forward(
self,
x: Tensor,
cache: Optional[Tensor] = None,
right_context: int = 0,
self, x: Tensor, cache: Optional[Tensor] = None, right_context: int = 0
) -> Tuple[Tensor, Tensor]:
"""Compute convolution module.

View File

@ -384,9 +384,7 @@ def decode_one_batch(
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
feature, pad=(0, 0, 0, params.left_context), value=LOG_EPS
)
if params.simulate_streaming:
@ -778,7 +776,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
fisherswbd = FisherSwbdSpeechAsrDataModule(args)
test_eval2000_cuts = fisherswbd.test_eval2000_cuts()
@ -803,9 +801,7 @@ def main():
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!")

View File

@ -92,10 +92,7 @@ class DecodeStream(object):
"""Return True if all the features are processed."""
return self._done
def set_features(
self,
features: torch.Tensor,
) -> None:
def set_features(self, features: torch.Tensor) -> None:
"""Set features tensor of current utterance."""
assert features.dim() == 2, features.dim()
self.features = torch.nn.functional.pad(

View File

@ -96,11 +96,7 @@ def get_parser():
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument("--bpe-model", type=str, help="""Path to bpe.model.""")
parser.add_argument(
"--method",

View File

@ -39,7 +39,8 @@ import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
#from asr_datamodule import LibriSpeechAsrDataModule
# from asr_datamodule import LibriSpeechAsrDataModule
from asr_datamodule import FisherSwbdSpeechAsrDataModule
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
@ -187,9 +188,7 @@ def get_parser():
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
model: nn.Module, encoder_out: torch.Tensor, streams: List[DecodeStream]
) -> List[List[int]]:
assert len(streams) == encoder_out.size(0)
@ -236,10 +235,7 @@ def greedy_search(
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
hyp_tokens = []
@ -290,9 +286,7 @@ def fast_beam_search(
def decode_one_chunk(
params: AttributeDict,
model: nn.Module,
decode_streams: List[DecodeStream],
params: AttributeDict, model: nn.Module, decode_streams: List[DecodeStream]
) -> List[int]:
"""Decode one chunk frames of features for each decode_streams and
return the indexes of finished streams in a List.
@ -502,10 +496,7 @@ def decode_dataset(
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] # noqa
decode_results.append(
(
decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(),
)
(decode_streams[i].ground_truth.split(), sp.decode(hyp).split())
)
del decode_streams[i]
@ -661,7 +652,7 @@ def main():
fisherswbd = FisherSwbdSpeechAsrDataModule(args)
test_eval2000_cuts = fisherswbd.test_eval2000_cuts()
test_swbd_cuts = fisherswbd.test_swbd_cuts ()
test_swbd_cuts = fisherswbd.test_swbd_cuts()
test_callhome_cuts = fisherswbd.test_callhome_cuts()
test_eval2000_dl = fisherswbd.test_dataloaders(test_eval2000_cuts)
@ -681,9 +672,7 @@ def main():
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!")

View File

@ -155,10 +155,7 @@ def get_parser():
)
parser.add_argument(
"--num-epochs",
type=int,
default=30,
help="Number of epochs to train.",
"--num-epochs", type=int, default=30, help="Number of epochs to train."
)
parser.add_argument(
@ -480,10 +477,7 @@ def load_checkpoint_if_available(
assert filename.is_file(), f"{filename} does not exist!"
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
filename, model=model, optimizer=optimizer, scheduler=scheduler
)
keys = [
@ -646,11 +640,7 @@ def compute_validation_loss(
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=False,
params=params, model=model, sp=sp, batch=batch, is_training=False
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
@ -767,9 +757,7 @@ def train_one_epoch(
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
rank=rank,
out_dir=params.exp_dir, topk=params.keep_last_k, rank=rank
)
if batch_idx % params.log_interval == 0:
@ -830,8 +818,6 @@ def run(rank, world_size, args):
"""
params = get_params()
params.update(vars(args))
if params.full_libri is False:
params.valid_interval = 1600
fix_random_seed(params.seed)
if world_size > 1:
@ -897,11 +883,11 @@ def run(rank, world_size, args):
if params.print_diagnostics:
diagnostic = diagnostics.attach_diagnostics(model)
librispeech = FisherSwbdSpeechAsrDataModule(args)
fisherswbd = FisherSwbdSpeechAsrDataModule(args)
train_cuts = fisherswbd.train_fisher_cuts()
train_cuts += fisherswbd.train_swbd_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
@ -991,9 +977,7 @@ def run(rank, world_size, args):
def display_and_save_batch(
batch: dict,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
batch: dict, params: AttributeDict, sp: spm.SentencePieceProcessor
) -> None:
"""Display the batch statistics and save the batch into disk.