This commit is contained in:
Piotr Żelasko 2022-01-17 23:08:48 +00:00
parent aad2b7940d
commit d394e88020
4 changed files with 31 additions and 10 deletions

View File

@ -40,7 +40,7 @@ from icefall.utils import str2bool
class Resample16kHz: class Resample16kHz:
def __call__(self, cuts: CutSet) -> CutSet: def __call__(self, cuts: CutSet) -> CutSet:
return cuts.resample(16000).with_recording_path_prefix('download') return cuts.resample(16000).with_recording_path_prefix("download")
class AsrDataModule: class AsrDataModule:
@ -282,5 +282,5 @@ def test():
break break
if __name__ == '__main__': if __name__ == "__main__":
test() test()

View File

@ -665,7 +665,9 @@ def main():
datamodule = AsrDataModule(args) datamodule = AsrDataModule(args)
fisher_swbd_dev_cuts = datamodule.dev_cuts() fisher_swbd_dev_cuts = datamodule.dev_cuts()
fisher_swbd_dev_dataloader = datamodule.test_dataloaders(fisher_swbd_dev_cuts) fisher_swbd_dev_dataloader = datamodule.test_dataloaders(
fisher_swbd_dev_cuts
)
test_sets = ["dev-fisher-swbd"] test_sets = ["dev-fisher-swbd"]
test_dl = [fisher_swbd_dev_dataloader] test_dl = [fisher_swbd_dev_dataloader]

View File

@ -17,6 +17,7 @@ def get_args():
return parser.parse_args() return parser.parse_args()
# fmt: off
class FisherSwbdNormalizer: class FisherSwbdNormalizer:
""" """
Note: the functions "normalize" and "keep" implement the logic similar to Note: the functions "normalize" and "keep" implement the logic similar to
@ -118,6 +119,7 @@ class FisherSwbdNormalizer:
text = self.whitespace_regexp.sub(" ", text).strip() text = self.whitespace_regexp.sub(" ", text).strip()
return text return text
# fmt: on
def keep(sup: SupervisionSegment) -> bool: def keep(sup: SupervisionSegment) -> bool:
@ -181,6 +183,7 @@ def test():
print(normalizer.normalize(text)) print(normalizer.normalize(text))
print() print()
if __name__ == "__main__": if __name__ == "__main__":
# test() # test()
main() main()

View File

@ -106,7 +106,6 @@ def get_g2p_sym2int():
return sym2int return sym2int
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file. """Write a symbol to ID mapping to a file.
@ -382,7 +381,17 @@ def main():
lexicon_filename = lang_dir / "lexicon.txt" lexicon_filename = lang_dir / "lexicon.txt"
sil_token = "SIL" sil_token = "SIL"
sil_prob = 0.5 sil_prob = 0.5
special_symbols = ["[UNK]", "[BREATH]", "[COUGH]", "[LAUGHTER]", "[LIPSMACK]", "[NOISE]", "[SIGH]", "[SNEEZE]", "[VOCALIZED-NOISE]"] special_symbols = [
"[UNK]",
"[BREATH]",
"[COUGH]",
"[LAUGHTER]",
"[LIPSMACK]",
"[NOISE]",
"[SIGH]",
"[SNEEZE]",
"[VOCALIZED-NOISE]",
]
g2p = G2p() g2p = G2p()
token2id = get_g2p_sym2int() token2id = get_g2p_sym2int()
@ -407,8 +416,15 @@ def main():
( (
word, word,
[ [
phn for phn in g2p(word) phn
if phn not in ("'", " ", "-", ",") # g2p_en has these symbols as phones for phn in g2p(word)
if phn
not in (
"'",
" ",
"-",
",",
) # g2p_en has these symbols as phones
], ],
) )
for word in tqdm(vocab, desc="Processing vocab with G2P") for word in tqdm(vocab, desc="Processing vocab with G2P")
@ -437,9 +453,9 @@ def main():
token2id = dict(sorted(token2id.items(), key=lambda tpl: tpl[1])) token2id = dict(sorted(token2id.items(), key=lambda tpl: tpl[1]))
print(token2id) print(token2id)
word2id = {"<eps>": 0} word2id = {"<eps>": 0}
word2id.update({ word2id.update(
word: int(id_) for id_, (word, pron) in enumerate(lexicon, start=1) {word: int(id_) for id_, (word, pron) in enumerate(lexicon, start=1)}
}) )
for symbol in ["<s>", "</s>", "#0"]: for symbol in ["<s>", "</s>", "#0"]:
word2id[symbol] = len(word2id) word2id[symbol] = len(word2id)