mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
adding black reformatted files
This commit is contained in:
parent
67e3607863
commit
83e2b30a22
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -47,18 +48,19 @@ def compute_fbank_fisher_swbd_eval2000():
|
||||
num_jobs = min(25, os.cpu_count())
|
||||
num_mel_bins = 80
|
||||
sampling_rate = 8000
|
||||
dataset_parts = (
|
||||
"eval2000",
|
||||
"fisher",
|
||||
"swbd",
|
||||
)
|
||||
test_dataset=("eval2000",)
|
||||
dataset_parts = ("eval2000", "fisher", "swbd")
|
||||
test_dataset = ("eval2000",)
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts, output_dir=src_dir, lazy=True, suffix="jsonl"
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=src_dir,
|
||||
lazy=True,
|
||||
suffix="jsonl",
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=sampling_rate))
|
||||
extractor = Fbank(
|
||||
FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=sampling_rate)
|
||||
)
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
@ -67,10 +69,9 @@ def compute_fbank_fisher_swbd_eval2000():
|
||||
continue
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
recordings=m["recordings"], supervisions=m["supervisions"]
|
||||
)
|
||||
#if "train" in partition:
|
||||
# if "train" in partition:
|
||||
if partition not in test_dataset:
|
||||
logging.info(f"Adding speed perturbations to : {partition}")
|
||||
cut_set = (
|
||||
|
||||
@ -47,11 +47,7 @@ def compute_fbank_musan():
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
num_mel_bins = 80
|
||||
sampling_rate = 8000
|
||||
dataset_parts = (
|
||||
"music",
|
||||
"speech",
|
||||
"noise",
|
||||
)
|
||||
dataset_parts = ("music", "speech", "noise")
|
||||
prefix = "musan"
|
||||
suffix = "jsonl.gz"
|
||||
manifests = read_manifests_if_cached(
|
||||
@ -75,7 +71,9 @@ def compute_fbank_musan():
|
||||
|
||||
logging.info("Extracting features for Musan")
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=sampling_rate)
|
||||
extractor = Fbank(
|
||||
FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=sampling_rate)
|
||||
)
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
# create chunks of Musan with duration 5 - 10 seconds
|
||||
|
||||
@ -1,58 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
#
|
||||
# script to extract cutids corresponding to a list of source audio files.
|
||||
# It takes three arguments: list of audio (.sph) , cut jsonl and out jsonl
|
||||
|
||||
|
||||
import sys, json ;
|
||||
import ntpath;
|
||||
import sys, json
|
||||
import ntpath
|
||||
|
||||
list_of_sph = sys.argv[1];
|
||||
jsonfile = sys.argv[2];
|
||||
out_partition_json = sys.argv[3];
|
||||
list_of_sph = sys.argv[1]
|
||||
jsonfile = sys.argv[2]
|
||||
out_partition_json = sys.argv[3]
|
||||
|
||||
|
||||
list_of_sph=[line.rstrip('\n') for line in open(list_of_sph)]
|
||||
list_of_sph = [line.rstrip("\n") for line in open(list_of_sph)]
|
||||
|
||||
sph_basename_list=[]
|
||||
sph_basename_list = []
|
||||
|
||||
for f in list_of_sph:
|
||||
bsname=ntpath.basename(f)
|
||||
#print(bsname)
|
||||
bsname = ntpath.basename(f)
|
||||
sph_basename_list.append(ntpath.basename(f))
|
||||
|
||||
|
||||
json_str=[line.rstrip('\n') for line in open(jsonfile)]
|
||||
json_str = [line.rstrip("\n") for line in open(jsonfile)]
|
||||
num_json = len(json_str)
|
||||
|
||||
#cutid2sph=dict()
|
||||
out_partition=open(out_partition_json,'w',encoding='utf-8')
|
||||
out_partition = open(out_partition_json, "w", encoding="utf-8")
|
||||
|
||||
for i in range(num_json):
|
||||
if json_str[i] != '':
|
||||
#print(json_str[i])
|
||||
if json_str[i] != "":
|
||||
# print(json_str[i])
|
||||
cur_json = json.loads(json_str[i])
|
||||
#print(cur_json)
|
||||
cur_cutid= cur_json['id']
|
||||
cur_rec = cur_json['recording']
|
||||
cur_sources = cur_rec['sources']
|
||||
#print(cur_cutid)
|
||||
#print(cur_rec)
|
||||
#print(cur_sources)
|
||||
# print(cur_json)
|
||||
cur_cutid = cur_json["id"]
|
||||
cur_rec = cur_json["recording"]
|
||||
cur_sources = cur_rec["sources"]
|
||||
# print(cur_cutid)
|
||||
# print(cur_rec)
|
||||
# print(cur_sources)
|
||||
for s in cur_sources:
|
||||
cur_sph = s['source']
|
||||
cur_sph_basename=ntpath.basename(cur_sph)
|
||||
#print(cur_sph)
|
||||
#print(cur_sph_basename)
|
||||
if cur_sph_basename in sph_basename_list :
|
||||
cur_sph = s["source"]
|
||||
cur_sph_basename = ntpath.basename(cur_sph)
|
||||
# print(cur_sph)
|
||||
# print(cur_sph_basename)
|
||||
if cur_sph_basename in sph_basename_list:
|
||||
out_json_line = json_str[i]
|
||||
out_partition.write(out_json_line)
|
||||
out_partition.write("\n")
|
||||
#for keys in cur_json:
|
||||
#cur_cutid= cur_json['id']
|
||||
#cur_rec = cur_json['recording_id']
|
||||
#print(cur_cutid)
|
||||
|
||||
|
||||
# for keys in cur_json:
|
||||
# cur_cutid= cur_json['id']
|
||||
# cur_rec = cur_json['recording_id']
|
||||
# print(cur_cutid)
|
||||
|
||||
|
||||
"""
|
||||
for keys in cur_json:
|
||||
#print(keys)
|
||||
@ -64,9 +62,3 @@ for i in range(num_json):
|
||||
out_partition.write(out_json_line)
|
||||
out_partition.write("\n")
|
||||
"""
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@ -1,40 +1,35 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
#
|
||||
|
||||
|
||||
import sys, json ;
|
||||
import ntpath;
|
||||
import sys, json
|
||||
import ntpath
|
||||
|
||||
list_of_sph = sys.argv[1];
|
||||
jsonfile = sys.argv[2];
|
||||
out_partition_json = sys.argv[3];
|
||||
list_of_sph = sys.argv[1]
|
||||
jsonfile = sys.argv[2]
|
||||
out_partition_json = sys.argv[3]
|
||||
|
||||
|
||||
list_of_sph=[line.rstrip('\n') for line in open(list_of_sph)]
|
||||
list_of_sph = [line.rstrip("\n") for line in open(list_of_sph)]
|
||||
|
||||
sph_basename_list=[]
|
||||
sph_basename_list = []
|
||||
|
||||
for f in list_of_sph:
|
||||
bsname=ntpath.basename(f)
|
||||
#print(bsname)
|
||||
bsname = ntpath.basename(f)
|
||||
sph_basename_list.append(ntpath.basename(f))
|
||||
|
||||
|
||||
json_str=[line.rstrip('\n') for line in open(jsonfile)]
|
||||
json_str = [line.rstrip("\n") for line in open(jsonfile)]
|
||||
num_json = len(json_str)
|
||||
|
||||
out_partition=open(out_partition_json,'w',encoding='utf-8')
|
||||
out_partition = open(out_partition_json, "w", encoding="utf-8")
|
||||
|
||||
for i in range(num_json):
|
||||
if json_str[i] != '':
|
||||
#print(json_str[i])
|
||||
if json_str[i] != "":
|
||||
cur_json = json.loads(json_str[i])
|
||||
#print(cur_json)
|
||||
cur_rec = cur_json['recording_id']
|
||||
#print(cur_rec)
|
||||
cur_rec = cur_json["recording_id"]
|
||||
cur_sph_basename = cur_rec + ".sph"
|
||||
#print(cur_sph_basename)
|
||||
if cur_sph_basename in sph_basename_list :
|
||||
if cur_sph_basename in sph_basename_list:
|
||||
out_json_line = json_str[i]
|
||||
out_partition.write(out_json_line)
|
||||
out_partition.write("\n")
|
||||
|
||||
@ -1,38 +1,20 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# extract list of sph from a cut jsonl
|
||||
# python3 extract_list_of_sph.py dev_cuts_swbd.jsonl > data/fbank/dev_swbd_sph.list
|
||||
|
||||
|
||||
import sys, json ;
|
||||
import sys, json
|
||||
|
||||
inputfile = sys.argv[1]
|
||||
json_str=[line.rstrip('\n') for line in open(inputfile)]
|
||||
json_str = [line.rstrip("\n") for line in open(inputfile)]
|
||||
num_json = len(json_str)
|
||||
|
||||
#print(num_json)
|
||||
#with open(inputfile, 'r',encoding='utf-8') as Jsonfile:
|
||||
# print("Converting JSON encoded data into Python dictionary")
|
||||
# json_dict = json.load(Jsonfile)
|
||||
# for k,v in json_dict:
|
||||
# print(k,v)
|
||||
|
||||
|
||||
|
||||
for i in range(num_json):
|
||||
if json_str[i] != '':
|
||||
#print(json_str[i])
|
||||
if json_str[i] != "":
|
||||
cur_json = json.loads(json_str[i])
|
||||
# print(cur_json)
|
||||
for keys in cur_json:
|
||||
#print(keys)
|
||||
cur_rec = cur_json['recording']
|
||||
cur_sources = cur_rec['sources']
|
||||
#print(cur_sources)
|
||||
cur_rec = cur_json["recording"]
|
||||
cur_sources = cur_rec["sources"]
|
||||
for s in cur_sources:
|
||||
cur_sph = s['source']
|
||||
cur_sph = s["source"]
|
||||
print(cur_sph)
|
||||
#cur_sph = cur_sources[2]
|
||||
#print(cur_sph)
|
||||
|
||||
|
||||
|
||||
#print(json.load(sys.stdin)['source'])
|
||||
|
||||
@ -16,67 +16,72 @@ def get_args():
|
||||
parser.add_argument("output_sups")
|
||||
return parser.parse_args()
|
||||
|
||||
def remove_punctutation_and_other_symbol(text:str) -> str:
|
||||
text = text.replace("--"," ")
|
||||
text = text.replace("//"," ")
|
||||
text = text.replace("."," ")
|
||||
text = text.replace("?"," ")
|
||||
text = text.replace("~"," ")
|
||||
text = text.replace(","," ")
|
||||
text = text.replace(";"," ")
|
||||
text = text.replace("("," ")
|
||||
text = text.replace(")"," ")
|
||||
text = text.replace("&"," ")
|
||||
text = text.replace("%"," ")
|
||||
text = text.replace("*"," ")
|
||||
text = text.replace("{"," ")
|
||||
text = text.replace("}"," ")
|
||||
|
||||
def remove_punctutation_and_other_symbol(text: str) -> str:
|
||||
text = text.replace("--", " ")
|
||||
text = text.replace("//", " ")
|
||||
text = text.replace(".", " ")
|
||||
text = text.replace("?", " ")
|
||||
text = text.replace("~", " ")
|
||||
text = text.replace(",", " ")
|
||||
text = text.replace(";", " ")
|
||||
text = text.replace("(", " ")
|
||||
text = text.replace(")", " ")
|
||||
text = text.replace("&", " ")
|
||||
text = text.replace("%", " ")
|
||||
text = text.replace("*", " ")
|
||||
text = text.replace("{", " ")
|
||||
text = text.replace("}", " ")
|
||||
return text
|
||||
|
||||
|
||||
def eval2000_clean_eform(text: str, eform_count) -> str:
|
||||
string_to_remove = []
|
||||
piece=text.split("\">")
|
||||
for i in range(0,len(piece)):
|
||||
s=piece[i]+"\">"
|
||||
res = re.search(r'<contraction e_form(.*?)\">', s)
|
||||
piece = text.split('">')
|
||||
for i in range(0, len(piece)):
|
||||
s = piece[i] + '">'
|
||||
res = re.search(r"<contraction e_form(.*?)\">", s)
|
||||
if res is not None:
|
||||
res_rm= res.group(1)
|
||||
res_rm = res.group(1)
|
||||
string_to_remove.append(res_rm)
|
||||
for p in string_to_remove:
|
||||
eform_string = p
|
||||
text = text.replace(eform_string, " ")
|
||||
eform_1 = "<contraction e_form"
|
||||
text = text.replace(eform_1, " ")
|
||||
eform_2="\">"
|
||||
text = text.replace(eform_2," ")
|
||||
#print("TEXT final: ", text)
|
||||
eform_2 = '">'
|
||||
text = text.replace(eform_2, " ")
|
||||
# print("TEXT final: ", text)
|
||||
return text
|
||||
|
||||
def replace_silphone(text: str) -> str:
|
||||
|
||||
|
||||
def replace_silphone(text: str) -> str:
|
||||
text = text.replace("[/BABY CRYING]", " ")
|
||||
text = text.replace("[/CHILD]" , " ")
|
||||
text = text.replace("[[DISTORTED]]" , " ")
|
||||
text = text.replace("[/DISTORTION]" , " ")
|
||||
text = text.replace("[[DRAWN OUT]]" , " ")
|
||||
text = text.replace("[[DRAWN-OUT]]" , " ")
|
||||
text = text.replace("[[FAINT]]" , " ")
|
||||
text = text.replace("[SMACK]" , " ")
|
||||
text = text.replace("[[MUMBLES]]" , " ")
|
||||
text = text.replace("[[HIGH PITCHED SQUEAKY VOICE]]" , " ")
|
||||
text = text.replace("[[IN THE LAUGH]]" , "[LAUGHTER]")
|
||||
text = text.replace("[[LAST WORD SPOKEN WITH A LAUGH]]" , "[LAUGHTER]")
|
||||
text = text.replace("[[PART OF FIRST SYLLABLE OF PREVIOUS WORD CUT OFF]]" , " ")
|
||||
text = text.replace("[[PREVIOUS WORD SPOKEN WITH A LAUGH]]" , " ")
|
||||
text = text.replace("[[PREVIOUS TWO WORDS SPOKEN WHILE LAUGHING]]" , " ")
|
||||
text = text.replace("[[PROLONGED]]" , " ")
|
||||
text = text.replace("[/RUNNING WATER]" , " ")
|
||||
text = text.replace("[[SAYS LAUGHING]]" , "[LAUGHTER]")
|
||||
text = text.replace("[[SINGING]]" , " ")
|
||||
text = text.replace("[[SPOKEN WHILE LAUGHING]]" , "[LAUGHTER]")
|
||||
text = text.replace("[/STATIC]" , " ")
|
||||
text = text.replace("['THIRTIETH' DRAWN OUT]" , " ")
|
||||
text = text.replace("[/VOICES]" , " ")
|
||||
text = text.replace("[[WHISPERED]]" , " ")
|
||||
text = text.replace("[/CHILD]", " ")
|
||||
text = text.replace("[[DISTORTED]]", " ")
|
||||
text = text.replace("[/DISTORTION]", " ")
|
||||
text = text.replace("[[DRAWN OUT]]", " ")
|
||||
text = text.replace("[[DRAWN-OUT]]", " ")
|
||||
text = text.replace("[[FAINT]]", " ")
|
||||
text = text.replace("[SMACK]", " ")
|
||||
text = text.replace("[[MUMBLES]]", " ")
|
||||
text = text.replace("[[HIGH PITCHED SQUEAKY VOICE]]", " ")
|
||||
text = text.replace("[[IN THE LAUGH]]", "[LAUGHTER]")
|
||||
text = text.replace("[[LAST WORD SPOKEN WITH A LAUGH]]", "[LAUGHTER]")
|
||||
text = text.replace(
|
||||
"[[PART OF FIRST SYLLABLE OF PREVIOUS WORD CUT OFF]]", " "
|
||||
)
|
||||
text = text.replace("[[PREVIOUS WORD SPOKEN WITH A LAUGH]]", " ")
|
||||
text = text.replace("[[PREVIOUS TWO WORDS SPOKEN WHILE LAUGHING]]", " ")
|
||||
text = text.replace("[[PROLONGED]]", " ")
|
||||
text = text.replace("[/RUNNING WATER]", " ")
|
||||
text = text.replace("[[SAYS LAUGHING]]", "[LAUGHTER]")
|
||||
text = text.replace("[[SINGING]]", " ")
|
||||
text = text.replace("[[SPOKEN WHILE LAUGHING]]", "[LAUGHTER]")
|
||||
text = text.replace("[/STATIC]", " ")
|
||||
text = text.replace("['THIRTIETH' DRAWN OUT]", " ")
|
||||
text = text.replace("[/VOICES]", " ")
|
||||
text = text.replace("[[WHISPERED]]", " ")
|
||||
text = text.replace("[DISTORTION]", " ")
|
||||
text = text.replace("[DISTORTION, HIGH VOLUME ON WAVES]", " ")
|
||||
text = text.replace("[BACKGROUND LAUGHTER]", "[LAUGHTER]")
|
||||
@ -95,24 +100,24 @@ def replace_silphone(text: str) -> str:
|
||||
text = text.replace("[BABY CRYING]", " ")
|
||||
text = text.replace("[METALLIC KNOCKING SOUND]", " ")
|
||||
text = text.replace("[METALLIC SOUND]", " ")
|
||||
|
||||
|
||||
text = text.replace("[PHONE JIGGLING]", " ")
|
||||
text = text.replace("[BACKGROUND SOUND]", " ")
|
||||
text = text.replace("[BACKGROUND VOICE]", " ")
|
||||
text = text.replace("[BACKGROUND VOICES]", " ")
|
||||
text = text.replace("[BACKGROUND VOICES]", " ")
|
||||
text = text.replace("[BACKGROUND NOISE]", " ")
|
||||
text = text.replace("[CAR HORNS IN BACKGROUND]", " ")
|
||||
text = text.replace("[CAR HORNS]", " ")
|
||||
text = text.replace("[CARNATING]", " ")
|
||||
text = text.replace("[CARNATING]", " ")
|
||||
text = text.replace("[CRYING CHILD]", " ")
|
||||
text = text.replace("[CHOPPING SOUND]", " ")
|
||||
text = text.replace("[BANGING]", " ")
|
||||
text = text.replace("[CLICKING NOISE]", " ")
|
||||
text = text.replace("[CLATTERING]", " ")
|
||||
text = text.replace("[CLATTERING]", " ")
|
||||
text = text.replace("[ECHO]", " ")
|
||||
text = text.replace("[KNOCK]", " ")
|
||||
text = text.replace("[NOISE-GOOD]", "[NOISE]")
|
||||
text = text.replace("[RIGHT]", " ")
|
||||
text = text.replace("[KNOCK]", " ")
|
||||
text = text.replace("[NOISE-GOOD]", "[NOISE]")
|
||||
text = text.replace("[RIGHT]", " ")
|
||||
text = text.replace("[SOUND]", " ")
|
||||
text = text.replace("[SQUEAK]", " ")
|
||||
text = text.replace("[STATIC]", " ")
|
||||
@ -131,64 +136,65 @@ def replace_silphone(text: str) -> str:
|
||||
text = text.replace("Y[OU]I-", "YOU I")
|
||||
text = text.replace("-[A]ND", "AND")
|
||||
text = text.replace("JU[ST]", "JUST")
|
||||
text = text.replace("{BREATH}" , " ")
|
||||
text = text.replace("{BREATHY}" , " ")
|
||||
text = text.replace("{CHANNEL NOISE}" , " ")
|
||||
text = text.replace("{CLEAR THROAT}" , " ")
|
||||
text = text.replace("{BREATH}", " ")
|
||||
text = text.replace("{BREATHY}", " ")
|
||||
text = text.replace("{CHANNEL NOISE}", " ")
|
||||
text = text.replace("{CLEAR THROAT}", " ")
|
||||
|
||||
text = text.replace("{CLEARING THROAT}" , " ")
|
||||
text = text.replace("{CLEARS THROAT}" , " ")
|
||||
text = text.replace("{COUGH}" , " ")
|
||||
text = text.replace("{DRAWN OUT}" , " ")
|
||||
text = text.replace("{EXHALATION}" , " ")
|
||||
text = text.replace("{EXHALE}" , " ")
|
||||
text = text.replace("{GASP}" , " ")
|
||||
text = text.replace("{HIGH SQUEAL}" , " ")
|
||||
text = text.replace("{INHALE}" , " ")
|
||||
text = text.replace("{LAUGH}" , "[LAUGHTER]")
|
||||
text = text.replace("{LAUGH}" , "[LAUGHTER]")
|
||||
text = text.replace("{LAUGH}" , "[LAUGHTER]")
|
||||
text = text.replace("{LIPSMACK}" , " ")
|
||||
text = text.replace("{LIPSMACK}" , " ")
|
||||
text = text.replace("{CLEARING THROAT}", " ")
|
||||
text = text.replace("{CLEARS THROAT}", " ")
|
||||
text = text.replace("{COUGH}", " ")
|
||||
text = text.replace("{DRAWN OUT}", " ")
|
||||
text = text.replace("{EXHALATION}", " ")
|
||||
text = text.replace("{EXHALE}", " ")
|
||||
text = text.replace("{GASP}", " ")
|
||||
text = text.replace("{HIGH SQUEAL}", " ")
|
||||
text = text.replace("{INHALE}", " ")
|
||||
text = text.replace("{LAUGH}", "[LAUGHTER]")
|
||||
text = text.replace("{LAUGH}", "[LAUGHTER]")
|
||||
text = text.replace("{LAUGH}", "[LAUGHTER]")
|
||||
text = text.replace("{LIPSMACK}", " ")
|
||||
text = text.replace("{LIPSMACK}", " ")
|
||||
|
||||
text = text.replace("{NOISE OF DISGUST}" , " ")
|
||||
text = text.replace("{SIGH}" , " ")
|
||||
text = text.replace("{SNIFF}" , " ")
|
||||
text = text.replace("{SNORT}" , " ")
|
||||
text = text.replace("{SHARP EXHALATION}" , " ")
|
||||
text = text.replace("{BREATH LAUGH}" , " ")
|
||||
text = text.replace("{NOISE OF DISGUST}", " ")
|
||||
text = text.replace("{SIGH}", " ")
|
||||
text = text.replace("{SNIFF}", " ")
|
||||
text = text.replace("{SNORT}", " ")
|
||||
text = text.replace("{SHARP EXHALATION}", " ")
|
||||
text = text.replace("{BREATH LAUGH}", " ")
|
||||
return text
|
||||
|
||||
def remove_languagetag(text:str) -> str:
|
||||
langtag = re.findall(r'<(.*?)>', text)
|
||||
|
||||
def remove_languagetag(text: str) -> str:
|
||||
langtag = re.findall(r"<(.*?)>", text)
|
||||
for t in langtag:
|
||||
text = text.replace(t, " ")
|
||||
text = text.replace("<"," ")
|
||||
text = text.replace(">"," ")
|
||||
text = text.replace("<", " ")
|
||||
text = text.replace(">", " ")
|
||||
return text
|
||||
|
||||
|
||||
|
||||
def eval2000_normalizer(text: str) -> str:
|
||||
#print("TEXT original: ",text)
|
||||
eform_count=text.count("contraction e_form")
|
||||
#print("eform corunt:", eform_count)
|
||||
if eform_count>0:
|
||||
text = eval2000_clean_eform(text,eform_count)
|
||||
# print("TEXT original: ",text)
|
||||
eform_count = text.count("contraction e_form")
|
||||
# print("eform corunt:", eform_count)
|
||||
if eform_count > 0:
|
||||
text = eval2000_clean_eform(text, eform_count)
|
||||
text = text.upper()
|
||||
text = remove_languagetag(text)
|
||||
text = replace_silphone(text)
|
||||
text = remove_punctutation_and_other_symbol(text)
|
||||
text = text.replace("IGNORE_TIME_SEGMENT_IN_SCORING", " ")
|
||||
text = text.replace("IGNORE_TIME_SEGMENT_SCORING", " ")
|
||||
spaces = re.findall(r'\s+', text)
|
||||
spaces = re.findall(r"\s+", text)
|
||||
for sp in spaces:
|
||||
text = text.replace(sp," ")
|
||||
text = text.strip()
|
||||
#text = self.whitespace_regexp.sub(" ", text).strip()
|
||||
#print(text)
|
||||
text = text.replace(sp, " ")
|
||||
text = text.strip()
|
||||
# text = self.whitespace_regexp.sub(" ", text).strip()
|
||||
# print(text)
|
||||
return text
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
sups = load_manifest_lazy_or_eager(args.input_sups)
|
||||
@ -203,6 +209,7 @@ def main():
|
||||
skip += 1
|
||||
continue
|
||||
writer.write(sup)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -108,9 +108,7 @@ def lexicon_to_fst_no_sil(
|
||||
disambig_token = token2id["#0"]
|
||||
disambig_word = word2id["#0"]
|
||||
arcs = add_self_loops(
|
||||
arcs,
|
||||
disambig_token=disambig_token,
|
||||
disambig_word=disambig_word,
|
||||
arcs, disambig_token=disambig_token, disambig_word=disambig_word
|
||||
)
|
||||
|
||||
final_state = next_state
|
||||
@ -223,9 +221,7 @@ def main():
|
||||
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||
|
||||
L = lexicon_to_fst_no_sil(
|
||||
lexicon,
|
||||
token2id=token_sym_table,
|
||||
word2id=word_sym_table,
|
||||
lexicon, token2id=token_sym_table, word2id=word_sym_table
|
||||
)
|
||||
|
||||
L_disambig = lexicon_to_fst_no_sil(
|
||||
|
||||
@ -68,12 +68,7 @@ def get_args():
|
||||
def get_g2p_sym2int():
|
||||
|
||||
# These symbols are removed from from g2p_en's vocabulary
|
||||
excluded_symbols = [
|
||||
"<pad>",
|
||||
"<s>",
|
||||
"</s>",
|
||||
"<unk>",
|
||||
]
|
||||
excluded_symbols = ["<pad>", "<s>", "</s>", "<unk>"]
|
||||
|
||||
symbols = [p for p in sorted(G2p().phonemes) if p not in excluded_symbols]
|
||||
# reserve 0 and 1 for blank and sos/eos/pad tokens
|
||||
@ -345,9 +340,7 @@ def lexicon_to_fst(
|
||||
disambig_token = token2id["#0"]
|
||||
disambig_word = word2id["#0"]
|
||||
arcs = add_self_loops(
|
||||
arcs,
|
||||
disambig_token=disambig_token,
|
||||
disambig_word=disambig_word,
|
||||
arcs, disambig_token=disambig_token, disambig_word=disambig_word
|
||||
)
|
||||
|
||||
final_state = next_state
|
||||
@ -396,9 +389,7 @@ def main():
|
||||
print(vocab[:10])
|
||||
|
||||
if not lexicon_filename.is_file():
|
||||
lexicon = [
|
||||
("!SIL", [sil_token]),
|
||||
]
|
||||
lexicon = [("!SIL", [sil_token])]
|
||||
for symbol in special_symbols:
|
||||
lexicon.append((symbol, [symbol[1:-1]]))
|
||||
lexicon += [
|
||||
|
||||
@ -43,16 +43,10 @@ def get_args():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--transcript",
|
||||
type=str,
|
||||
help="Training transcript.",
|
||||
)
|
||||
parser.add_argument("--transcript", type=str, help="Training transcript.")
|
||||
|
||||
parser.add_argument(
|
||||
"--vocab-size",
|
||||
type=int,
|
||||
help="Vocabulary size for BPE training",
|
||||
"--vocab-size", type=int, help="Vocabulary size for BPE training"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
#
|
||||
#
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -356,13 +356,10 @@ class FisherSwbdSpeechAsrDataModule:
|
||||
)
|
||||
else:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
cut_transforms=transforms, return_cuts=self.args.return_cuts
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
cuts_valid, max_duration=self.args.max_duration, shuffle=False
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
@ -384,9 +381,7 @@ class FisherSwbdSpeechAsrDataModule:
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
cuts, max_duration=self.args.max_duration, shuffle=False
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
@ -396,41 +391,52 @@ class FisherSwbdSpeechAsrDataModule:
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def train_fisher_cuts(self) -> CutSet:
|
||||
logging.info("About to get fisher cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "train_cuts_fisher.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_swbd_cuts(self) -> CutSet:
|
||||
logging.info("About to get train swbd cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "train_cuts_swbd.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_fisher_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev fisher cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "dev_cuts_fisher.jsonl.gz"
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "dev_cuts_fisher.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_swbd_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev swbd cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "dev_cuts_swbd.jsonl.gz"
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "dev_cuts_swbd.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_eval2000_cuts(self) -> CutSet:
|
||||
logging.info("About to get test eval2000 cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_eval2000.jsonl.gz"
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_eval2000.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_swbd_cuts(self) -> CutSet:
|
||||
logging.info("About to get test eval2000 swbd cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_eval2000_swbd.jsonl.gz"
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_eval2000_swbd.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_callhome_cuts(self) -> CutSet:
|
||||
logging.info("About to get test eval2000 callhome cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_eval2000_callhome.jsonl.gz"
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_eval2000_callhome.jsonl.gz"
|
||||
)
|
||||
|
||||
@ -550,9 +550,7 @@ def greedy_search(
|
||||
|
||||
|
||||
def greedy_search_batch(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
@ -591,9 +589,7 @@ def greedy_search_batch(
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
hyps, device=device, dtype=torch.int64
|
||||
) # (N, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
@ -630,9 +626,7 @@ def greedy_search_batch(
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
decoder_input, device=device, dtype=torch.int64
|
||||
)
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
@ -894,9 +888,7 @@ def modified_beam_search(
|
||||
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
project_input=False,
|
||||
current_encoder_out, decoder_out, project_input=False
|
||||
) # (num_hyps, 1, 1, vocab_size)
|
||||
|
||||
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||
@ -953,9 +945,7 @@ def modified_beam_search(
|
||||
|
||||
|
||||
def _deprecated_modified_beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
model: Transducer, encoder_out: torch.Tensor, beam: int = 4
|
||||
) -> List[int]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
@ -1023,9 +1013,7 @@ def _deprecated_modified_beam_search(
|
||||
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
project_input=False,
|
||||
current_encoder_out, decoder_out, project_input=False
|
||||
)
|
||||
# logits is of shape (num_hyps, 1, 1, vocab_size)
|
||||
logits = logits.squeeze(1).squeeze(1)
|
||||
@ -1097,9 +1085,7 @@ def beam_search(
|
||||
device = next(model.parameters()).device
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[blank_id] * context_size,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
[blank_id] * context_size, device=device, dtype=torch.int64
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
@ -1318,9 +1304,7 @@ def fast_beam_search_with_nbest_rescoring(
|
||||
num_unique_paths = len(word_ids_list)
|
||||
|
||||
b_to_a_map = torch.zeros(
|
||||
num_unique_paths,
|
||||
dtype=torch.int32,
|
||||
device=lattice.device,
|
||||
num_unique_paths, dtype=torch.int32, device=lattice.device
|
||||
)
|
||||
|
||||
rescored_word_fsas = k2.intersect_device(
|
||||
@ -1334,8 +1318,7 @@ def fast_beam_search_with_nbest_rescoring(
|
||||
rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas)
|
||||
rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas))
|
||||
ngram_lm_scores = rescored_word_fsas.get_tot_scores(
|
||||
use_double_scores=True,
|
||||
log_semiring=False,
|
||||
use_double_scores=True, log_semiring=False
|
||||
)
|
||||
|
||||
ans: Dict[str, List[List[int]]] = {}
|
||||
|
||||
@ -223,19 +223,10 @@ class Conformer(EncoderInterface):
|
||||
|
||||
init_states: List[torch.Tensor] = [
|
||||
torch.zeros(
|
||||
(
|
||||
self.encoder_layers,
|
||||
left_context,
|
||||
self.d_model,
|
||||
),
|
||||
device=device,
|
||||
(self.encoder_layers, left_context, self.d_model), device=device
|
||||
),
|
||||
torch.zeros(
|
||||
(
|
||||
self.encoder_layers,
|
||||
self.cnn_module_kernel - 1,
|
||||
self.d_model,
|
||||
),
|
||||
(self.encoder_layers, self.cnn_module_kernel - 1, self.d_model),
|
||||
device=device,
|
||||
),
|
||||
]
|
||||
@ -330,7 +321,9 @@ class Conformer(EncoderInterface):
|
||||
{(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
|
||||
given {states[1].shape}."""
|
||||
|
||||
lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
|
||||
lengths -= (
|
||||
2
|
||||
) # we will cut off 1 frame on each side of encoder_embed output
|
||||
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
|
||||
@ -829,9 +822,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
left_context: int = 0,
|
||||
self, x: torch.Tensor, left_context: int = 0
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Add positional encoding.
|
||||
|
||||
@ -875,10 +866,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
dropout: float = 0.0,
|
||||
self, embed_dim: int, num_heads: int, dropout: float = 0.0
|
||||
) -> None:
|
||||
super(RelPositionMultiheadAttention, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
@ -1272,8 +1260,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
bsz, num_heads, tgt_len, src_len
|
||||
)
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
||||
float("-inf"),
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")
|
||||
)
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, tgt_len, src_len
|
||||
@ -1420,10 +1407,7 @@ class ConvolutionModule(nn.Module):
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
cache: Optional[Tensor] = None,
|
||||
right_context: int = 0,
|
||||
self, x: Tensor, cache: Optional[Tensor] = None, right_context: int = 0
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Compute convolution module.
|
||||
|
||||
|
||||
@ -384,9 +384,7 @@ def decode_one_batch(
|
||||
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
feature, pad=(0, 0, 0, params.left_context), value=LOG_EPS
|
||||
)
|
||||
|
||||
if params.simulate_streaming:
|
||||
@ -778,7 +776,7 @@ def main():
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
|
||||
fisherswbd = FisherSwbdSpeechAsrDataModule(args)
|
||||
|
||||
test_eval2000_cuts = fisherswbd.test_eval2000_cuts()
|
||||
@ -803,9 +801,7 @@ def main():
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
params=params, test_set_name=test_set, results_dict=results_dict
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
@ -92,10 +92,7 @@ class DecodeStream(object):
|
||||
"""Return True if all the features are processed."""
|
||||
return self._done
|
||||
|
||||
def set_features(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
) -> None:
|
||||
def set_features(self, features: torch.Tensor) -> None:
|
||||
"""Set features tensor of current utterance."""
|
||||
assert features.dim() == 2, features.dim()
|
||||
self.features = torch.nn.functional.pad(
|
||||
|
||||
@ -96,11 +96,7 @@ def get_parser():
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
parser.add_argument("--bpe-model", type=str, help="""Path to bpe.model.""")
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
|
||||
@ -39,7 +39,8 @@ import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
#from asr_datamodule import LibriSpeechAsrDataModule
|
||||
|
||||
# from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from asr_datamodule import FisherSwbdSpeechAsrDataModule
|
||||
from decode_stream import DecodeStream
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
@ -187,9 +188,7 @@ def get_parser():
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
streams: List[DecodeStream],
|
||||
model: nn.Module, encoder_out: torch.Tensor, streams: List[DecodeStream]
|
||||
) -> List[List[int]]:
|
||||
|
||||
assert len(streams) == encoder_out.size(0)
|
||||
@ -236,10 +235,7 @@ def greedy_search(
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=False,
|
||||
)
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
hyp_tokens = []
|
||||
@ -290,9 +286,7 @@ def fast_beam_search(
|
||||
|
||||
|
||||
def decode_one_chunk(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
decode_streams: List[DecodeStream],
|
||||
params: AttributeDict, model: nn.Module, decode_streams: List[DecodeStream]
|
||||
) -> List[int]:
|
||||
"""Decode one chunk frames of features for each decode_streams and
|
||||
return the indexes of finished streams in a List.
|
||||
@ -502,10 +496,7 @@ def decode_dataset(
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = hyp[params.context_size :] # noqa
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].ground_truth.split(),
|
||||
sp.decode(hyp).split(),
|
||||
)
|
||||
(decode_streams[i].ground_truth.split(), sp.decode(hyp).split())
|
||||
)
|
||||
del decode_streams[i]
|
||||
|
||||
@ -661,7 +652,7 @@ def main():
|
||||
fisherswbd = FisherSwbdSpeechAsrDataModule(args)
|
||||
|
||||
test_eval2000_cuts = fisherswbd.test_eval2000_cuts()
|
||||
test_swbd_cuts = fisherswbd.test_swbd_cuts ()
|
||||
test_swbd_cuts = fisherswbd.test_swbd_cuts()
|
||||
test_callhome_cuts = fisherswbd.test_callhome_cuts()
|
||||
|
||||
test_eval2000_dl = fisherswbd.test_dataloaders(test_eval2000_cuts)
|
||||
@ -681,9 +672,7 @@ def main():
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
params=params, test_set_name=test_set, results_dict=results_dict
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
@ -155,10 +155,7 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Number of epochs to train.",
|
||||
"--num-epochs", type=int, default=30, help="Number of epochs to train."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -480,10 +477,7 @@ def load_checkpoint_if_available(
|
||||
assert filename.is_file(), f"{filename} does not exist!"
|
||||
|
||||
saved_params = load_checkpoint(
|
||||
filename,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
filename, model=model, optimizer=optimizer, scheduler=scheduler
|
||||
)
|
||||
|
||||
keys = [
|
||||
@ -646,11 +640,7 @@ def compute_validation_loss(
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=False,
|
||||
params=params, model=model, sp=sp, batch=batch, is_training=False
|
||||
)
|
||||
assert loss.requires_grad is False
|
||||
tot_loss = tot_loss + loss_info
|
||||
@ -767,9 +757,7 @@ def train_one_epoch(
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
rank=rank,
|
||||
out_dir=params.exp_dir, topk=params.keep_last_k, rank=rank
|
||||
)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
@ -830,8 +818,6 @@ def run(rank, world_size, args):
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
if params.full_libri is False:
|
||||
params.valid_interval = 1600
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
if world_size > 1:
|
||||
@ -897,11 +883,11 @@ def run(rank, world_size, args):
|
||||
if params.print_diagnostics:
|
||||
diagnostic = diagnostics.attach_diagnostics(model)
|
||||
|
||||
librispeech = FisherSwbdSpeechAsrDataModule(args)
|
||||
fisherswbd = FisherSwbdSpeechAsrDataModule(args)
|
||||
|
||||
train_cuts = fisherswbd.train_fisher_cuts()
|
||||
train_cuts += fisherswbd.train_swbd_cuts()
|
||||
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
#
|
||||
@ -991,9 +977,7 @@ def run(rank, world_size, args):
|
||||
|
||||
|
||||
def display_and_save_batch(
|
||||
batch: dict,
|
||||
params: AttributeDict,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict, params: AttributeDict, sp: spm.SentencePieceProcessor
|
||||
) -> None:
|
||||
"""Display the batch statistics and save the batch into disk.
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user