This commit is contained in:
Piotr Żelasko 2022-01-17 23:08:48 +00:00
parent aad2b7940d
commit d394e88020
4 changed files with 31 additions and 10 deletions

View File

@ -40,7 +40,7 @@ from icefall.utils import str2bool
class Resample16kHz:
def __call__(self, cuts: CutSet) -> CutSet:
return cuts.resample(16000).with_recording_path_prefix('download')
return cuts.resample(16000).with_recording_path_prefix("download")
class AsrDataModule:
@ -282,5 +282,5 @@ def test():
break
if __name__ == '__main__':
if __name__ == "__main__":
test()

View File

@ -665,7 +665,9 @@ def main():
datamodule = AsrDataModule(args)
fisher_swbd_dev_cuts = datamodule.dev_cuts()
fisher_swbd_dev_dataloader = datamodule.test_dataloaders(fisher_swbd_dev_cuts)
fisher_swbd_dev_dataloader = datamodule.test_dataloaders(
fisher_swbd_dev_cuts
)
test_sets = ["dev-fisher-swbd"]
test_dl = [fisher_swbd_dev_dataloader]

View File

@ -17,6 +17,7 @@ def get_args():
return parser.parse_args()
# fmt: off
class FisherSwbdNormalizer:
"""
Note: the functions "normalize" and "keep" implement the logic similar to
@ -118,6 +119,7 @@ class FisherSwbdNormalizer:
text = self.whitespace_regexp.sub(" ", text).strip()
return text
# fmt: on
def keep(sup: SupervisionSegment) -> bool:
@ -181,6 +183,7 @@ def test():
print(normalizer.normalize(text))
print()
if __name__ == "__main__":
# test()
main()

View File

@ -106,7 +106,6 @@ def get_g2p_sym2int():
return sym2int
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file.
@ -382,7 +381,17 @@ def main():
lexicon_filename = lang_dir / "lexicon.txt"
sil_token = "SIL"
sil_prob = 0.5
special_symbols = ["[UNK]", "[BREATH]", "[COUGH]", "[LAUGHTER]", "[LIPSMACK]", "[NOISE]", "[SIGH]", "[SNEEZE]", "[VOCALIZED-NOISE]"]
special_symbols = [
"[UNK]",
"[BREATH]",
"[COUGH]",
"[LAUGHTER]",
"[LIPSMACK]",
"[NOISE]",
"[SIGH]",
"[SNEEZE]",
"[VOCALIZED-NOISE]",
]
g2p = G2p()
token2id = get_g2p_sym2int()
@ -407,8 +416,15 @@ def main():
(
word,
[
phn for phn in g2p(word)
if phn not in ("'", " ", "-", ",") # g2p_en has these symbols as phones
phn
for phn in g2p(word)
if phn
not in (
"'",
" ",
"-",
",",
) # g2p_en has these symbols as phones
],
)
for word in tqdm(vocab, desc="Processing vocab with G2P")
@ -437,9 +453,9 @@ def main():
token2id = dict(sorted(token2id.items(), key=lambda tpl: tpl[1]))
print(token2id)
word2id = {"<eps>": 0}
word2id.update({
word: int(id_) for id_, (word, pron) in enumerate(lexicon, start=1)
})
word2id.update(
{word: int(id_) for id_, (word, pron) in enumerate(lexicon, start=1)}
)
for symbol in ["<s>", "</s>", "#0"]:
word2id[symbol] = len(word2id)