From d394e88020c1b6eaf2dd999b7406417e866e760e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 17 Jan 2022 23:08:48 +0000 Subject: [PATCH] black --- .../ASR/conformer_ctc/asr_datamodule.py | 4 +-- egs/fisher_swbd/ASR/conformer_ctc/decode.py | 4 ++- .../normalize_and_filter_supervisions.py | 3 ++ .../ASR/local/prepare_lang_g2pen.py | 30 ++++++++++++++----- 4 files changed, 31 insertions(+), 10 deletions(-) diff --git a/egs/fisher_swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/fisher_swbd/ASR/conformer_ctc/asr_datamodule.py index 7abe169d1..f9d25b7de 100644 --- a/egs/fisher_swbd/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/fisher_swbd/ASR/conformer_ctc/asr_datamodule.py @@ -40,7 +40,7 @@ from icefall.utils import str2bool class Resample16kHz: def __call__(self, cuts: CutSet) -> CutSet: - return cuts.resample(16000).with_recording_path_prefix('download') + return cuts.resample(16000).with_recording_path_prefix("download") class AsrDataModule: @@ -282,5 +282,5 @@ def test(): break -if __name__ == '__main__': +if __name__ == "__main__": test() diff --git a/egs/fisher_swbd/ASR/conformer_ctc/decode.py b/egs/fisher_swbd/ASR/conformer_ctc/decode.py index 3185ff777..467ece3df 100755 --- a/egs/fisher_swbd/ASR/conformer_ctc/decode.py +++ b/egs/fisher_swbd/ASR/conformer_ctc/decode.py @@ -665,7 +665,9 @@ def main(): datamodule = AsrDataModule(args) fisher_swbd_dev_cuts = datamodule.dev_cuts() - fisher_swbd_dev_dataloader = datamodule.test_dataloaders(fisher_swbd_dev_cuts) + fisher_swbd_dev_dataloader = datamodule.test_dataloaders( + fisher_swbd_dev_cuts + ) test_sets = ["dev-fisher-swbd"] test_dl = [fisher_swbd_dev_dataloader] diff --git a/egs/fisher_swbd/ASR/local/normalize_and_filter_supervisions.py b/egs/fisher_swbd/ASR/local/normalize_and_filter_supervisions.py index 22f265c9d..cd65f1c86 100644 --- a/egs/fisher_swbd/ASR/local/normalize_and_filter_supervisions.py +++ b/egs/fisher_swbd/ASR/local/normalize_and_filter_supervisions.py @@ -17,6 +17,7 @@ def get_args(): return parser.parse_args() +# fmt: off class FisherSwbdNormalizer: """ Note: the functions "normalize" and "keep" implement the logic similar to @@ -118,6 +119,7 @@ class FisherSwbdNormalizer: text = self.whitespace_regexp.sub(" ", text).strip() return text +# fmt: on def keep(sup: SupervisionSegment) -> bool: @@ -181,6 +183,7 @@ def test(): print(normalizer.normalize(text)) print() + if __name__ == "__main__": # test() main() diff --git a/egs/fisher_swbd/ASR/local/prepare_lang_g2pen.py b/egs/fisher_swbd/ASR/local/prepare_lang_g2pen.py index c771dac4b..4768e1dc0 100755 --- a/egs/fisher_swbd/ASR/local/prepare_lang_g2pen.py +++ b/egs/fisher_swbd/ASR/local/prepare_lang_g2pen.py @@ -106,7 +106,6 @@ def get_g2p_sym2int(): return sym2int - def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: """Write a symbol to ID mapping to a file. @@ -382,7 +381,17 @@ def main(): lexicon_filename = lang_dir / "lexicon.txt" sil_token = "SIL" sil_prob = 0.5 - special_symbols = ["[UNK]", "[BREATH]", "[COUGH]", "[LAUGHTER]", "[LIPSMACK]", "[NOISE]", "[SIGH]", "[SNEEZE]", "[VOCALIZED-NOISE]"] + special_symbols = [ + "[UNK]", + "[BREATH]", + "[COUGH]", + "[LAUGHTER]", + "[LIPSMACK]", + "[NOISE]", + "[SIGH]", + "[SNEEZE]", + "[VOCALIZED-NOISE]", + ] g2p = G2p() token2id = get_g2p_sym2int() @@ -407,8 +416,15 @@ def main(): ( word, [ - phn for phn in g2p(word) - if phn not in ("'", " ", "-", ",") # g2p_en has these symbols as phones + phn + for phn in g2p(word) + if phn + not in ( + "'", + " ", + "-", + ",", + ) # g2p_en has these symbols as phones ], ) for word in tqdm(vocab, desc="Processing vocab with G2P") @@ -437,9 +453,9 @@ def main(): token2id = dict(sorted(token2id.items(), key=lambda tpl: tpl[1])) print(token2id) word2id = {"": 0} - word2id.update({ - word: int(id_) for id_, (word, pron) in enumerate(lexicon, start=1) - }) + word2id.update( + {word: int(id_) for id_, (word, pron) in enumerate(lexicon, start=1)} + ) for symbol in ["", "", "#0"]: word2id[symbol] = len(word2id)