From d8907769ee2818edc2d522ba8514a276365a0415 Mon Sep 17 00:00:00 2001 From: luomingshuang <739314837@qq.com> Date: Fri, 8 Apr 2022 19:00:19 +0800 Subject: [PATCH] add and update some files --- .../ASR/local/display_manifest_statistics.py | 135 +++++++++ egs/wenetspeech/ASR/local/prepare_char.py | 246 +++++++++++++++++ egs/wenetspeech/ASR/local/prepare_pinyin.py | 260 ++++++++++++++++++ egs/wenetspeech/ASR/local/prepare_words.py | 84 ++++++ egs/wenetspeech/ASR/local/text2segments.py | 83 ++++++ egs/wenetspeech/ASR/local/text2token.py | 2 +- egs/wenetspeech/ASR/prepare.sh | 144 +++------- .../asr_datamodule.py | 8 +- .../ASR/pruned_transducer_stateless/train.py | 67 +++-- icefall/pinyin_graph_compiler.py | 219 +++++++++++++++ 10 files changed, 1117 insertions(+), 131 deletions(-) create mode 100644 egs/wenetspeech/ASR/local/display_manifest_statistics.py create mode 100755 egs/wenetspeech/ASR/local/prepare_char.py create mode 100755 egs/wenetspeech/ASR/local/prepare_pinyin.py create mode 100644 egs/wenetspeech/ASR/local/prepare_words.py create mode 100644 egs/wenetspeech/ASR/local/text2segments.py create mode 100644 icefall/pinyin_graph_compiler.py diff --git a/egs/wenetspeech/ASR/local/display_manifest_statistics.py b/egs/wenetspeech/ASR/local/display_manifest_statistics.py new file mode 100644 index 000000000..0706361b4 --- /dev/null +++ b/egs/wenetspeech/ASR/local/display_manifest_statistics.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. +See the function `remove_short_and_long_utt()` +in ../../../librispeech/ASR/transducer/train.py +for usage. +""" + + +from lhotse import load_manifest + + +def main(): + paths = [ # "./data/fbank/cuts_L_100_pieces.jsonl.gz", + "./data/fbank/cuts_L_50_pieces.jsonl.gz", + # "./data/fbank/cuts_DEV.jsonl.gz", + # "./data/fbank/cuts_TEST_NET.jsonl.gz", + # "./data/fbank/cuts_TEST_MEETING.jsonl.gz" + ] + + for path in paths: + print(f"Starting display the statistics for {path}") + cuts = load_manifest(path) + cuts.describe() + + +if __name__ == "__main__": + main() + +""" +Starting display the statistics for ./data/fbank/cuts_L_50_pieces.jsonl.gz +Cuts count: 2241476 +Total duration (hours): 1475.0 +Speech duration (hours): 1475.0 (100.0%) +*** +Duration statistics (seconds): +mean 2.4 +std 1.6 +min 0.3 +25% 1.3 +50% 2.0 +75% 2.9 +99% 8.2 +99.5% 9.3 +99.9% 13.5 +max 87.0 + +Starting display the statistics for ./data/fbank/cuts_L_100_pieces.jsonl.gz +Cuts count: 4929619 +Total duration (hours): 3361.1 +Speech duration (hours): 3361.1 (100.0%) +*** +Duration statistics (seconds): +mean 2.5 +std 1.7 +min 0.3 +25% 1.4 +50% 2.0 +75% 3.0 +99% 8.1 +99.5% 8.8 +99.9% 14.7 +max 87.0 + +Starting display the statistics for ./data/fbank/cuts_DEV.jsonl.gz +Cuts count: 13825 +Total duration (hours): 20.0 +Speech duration (hours): 20.0 (100.0%) +*** +Duration statistics (seconds): +mean 5.2 +std 2.2 +min 1.0 +25% 3.3 +50% 4.9 +75% 7.0 +99% 9.6 +99.5% 9.8 +99.9% 10.0 +max 10.0 + +Starting display the statistics for ./data/fbank/cuts_TEST_NET.jsonl.gz +Cuts count: 24774 +Total duration (hours): 23.1 +Speech duration (hours): 23.1 (100.0%) +*** +Duration statistics (seconds): +mean 3.4 +std 2.6 +min 0.1 +25% 1.4 +50% 2.4 +75% 4.8 +99% 13.1 +99.5% 14.5 +99.9% 18.5 +max 33.3 + +Starting display the statistics for ./data/fbank/cuts_TEST_MEETING.jsonl.gz +Cuts count: 8370 +Total duration (hours): 15.2 +Speech duration (hours): 15.2 (100.0%) +*** +Duration statistics (seconds): +mean 6.5 +std 3.5 +min 0.8 +25% 3.7 +50% 5.8 +75% 8.8 +99% 15.2 +99.5% 16.0 +99.9% 18.8 +max 24.6 + +""" diff --git a/egs/wenetspeech/ASR/local/prepare_char.py b/egs/wenetspeech/ASR/local/prepare_char.py new file mode 100755 index 000000000..8bc073c75 --- /dev/null +++ b/egs/wenetspeech/ASR/local/prepare_char.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input `lang_dir`, which should contain:: + - lang_dir/text, + - lang_dir/words.txt +and generates the following files in the directory `lang_dir`: + - lexicon.txt + - lexicon_disambig.txt + - L.pt + - L_disambig.pt + - tokens.txt +""" + +import argparse +import re +from pathlib import Path +from typing import Dict, List + +import k2 +import torch +from prepare_lang import ( + Lexicon, + add_disambig_symbols, + add_self_loops, + write_lexicon, + write_mapping, +) + + +def lexicon_to_fst_no_sil( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format). + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + loop_state = 0 # words enter and leave from here + next_state = 1 # the next un-allocated state, will be incremented as we go + + arcs = [] + + # The blank symbol is defined in local/train_bpe_model.py + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + for word, pieces in lexicon: + assert len(pieces) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + pieces = [ + token2id[i] if i in token2id else token2id[""] for i in pieces + ] + + for i in range(len(pieces) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, pieces[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last piece of this word + i = len(pieces) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, pieces[i], w, 0]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: + """Check if all the given tokens are in token symbol table. + Args: + token_sym_table: + Token symbol table that contains all the valid tokens. + tokens: + A list of tokens. + Returns: + Return True if there is any token not in the token_sym_table, + otherwise False. + """ + for tok in tokens: + if tok not in token_sym_table: + return True + return False + + +def generate_lexicon( + token_sym_table: Dict[str, int], words: List[str] +) -> Lexicon: + """Generate a lexicon from a word list and token_sym_table. + Args: + token_sym_table: + Token symbol table that mapping token to token ids. + words: + A list of strings representing words. + Returns: + Return a dict whose keys are words and values are the corresponding + tokens. + """ + lexicon = [] + for word in words: + chars = list(word.strip(" \t")) + if contain_oov(token_sym_table, chars): + continue + lexicon.append((word, chars)) + + # The OOV word is + lexicon.append(("", [""])) + return lexicon + + +def generate_tokens(text_file: str) -> Dict[str, int]: + """Generate tokens from the given text file. + Args: + text_file: + A file that contains text lines to generate tokens. + Returns: + Return a dict whose keys are tokens and values are token ids ranged + from 0 to len(keys) - 1. + """ + tokens: Dict[str, int] = dict() + tokens[""] = 0 + tokens[""] = 1 + tokens[""] = 2 + whitespace = re.compile(r"([ \t\r\n]+)") + with open(text_file, "r", encoding="utf-8") as f: + for line in f: + line = re.sub(whitespace, "", line) + tokens_list = list(line) + for token in tokens_list: + if token not in tokens: + tokens[token] = len(tokens) + return tokens + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--lang-dir", type=str, help="The lang directory.") + args = parser.parse_args() + + lang_dir = Path(args.lang_dir) + text_file = lang_dir / "text" + + word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") + + words = word_sym_table.symbols + + excluded = ["", "!SIL", "", "", "#0", "", ""] + for w in excluded: + if w in words: + words.remove(w) + + token_sym_table = generate_tokens(text_file) + + lexicon = generate_lexicon(token_sym_table, words) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + next_token_id = max(token_sym_table.values()) + 1 + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in token_sym_table + token_sym_table[disambig] = next_token_id + next_token_id += 1 + + word_sym_table.add("#0") + word_sym_table.add("") + word_sym_table.add("") + + write_mapping(lang_dir / "tokens.txt", token_sym_table) + + write_lexicon(lang_dir / "lexicon.txt", lexicon) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst_no_sil( + lexicon, + token2id=token_sym_table, + word2id=word_sym_table, + ) + + L_disambig = lexicon_to_fst_no_sil( + lexicon_disambig, + token2id=token_sym_table, + word2id=word_sym_table, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech/ASR/local/prepare_pinyin.py b/egs/wenetspeech/ASR/local/prepare_pinyin.py new file mode 100755 index 000000000..128cab0b7 --- /dev/null +++ b/egs/wenetspeech/ASR/local/prepare_pinyin.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input `lang_dir`, which should contain:: + - lang_dir/text, + - lang_dir/words.txt +and generates the following files in the directory `lang_dir`: + - lexicon.txt + - lexicon_disambig.txt + - L.pt + - L_disambig.pt + - tokens.txt +""" + +import argparse +from pathlib import Path +from typing import Dict, List + +import k2 +import torch +from prepare_lang import ( + Lexicon, + add_disambig_symbols, + add_self_loops, + write_lexicon, + write_mapping, +) +from pypinyin import pinyin +from tqdm import tqdm + + +def lexicon_to_fst_no_sil( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format). + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + loop_state = 0 # words enter and leave from here + next_state = 1 # the next un-allocated state, will be incremented as we go + + arcs = [] + + # The blank symbol is defined in local/train_bpe_model.py + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + for word, pieces in lexicon: + assert len(pieces) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + pieces = [ + token2id[i] if i in token2id else token2id[""] for i in pieces + ] + + for i in range(len(pieces) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, pieces[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last piece of this word + i = len(pieces) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, pieces[i], w, 0]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: + """Check if all the given tokens are in token symbol table. + Args: + token_sym_table: + Token symbol table that contains all the valid tokens. + tokens: + A list of tokens. + Returns: + Return True if there is any token not in the token_sym_table, + otherwise False. + """ + for tok in tokens: + if tok not in token_sym_table: + return True + return False + + +def generate_lexicon( + token_sym_table: Dict[str, int], words: List[str] +) -> Lexicon: + """Generate a lexicon from a word list and token_sym_table. + Args: + token_sym_table: + Token symbol table that mapping token to token ids. + words: + A list of strings representing words. + Returns: + Return a dict whose keys are words and values are the corresponding + tokens. + """ + lexicon = [] + for i in tqdm(range(len(words))): + word = words[i] + tokens = [] + pinyins = pinyin(word.strip(" \t")) + for pinyin_one in pinyins: + if pinyin_one[0].isupper(): + tokens.extend(list(pinyin_one[0])) + else: + tokens.append(pinyin_one[0]) + if contain_oov(token_sym_table, tokens): + continue + lexicon.append((word, tokens)) + + # The OOV word is + lexicon.append(("", [""])) + return lexicon + + +def generate_tokens(words: List[str]) -> Dict[str, int]: + """Generate tokens from the given text file. + Args: + words: + The list of words after removing , !SIL, and so on. + Returns: + Return a dict whose keys are tokens and values are token ids ranged + from 0 to len(keys) - 1. + """ + tokens: Dict[str, int] = dict() + tokens[""] = 0 + tokens[""] = 1 + tokens[""] = 2 + for i in tqdm(range(len(words))): + word = words[i] + pinyins_list = pinyin(word) + for pinyin_one in pinyins_list: + if pinyin_one[0].isupper(): + tokens_list = list(pinyin_one[0]) + else: + tokens_list = pinyin_one + for token in tokens_list: + if token not in tokens: + tokens[token] = len(tokens) + tokens = sorted(tokens.items(), key=lambda item: item[0]) + tokens = dict(tokens) + + return tokens + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--lang-dir", type=str, help="The lang directory.") + args = parser.parse_args() + + lang_dir = Path(args.lang_dir) + words_file = lang_dir / "words.txt" + + word_sym_table = k2.SymbolTable.from_file(words_file) + + words = word_sym_table.symbols + + excluded = ["", "!SIL", "", "", "#0", "", ""] + for w in excluded: + if w in words: + words.remove(w) + + token_sym_table = generate_tokens(words) + + lexicon = generate_lexicon(token_sym_table, words) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + next_token_id = max(token_sym_table.values()) + 1 + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in token_sym_table + token_sym_table[disambig] = next_token_id + next_token_id += 1 + + word_sym_table.add("#0") + word_sym_table.add("") + word_sym_table.add("") + + write_mapping(lang_dir / "tokens.txt", token_sym_table) + + write_lexicon(lang_dir / "lexicon.txt", lexicon) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst_no_sil( + lexicon, + token2id=token_sym_table, + word2id=word_sym_table, + ) + + L_disambig = lexicon_to_fst_no_sil( + lexicon_disambig, + token2id=token_sym_table, + word2id=word_sym_table, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech/ASR/local/prepare_words.py b/egs/wenetspeech/ASR/local/prepare_words.py new file mode 100644 index 000000000..65aca2983 --- /dev/null +++ b/egs/wenetspeech/ASR/local/prepare_words.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input words.txt without ids: + - words_no_ids.txt +and generates the new words.txt with related ids. + - words.txt +""" + + +import argparse +import logging + +from tqdm import tqdm + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Prepare words.txt", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--input-file", + default="data/lang_char/words_no_ids.txt", + type=str, + help="the words file without ids for WenetSpeech", + ) + parser.add_argument( + "--output-file", + default="data/lang_char/words.txt", + type=str, + help="the words file with ids for WenetSpeech", + ) + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + input_file = args.input_file + output_file = args.output_file + + f = open(input_file, "r", encoding="utf-8") + lines = f.readlines() + new_lines = [] + add_words = [" 0", "!SIL 1", " 2", " 3"] + new_lines.extend(add_words) + + logging.info("Starting reading the input file") + for i in tqdm(range(len(lines))): + x = lines[i] + idx = 4 + i + new_line = str(x.strip("\n")) + " " + str(idx) + new_lines.append(new_line) + + logging.info("Starting writing the words.txt") + f_out = open(output_file, "w", encoding="utf-8") + for line in new_lines: + f_out.write(line) + f_out.write("\n") + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech/ASR/local/text2segments.py b/egs/wenetspeech/ASR/local/text2segments.py new file mode 100644 index 000000000..b55277e95 --- /dev/null +++ b/egs/wenetspeech/ASR/local/text2segments.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input "text", which refers to the transcript file for +WenetSpeech: + - text +and generates the output file text_word_segmentation which is implemented +with word segmenting: + - text_words_segmentation +""" + + +import argparse + +import jieba +from tqdm import tqdm + +jieba.enable_paddle() + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Chinese Word Segmentation for text", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--input", + default="data/lang_char/text", + type=str, + help="the input text file for WenetSpeech", + ) + parser.add_argument( + "--output", + default="data/lang_char/text_words_segmentation", + type=str, + help="the text implemented with words segmenting for WenetSpeech", + ) + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + input_file = args.input + output_file = args.output + + f = open(input_file, "r", encoding="utf-8") + lines = f.readlines() + new_lines = [] + for i in tqdm(range(len(lines))): + x = lines[i].rstrip() + seg_list = jieba.cut(x, use_paddle=True) + new_line = " ".join(seg_list) + new_lines.append(new_line) + + f_new = open(output_file, "w", encoding="utf-8") + for line in new_lines: + f_new.write(line) + f_new.write("\n") + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech/ASR/local/text2token.py b/egs/wenetspeech/ASR/local/text2token.py index 5ecea4009..1c463cf1c 100755 --- a/egs/wenetspeech/ASR/local/text2token.py +++ b/egs/wenetspeech/ASR/local/text2token.py @@ -188,7 +188,7 @@ def main(): a_chars = [z.replace(" ", args.space) for z in a_flat] - print(" ".join(a_chars)) + print("".join(a_chars)) line = f.readline() diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh index 1406600d1..52ef077c1 100755 --- a/egs/wenetspeech/ASR/prepare.sh +++ b/egs/wenetspeech/ASR/prepare.sh @@ -117,9 +117,9 @@ fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then log "Stage 7: Combine features for L" - if [ ! -f data/fbank/cuts_L.jsonl.gz ]; then - pieces=$(find data/fbank/L_split_${num_splits} -name "cuts_L.*.jsonl.gz") - lhotse combine $pieces data/fbank/cuts_L.jsonl.gz + if [ ! -f data/fbank/cuts_L_50.jsonl.gz ]; then + pieces=$(find data/fbank/L_split_50 -name "cuts_L.*.jsonl.gz") + lhotse combine $pieces data/fbank/cuts_L_50.jsonl.gz fi fi @@ -134,120 +134,50 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then lang_char_dir=data/lang_char mkdir -p $lang_char_dir - gunzip -c data/manifests/supervisions_L.jsonl.gz \ - | jq '.text' | sed 's/"//g' \ - | ./local/text2token.py -t "char" > $lang_char_dir/text + # Prepare text. + if [ ! -f $lang_char_dir/text ]; then + gunzip -c data/manifests/supervisions_L.jsonl.gz \ + | jq '.text' | sed 's/"//g' \ + | ./local/text2token.py -t "char" > $lang_char_dir/text + fi - cat $lang_char_dir/text | sed 's/ /\n/g' \ - | sort -u | sed '/^$/d' > $lang_char_dir/words.txt - (echo ''; echo ''; echo ''; ) | - cat - $lang_char_dir/words.txt | sort | uniq | awk ' - BEGIN { - print " 0"; - } - { - if ($1 == "") { - print " is in the vocabulary!" | "cat 1>&2" - exit 1; - } - if ($1 == "") { - print " is in the vocabulary!" | "cat 1>&2" - exit 1; - } - printf("%s %d\n", $1, NR); - } - END { - printf("#0 %d\n", NR+1); - printf(" %d\n", NR+2); - printf(" %d\n", NR+3); - }' > $lang_char_dir/words || exit 1; + # The implementation of chinese word segmentation for text, + # and it will take about 15 minutes. + if [ ! -f $lang_char_dir/text_words_segmentation ]; then + python ./local/text2segments.py \ + --input $lang_char_dir/text \ + --output $lang_char_dir/text_words_segmentation + fi - mv $lang_char_dir/words $lang_char_dir/words.txt + cat $lang_char_dir/text_words_segmentation | sed 's/ /\n/g' \ + | sort -u | sed '/^$/d' | uniq > $lang_char_dir/words_no_ids.txt + + if [ ! -f $lang_char_dir/words.txt ]; then + python ./local/prepare_words.py \ + --input-file $lang_char_dir/words_no_ids.txt \ + --output-file $lang_char_dir/words.txt + fi fi if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then - log "Stage 10: Prepare pinyin based lang" - lang_pinyin_dir=data/lang_pinyin - mkdir -p $lang_pinyin_dir - - gunzip -c data/manifests/supervisions_L.jsonl.gz \ - | jq '.text' | sed 's/"//g' \ - | ./local/text2token.py -t "pinyin" > $lang_pinyin_dir/text - - cat $lang_pinyin_dir/text | sed 's/ /\n/g' \ - | sort -u | sed '/^$/d' > $lang_pinyin_dir/words.txt - (echo ''; echo ''; echo ''; ) | - cat - $lang_pinyin_dir/words.txt | sort | uniq | awk ' - BEGIN { - print " 0"; - } - { - if ($1 == "") { - print " is in the vocabulary!" | "cat 1>&2" - exit 1; - } - if ($1 == "") { - print " is in the vocabulary!" | "cat 1>&2" - exit 1; - } - printf("%s %d\n", $1, NR); - } - END { - printf("#0 %d\n", NR+1); - printf(" %d\n", NR+2); - printf(" %d\n", NR+3); - }' > $lang_pinyin_dir/words || exit 1; - - mv $lang_pinyin_dir/words $lang_pinyin_dir/words.txt + log "Stage 10: Prepare char based L_disambig.pt" + if [ ! -f data/lang_char/L_disambig.pt ]; then + python ./local/prepare_char.py \ + --lang-dir data/lang_char + fi fi if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then - log "Stage 11: Prepare lazy_pinyin based lang" - lang_lazy_pinyin_dir=data/lang_lazy_pinyin - mkdir -p $lang_lazy_pinyin_dir + log "Stage 11: Prepare pinyin based L_disambig.pt" + lang_pinyin_dir=data/lang_pinyin + mkdir -p $lang_pinyin_dir - gunzip -c data/manifests/supervisions_L.jsonl.gz \ - | jq '.text' | sed 's/"//g' \ - | ./local/text2token.py -t "lazy_pinyin" > $lang_lazy_pinyin_dir/text - - cat $lang_lazy_pinyin_dir/text | sed 's/ /\n/g' \ - | sort -u | sed '/^$/d' > $lang_lazy_pinyin_dir/words.txt - (echo ''; echo ''; echo ''; ) | - cat - $lang_lazy_pinyin_dir/words.txt | sort | uniq | awk ' - BEGIN { - print " 0"; - } - { - if ($1 == "") { - print " is in the vocabulary!" | "cat 1>&2" - exit 1; - } - if ($1 == "") { - print " is in the vocabulary!" | "cat 1>&2" - exit 1; - } - printf("%s %d\n", $1, NR); - } - END { - printf("#0 %d\n", NR+1); - printf(" %d\n", NR+2); - printf(" %d\n", NR+3); - }' > $lang_lazy_pinyin_dir/words || exit 1; - - mv $lang_lazy_pinyin_dir/words $lang_lazy_pinyin_dir/words.txt -fi - -if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then - log "Stage 12: Prepare L_disambig.pt" - if [ ! -f data/lang_char/L_disambig.pt ]; then - python ./local/prepare_lang_wenetspeech.py --lang-dir data/lang_char - fi + cp -r data/lang_char/words.txt $lang_pinyin_dir/ + cp -r data/lang_char/text $lang_pinyin_dir/ + cp -r data/lang_char/text_words_segmentation $lang_pinyin_dir/ if [ ! -f data/lang_pinyin/L_disambig.pt ]; then - python ./local/prepare_lang_wenetspeech.py --lang-dir data/lang_pinyin - fi - - if [ ! -f data/lang_lazy_pinyin/L_disambig.pt ]; then - python ./local/prepare_lang_wenetspeech.py --lang-dir data/lang_lazy_pinyin + python ./local/prepare_pinyin.py \ + --lang-dir data/lang_pinyin fi fi diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless/asr_datamodule.py index 027d5cd1f..92c0d8c13 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless/asr_datamodule.py @@ -375,11 +375,15 @@ class WenetSpeechAsrDataModule: if self.args.lazy_load: logging.info("use lazy cuts") cuts_train = CutSet.from_jsonl_lazy( - self.args.manifest_dir / "cuts_L.jsonl.gz" + self.args.manifest_dir + / "cuts_L_50_pieces.jsonl.gz" + # use cuts_L_50_pieces.jsonl.gz for original experiments ) else: cuts_train = CutSet.from_file( - self.args.manifest_dir / "cuts_L.jsonl.gz" + self.args.manifest_dir + / "cuts_L_50_pieces.jsonl.gz" + # use cuts_L_50_pieces.jsonl.gz for original experiments ) return cuts_train diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless/train.py index 183d47034..52cad6b52 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless/train.py @@ -25,8 +25,10 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --world-size 4 \ --num-epochs 30 \ --start-epoch 0 \ - --exp-dir pruned_transducer_stateless/exp \ - --max-duration 300 + --lang-dir data/lang_char \ + --exp-dir pruned_transducer_stateless/exp-char \ + --token-type char \ + --max-duration 200 """ @@ -44,8 +46,8 @@ from asr_datamodule import WenetSpeechAsrDataModule from conformer import Conformer from decoder import Decoder from joiner import Joiner +from lhotse.cut import Cut from lhotse.utils import fix_random_seed -from local.text2token import token2id from model import Transducer from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP @@ -59,6 +61,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.lexicon import Lexicon +from icefall.pinyin_graph_compiler import PinyinCtcTrainingGraphCompiler from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool @@ -108,7 +111,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless/exp", + default="pruned_transducer_stateless_pinyin/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -118,7 +121,7 @@ def get_parser(): parser.add_argument( "--lang-dir", type=str, - default="data/lang_lazy_pinyin", + default="data/lang_char", help="""The lang dir It contains language related input files such as "lexicon.txt" @@ -128,10 +131,9 @@ def get_parser(): parser.add_argument( "--token-type", type=str, - default="lazy_pinyin", - help="""The token type - It refers to the token type for modeling, such as - char, pinyin, lazy_pinyin. + default="char", + help="""The type of token + It must be in ["char", "pinyin", "lazy_pinyin"]. """, ) @@ -435,16 +437,12 @@ def compute_loss( feature_lens = supervisions["num_frames"].to(device) texts = batch["supervisions"]["text"] - y = "" - if params.token_type == "char": - y = graph_compiler.texts_to_ids(texts) + + y = graph_compiler.texts_to_ids(texts) + if type(y) == list: + y = k2.RaggedTensor(y).to(device) else: - y = token2id( - texts=texts, - token_table=graph_compiler.token_table, - token_type=params.token_type, - ) - y = k2.RaggedTensor(y).to(device) + y = y.to(device) with torch.set_grad_enabled(is_training): simple_loss, pruned_loss = model( @@ -635,9 +633,19 @@ def run(rank, world_size, args): logging.info(f"Device: {device}") lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, device=device, oov="" - ) + graph_compiler = "" + if params.token_type == "char": + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + if params.token_type == "pinyin": + graph_compiler = PinyinCtcTrainingGraphCompiler( + lang_dir=params.lang_dir, + lexicon=lexicon, + device=device, + ) + params.blank_id = lexicon.token_table[""] params.vocab_size = max(lexicon.tokens) + 1 @@ -672,6 +680,23 @@ def run(rank, world_size, args): train_cuts = wenetspeech.train_cuts() + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 15.0 seconds + # You can get the statistics by local/display_manifest_statistics.py. + return 1.0 <= c.duration <= 15.0 + + def text_to_words(c: Cut): + # Convert text to words_segments. + text = c.supervisions[0].text + text = text.strip("\n").strip("\t") + words_cut = graph_compiler.text2words[text] + c.supervisions[0].text = words_cut + return c + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + if params.token_type == "pinyin": + train_cuts = train_cuts.map(text_to_words) + train_dl = wenetspeech.train_dataloaders(train_cuts) valid_cuts = wenetspeech.valid_cuts() valid_dl = wenetspeech.valid_dataloaders(valid_cuts) diff --git a/icefall/pinyin_graph_compiler.py b/icefall/pinyin_graph_compiler.py new file mode 100644 index 000000000..f8e68d01c --- /dev/null +++ b/icefall/pinyin_graph_compiler.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pathlib import Path +from typing import Dict, List + +import k2 +import torch +from tqdm import tqdm + +from icefall.lexicon import Lexicon, read_lexicon + + +class PinyinCtcTrainingGraphCompiler(object): + def __init__( + self, + lang_dir: Path, + lexicon: Lexicon, + device: torch.device, + sos_token: str = "", + eos_token: str = "", + oov: str = "", + ): + """ + Args: + lexicon: + It is built from `data/lang_char/lexicon.txt`. + device: + The device to use for operations compiling transcripts to FSAs. + oov: + Out of vocabulary token. When a word(token) in the transcript + does not exist in the token list, it is replaced with `oov`. + """ + + assert oov in lexicon.token_table + + self.lang_dir = lang_dir + self.oov_id = lexicon.token_table[oov] + self.token_table = lexicon.token_table + + self.device = device + + self.sos_id = self.token_table[sos_token] + self.eos_id = self.token_table[eos_token] + + self.word_table = lexicon.word_table + self.token_table = lexicon.token_table + + self.text2words = convert_text_to_word_segments( + text_filename=self.lang_dir / "text", + words_segments_filename=self.lang_dir / "text_words_segmentation", + ) + self.ragged_lexicon = convert_lexicon_to_ragged( + filename=self.lang_dir / "lexicon.txt", + word_table=self.word_table, + token_table=self.token_table, + ) + + def texts_to_ids(self, texts: List[str]) -> List[List[int]]: + """Convert a list of texts to a list-of-list of pinyin-based token IDs. + + Args: + texts: + It is a list of strings. + An example containing two strings is given below: + + ['你好中国', '北京欢迎您'] + Returns: + Return a list-of-list of pinyin-based token IDs. + """ + word_ids_list = [] + for i in range(len(texts)): + word_ids = [] + text = texts[i].strip("\n").strip("\t") + for word in text.split(" "): + if word in self.word_table: + word_ids.append(self.word_table[word]) + else: + word_ids.append(self.oov_id) + word_ids_list.append(word_ids) + + ragged_indexes = k2.RaggedTensor(word_ids_list, dtype=torch.int32) + ans = self.ragged_lexicon.index(ragged_indexes) + ans = ans.remove_axis(ans.num_axes - 2) + + return ans + + def compile( + self, + token_ids: List[List[int]], + modified: bool = False, + ) -> k2.Fsa: + """Build a ctc graph from a list-of-list token IDs. + + Args: + piece_ids: + It is a list-of-list integer IDs. + modified: + See :func:`k2.ctc_graph` for its meaning. + Return: + Return an FsaVec, which is the result of composing a + CTC topology with linear FSAs constructed from the given + piece IDs. + """ + return k2.ctc_graph(token_ids, modified=modified, device=self.device) + + +def convert_lexicon_to_ragged( + filename: str, word_table: k2.SymbolTable, token_table: k2.SymbolTable +) -> k2.RaggedTensor: + """Read a lexicon and convert lexicon to a ragged tensor. + + Args: + filename: + Path to the lexicon file. + word_table: + The word symbol table. + token_table: + The token symbol table. + Returns: + A k2 ragged tensor with two axes [word][token]. + """ + num_words = len(word_table.symbols) + excluded_words = [ + "", + "!SIL", + "", + "", + "#0", + "", + "", + ] + + row_splits = [0] + token_ids_list = [] + + lexicon_tmp = read_lexicon(filename) + lexicon = dict(lexicon_tmp) + if len(lexicon_tmp) != len(lexicon): + raise RuntimeError( + "It's assumed that each word has a unique pronunciation" + ) + + for i in range(num_words): + w = word_table[i] + if w in excluded_words: + row_splits.append(row_splits[-1]) + continue + tokens = lexicon[w] + token_ids = [token_table[k] for k in tokens] + + row_splits.append(row_splits[-1] + len(token_ids)) + token_ids_list.extend(token_ids) + + cached_tot_size = row_splits[-1] + row_splits = torch.tensor(row_splits, dtype=torch.int32) + + shape = k2.ragged.create_ragged_shape2( + row_splits, + None, + cached_tot_size, + ) + values = torch.tensor(token_ids_list, dtype=torch.int32) + + return k2.RaggedTensor(shape, values) + + +def convert_text_to_word_segments( + text_filename: str, words_segments_filename: str +) -> Dict[str, str]: + """Convert text to word-based segments. + + Args: + text_filename: + The file for the original transcripts. + words_segments_filename: + The file after implementing chinese word segmentation + for the original transcripts. + Returns: + A dictionary about text and words_segments. + """ + text2words = {} + + f_text = open(text_filename, "r", encoding="utf-8") + text_lines = f_text.readlines() + text_lines = [line.strip("\t") for line in text_lines] + + f_words = open(words_segments_filename, "r", encoding="utf-8") + words_lines = f_words.readlines() + words_lines = [line.strip("\t") for line in words_lines] + + if len(text_lines) != len(words_lines): + raise RuntimeError( + "The lengths of text and words_segments should be equal." + ) + + for i in tqdm(range(len(text_lines))): + text = text_lines[i].strip(" ").strip("\n") + words_segments = words_lines[i].strip(" ").strip("\n") + text2words[text] = words_segments + + return text2words