mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-19 06:53:10 +00:00
black
This commit is contained in:
parent
aad2b7940d
commit
d394e88020
@ -40,7 +40,7 @@ from icefall.utils import str2bool
|
|||||||
|
|
||||||
class Resample16kHz:
|
class Resample16kHz:
|
||||||
def __call__(self, cuts: CutSet) -> CutSet:
|
def __call__(self, cuts: CutSet) -> CutSet:
|
||||||
return cuts.resample(16000).with_recording_path_prefix('download')
|
return cuts.resample(16000).with_recording_path_prefix("download")
|
||||||
|
|
||||||
|
|
||||||
class AsrDataModule:
|
class AsrDataModule:
|
||||||
@ -282,5 +282,5 @@ def test():
|
|||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
test()
|
test()
|
||||||
|
@ -665,7 +665,9 @@ def main():
|
|||||||
datamodule = AsrDataModule(args)
|
datamodule = AsrDataModule(args)
|
||||||
|
|
||||||
fisher_swbd_dev_cuts = datamodule.dev_cuts()
|
fisher_swbd_dev_cuts = datamodule.dev_cuts()
|
||||||
fisher_swbd_dev_dataloader = datamodule.test_dataloaders(fisher_swbd_dev_cuts)
|
fisher_swbd_dev_dataloader = datamodule.test_dataloaders(
|
||||||
|
fisher_swbd_dev_cuts
|
||||||
|
)
|
||||||
|
|
||||||
test_sets = ["dev-fisher-swbd"]
|
test_sets = ["dev-fisher-swbd"]
|
||||||
test_dl = [fisher_swbd_dev_dataloader]
|
test_dl = [fisher_swbd_dev_dataloader]
|
||||||
|
@ -17,6 +17,7 @@ def get_args():
|
|||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
class FisherSwbdNormalizer:
|
class FisherSwbdNormalizer:
|
||||||
"""
|
"""
|
||||||
Note: the functions "normalize" and "keep" implement the logic similar to
|
Note: the functions "normalize" and "keep" implement the logic similar to
|
||||||
@ -118,6 +119,7 @@ class FisherSwbdNormalizer:
|
|||||||
text = self.whitespace_regexp.sub(" ", text).strip()
|
text = self.whitespace_regexp.sub(" ", text).strip()
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
def keep(sup: SupervisionSegment) -> bool:
|
def keep(sup: SupervisionSegment) -> bool:
|
||||||
@ -181,6 +183,7 @@ def test():
|
|||||||
print(normalizer.normalize(text))
|
print(normalizer.normalize(text))
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# test()
|
# test()
|
||||||
main()
|
main()
|
||||||
|
@ -106,7 +106,6 @@ def get_g2p_sym2int():
|
|||||||
return sym2int
|
return sym2int
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
|
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
|
||||||
"""Write a symbol to ID mapping to a file.
|
"""Write a symbol to ID mapping to a file.
|
||||||
|
|
||||||
@ -382,7 +381,17 @@ def main():
|
|||||||
lexicon_filename = lang_dir / "lexicon.txt"
|
lexicon_filename = lang_dir / "lexicon.txt"
|
||||||
sil_token = "SIL"
|
sil_token = "SIL"
|
||||||
sil_prob = 0.5
|
sil_prob = 0.5
|
||||||
special_symbols = ["[UNK]", "[BREATH]", "[COUGH]", "[LAUGHTER]", "[LIPSMACK]", "[NOISE]", "[SIGH]", "[SNEEZE]", "[VOCALIZED-NOISE]"]
|
special_symbols = [
|
||||||
|
"[UNK]",
|
||||||
|
"[BREATH]",
|
||||||
|
"[COUGH]",
|
||||||
|
"[LAUGHTER]",
|
||||||
|
"[LIPSMACK]",
|
||||||
|
"[NOISE]",
|
||||||
|
"[SIGH]",
|
||||||
|
"[SNEEZE]",
|
||||||
|
"[VOCALIZED-NOISE]",
|
||||||
|
]
|
||||||
|
|
||||||
g2p = G2p()
|
g2p = G2p()
|
||||||
token2id = get_g2p_sym2int()
|
token2id = get_g2p_sym2int()
|
||||||
@ -407,8 +416,15 @@ def main():
|
|||||||
(
|
(
|
||||||
word,
|
word,
|
||||||
[
|
[
|
||||||
phn for phn in g2p(word)
|
phn
|
||||||
if phn not in ("'", " ", "-", ",") # g2p_en has these symbols as phones
|
for phn in g2p(word)
|
||||||
|
if phn
|
||||||
|
not in (
|
||||||
|
"'",
|
||||||
|
" ",
|
||||||
|
"-",
|
||||||
|
",",
|
||||||
|
) # g2p_en has these symbols as phones
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
for word in tqdm(vocab, desc="Processing vocab with G2P")
|
for word in tqdm(vocab, desc="Processing vocab with G2P")
|
||||||
@ -437,9 +453,9 @@ def main():
|
|||||||
token2id = dict(sorted(token2id.items(), key=lambda tpl: tpl[1]))
|
token2id = dict(sorted(token2id.items(), key=lambda tpl: tpl[1]))
|
||||||
print(token2id)
|
print(token2id)
|
||||||
word2id = {"<eps>": 0}
|
word2id = {"<eps>": 0}
|
||||||
word2id.update({
|
word2id.update(
|
||||||
word: int(id_) for id_, (word, pron) in enumerate(lexicon, start=1)
|
{word: int(id_) for id_, (word, pron) in enumerate(lexicon, start=1)}
|
||||||
})
|
)
|
||||||
for symbol in ["<s>", "</s>", "#0"]:
|
for symbol in ["<s>", "</s>", "#0"]:
|
||||||
word2id[symbol] = len(word2id)
|
word2id[symbol] = len(word2id)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user