From 01b52165a107081c0ab2475d3442d08981446e77 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Tue, 26 Sep 2023 16:57:40 +0800 Subject: [PATCH] removed scripts existing in other recipes --- egs/swbd/ASR/local/compile_hlg.py | 167 ------- egs/swbd/ASR/local/compile_lg.py | 147 ------- egs/swbd/ASR/local/prepare_lang.py | 413 ------------------ .../ASR/local/prepare_lm_training_data.py | 167 ------- egs/swbd/ASR/local/validate_bpe_lexicon.py | 77 ---- 5 files changed, 971 deletions(-) delete mode 100755 egs/swbd/ASR/local/compile_hlg.py delete mode 100755 egs/swbd/ASR/local/compile_lg.py delete mode 100755 egs/swbd/ASR/local/prepare_lang.py delete mode 100755 egs/swbd/ASR/local/prepare_lm_training_data.py delete mode 100755 egs/swbd/ASR/local/validate_bpe_lexicon.py diff --git a/egs/swbd/ASR/local/compile_hlg.py b/egs/swbd/ASR/local/compile_hlg.py deleted file mode 100755 index d19d50ae6..000000000 --- a/egs/swbd/ASR/local/compile_hlg.py +++ /dev/null @@ -1,167 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input lang_dir and generates HLG from - - - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt - - L, the lexicon, built from lang_dir/L_disambig.pt - - Caution: We use a lexicon that contains disambiguation symbols - - - G, the LM, built from data/lm/G_n_gram.fst.txt - -The generated HLG is saved in $lang_dir/HLG.pt -""" -import argparse -import logging -from pathlib import Path - -import k2 -import torch - -from icefall.lexicon import Lexicon - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lm", - type=str, - default="G_3_gram", - help="""Stem name for LM used in HLG compiling. - """, - ) - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - """, - ) - - return parser.parse_args() - - -def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: - """ - Args: - lang_dir: - The language directory, e.g., data/lang_phone or data/lang_bpe_5000. - lm: - The language stem base name. - - Return: - An FSA representing HLG. - """ - lexicon = Lexicon(lang_dir) - max_token_id = max(lexicon.tokens) - logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") - H = k2.ctc_topo(max_token_id) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) - - if Path(f"data/lm/{lm}.pt").is_file(): - logging.info(f"Loading pre-compiled {lm}") - d = torch.load(f"data/lm/{lm}.pt") - G = k2.Fsa.from_dict(d) - else: - logging.info(f"Loading {lm}.fst.txt") - with open(f"data/lm/{lm}.fst.txt") as f: - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - torch.save(G.as_dict(), f"data/lm/{lm}.pt") - - first_token_disambig_id = lexicon.token_table["#0"] - first_word_disambig_id = lexicon.word_table["#0"] - - L = k2.arc_sort(L) - G = k2.arc_sort(G) - - logging.info("Intersecting L and G") - LG = k2.compose(L, G) - logging.info(f"LG shape: {LG.shape}") - - logging.info("Connecting LG") - LG = k2.connect(LG) - logging.info(f"LG shape after k2.connect: {LG.shape}") - - logging.info(type(LG.aux_labels)) - logging.info("Determinizing LG") - - LG = k2.determinize(LG) - logging.info(type(LG.aux_labels)) - - logging.info("Connecting LG after k2.determinize") - LG = k2.connect(LG) - - logging.info("Removing disambiguation symbols on LG") - - # LG.labels[LG.labels >= first_token_disambig_id] = 0 - # see https://github.com/k2-fsa/k2/pull/1140 - labels = LG.labels - labels[labels >= first_token_disambig_id] = 0 - LG.labels = labels - - assert isinstance(LG.aux_labels, k2.RaggedTensor) - LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 - - LG = k2.remove_epsilon(LG) - logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") - - LG = k2.connect(LG) - LG.aux_labels = LG.aux_labels.remove_values_eq(0) - - logging.info("Arc sorting LG") - LG = k2.arc_sort(LG) - - logging.info("Composing H and LG") - # CAUTION: The name of the inner_labels is fixed - # to `tokens`. If you want to change it, please - # also change other places in icefall that are using - # it. - HLG = k2.compose(H, LG, inner_labels="tokens") - - logging.info("Connecting LG") - HLG = k2.connect(HLG) - - logging.info("Arc sorting LG") - HLG = k2.arc_sort(HLG) - logging.info(f"HLG.shape: {HLG.shape}") - - return HLG - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - - if (lang_dir / "HLG.pt").is_file(): - logging.info(f"{lang_dir}/HLG.pt already exists - skipping") - return - - logging.info(f"Processing {lang_dir}") - - HLG = compile_HLG(lang_dir, args.lm) - logging.info(f"Saving HLG.pt to {lang_dir}") - torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/swbd/ASR/local/compile_lg.py b/egs/swbd/ASR/local/compile_lg.py deleted file mode 100755 index 709b14070..000000000 --- a/egs/swbd/ASR/local/compile_lg.py +++ /dev/null @@ -1,147 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input lang_dir and generates LG from - - - L, the lexicon, built from lang_dir/L_disambig.pt - - Caution: We use a lexicon that contains disambiguation symbols - - - G, the LM, built from data/lm/G_3_gram.fst.txt - -The generated LG is saved in $lang_dir/LG.pt -""" -import argparse -import logging -from pathlib import Path - -import k2 -import torch - -from icefall.lexicon import Lexicon - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - """, - ) - parser.add_argument( - "--lm", - type=str, - default="G_3_gram", - help="""Stem name for LM used in HLG compiling. - """, - ) - - return parser.parse_args() - - -def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: - """ - Args: - lang_dir: - The language directory, e.g., data/lang_phone or data/lang_bpe_5000. - - Return: - An FSA representing LG. - """ - lexicon = Lexicon(lang_dir) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) - - if Path(f"data/lm/{lm}.pt").is_file(): - logging.info(f"Loading pre-compiled {lm}") - d = torch.load(f"data/lm/{lm}.pt") - G = k2.Fsa.from_dict(d) - else: - logging.info(f"Loading {lm}.fst.txt") - with open(f"data/lm/{lm}.fst.txt") as f: - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - torch.save(G.as_dict(), f"data/lm/{lm}.pt") - - first_token_disambig_id = lexicon.token_table["#0"] - first_word_disambig_id = lexicon.word_table["#0"] - - L = k2.arc_sort(L) - G = k2.arc_sort(G) - - logging.info("Intersecting L and G") - LG = k2.compose(L, G) - logging.info(f"LG shape: {LG.shape}") - - logging.info("Connecting LG") - LG = k2.connect(LG) - logging.info(f"LG shape after k2.connect: {LG.shape}") - - logging.info(type(LG.aux_labels)) - logging.info("Determinizing LG") - - LG = k2.determinize(LG, k2.DeterminizeWeightPushingType.kLogWeightPushing) - logging.info(type(LG.aux_labels)) - - logging.info("Connecting LG after k2.determinize") - LG = k2.connect(LG) - - logging.info("Removing disambiguation symbols on LG") - - # LG.labels[LG.labels >= first_token_disambig_id] = 0 - # see https://github.com/k2-fsa/k2/pull/1140 - labels = LG.labels - labels[labels >= first_token_disambig_id] = 0 - LG.labels = labels - - assert isinstance(LG.aux_labels, k2.RaggedTensor) - LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 - - LG = k2.remove_epsilon(LG) - logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") - - LG = k2.connect(LG) - LG.aux_labels = LG.aux_labels.remove_values_eq(0) - - logging.info("Arc sorting LG") - LG = k2.arc_sort(LG) - - return LG - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - - if (lang_dir / "LG.pt").is_file(): - logging.info(f"{lang_dir}/LG.pt already exists - skipping") - return - - logging.info(f"Processing {lang_dir}") - - LG = compile_LG(lang_dir, args.lm) - logging.info(f"Saving LG.pt to {lang_dir}") - torch.save(LG.as_dict(), f"{lang_dir}/LG.pt") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/swbd/ASR/local/prepare_lang.py b/egs/swbd/ASR/local/prepare_lang.py deleted file mode 100755 index d913756a1..000000000 --- a/egs/swbd/ASR/local/prepare_lang.py +++ /dev/null @@ -1,413 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input a lexicon file "data/lang_phone/lexicon.txt" -consisting of words and tokens (i.e., phones) and does the following: - -1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt - -2. Generate tokens.txt, the token table mapping a token to a unique integer. - -3. Generate words.txt, the word table mapping a word to a unique integer. - -4. Generate L.pt, in k2 format. It can be loaded by - - d = torch.load("L.pt") - lexicon = k2.Fsa.from_dict(d) - -5. Generate L_disambig.pt, in k2 format. -""" -import argparse -import math -from collections import defaultdict -from pathlib import Path -from typing import Any, Dict, List, Tuple - -import k2 -import torch - -from icefall.lexicon import read_lexicon, write_lexicon -from icefall.utils import str2bool - -Lexicon = List[Tuple[str, List[str]]] - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - It should contain a file lexicon.txt. - Generated files by this script are saved into this directory. - """, - ) - - parser.add_argument( - "--debug", - type=str2bool, - default=False, - help="""True for debugging, which will generate - a visualization of the lexicon FST. - - Caution: If your lexicon contains hundreds of thousands - of lines, please set it to False! - """, - ) - - return parser.parse_args() - - -def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: - """Write a symbol to ID mapping to a file. - - Note: - No need to implement `read_mapping` as it can be done - through :func:`k2.SymbolTable.from_file`. - - Args: - filename: - Filename to save the mapping. - sym2id: - A dict mapping symbols to IDs. - Returns: - Return None. - """ - with open(filename, "w", encoding="utf-8") as f: - for sym, i in sym2id.items(): - f.write(f"{sym} {i}\n") - - -def get_tokens(lexicon: Lexicon) -> List[str]: - """Get tokens from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique tokens. - """ - ans = set() - for _, tokens in lexicon: - ans.update(tokens) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def get_words(lexicon: Lexicon) -> List[str]: - """Get words from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique words. - """ - ans = set() - for word, _ in lexicon: - ans.add(word) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: - """It adds pseudo-token disambiguation symbols #1, #2 and so on - at the ends of tokens to ensure that all pronunciations are different, - and that none is a prefix of another. - - See also add_lex_disambig.pl from kaldi. - - Args: - lexicon: - It is returned by :func:`read_lexicon`. - Returns: - Return a tuple with two elements: - - - The output lexicon with disambiguation symbols - - The ID of the max disambiguation symbol that appears - in the lexicon - """ - - # (1) Work out the count of each token-sequence in the - # lexicon. - count = defaultdict(int) - for _, tokens in lexicon: - count[" ".join(tokens)] += 1 - - # (2) For each left sub-sequence of each token-sequence, note down - # that it exists (for identifying prefixes of longer strings). - issubseq = defaultdict(int) - for _, tokens in lexicon: - tokens = tokens.copy() - tokens.pop() - while tokens: - issubseq[" ".join(tokens)] = 1 - tokens.pop() - - # (3) For each entry in the lexicon: - # if the token sequence is unique and is not a - # prefix of another word, no disambig symbol. - # Else output #1, or #2, #3, ... if the same token-seq - # has already been assigned a disambig symbol. - ans = [] - - # We start with #1 since #0 has its own purpose - first_allowed_disambig = 1 - max_disambig = first_allowed_disambig - 1 - last_used_disambig_symbol_of = defaultdict(int) - - for word, tokens in lexicon: - tokenseq = " ".join(tokens) - assert tokenseq != "" - if issubseq[tokenseq] == 0 and count[tokenseq] == 1: - ans.append((word, tokens)) - continue - - cur_disambig = last_used_disambig_symbol_of[tokenseq] - if cur_disambig == 0: - cur_disambig = first_allowed_disambig - else: - cur_disambig += 1 - - if cur_disambig > max_disambig: - max_disambig = cur_disambig - last_used_disambig_symbol_of[tokenseq] = cur_disambig - tokenseq += f" #{cur_disambig}" - ans.append((word, tokenseq.split())) - return ans, max_disambig - - -def generate_id_map(symbols: List[str]) -> Dict[str, int]: - """Generate ID maps, i.e., map a symbol to a unique ID. - - Args: - symbols: - A list of unique symbols. - Returns: - A dict containing the mapping between symbols and IDs. - """ - return {sym: i for i, sym in enumerate(symbols)} - - -def add_self_loops( - arcs: List[List[Any]], disambig_token: int, disambig_word: int -) -> List[List[Any]]: - """Adds self-loops to states of an FST to propagate disambiguation symbols - through it. They are added on each state with non-epsilon output symbols - on at least one arc out of the state. - - See also fstaddselfloops.pl from Kaldi. One difference is that - Kaldi uses OpenFst style FSTs and it has multiple final states. - This function uses k2 style FSTs and it does not need to add self-loops - to the final state. - - The input label of a self-loop is `disambig_token`, while the output - label is `disambig_word`. - - Args: - arcs: - A list-of-list. The sublist contains - `[src_state, dest_state, label, aux_label, score]` - disambig_token: - It is the token ID of the symbol `#0`. - disambig_word: - It is the word ID of the symbol `#0`. - - Return: - Return new `arcs` containing self-loops. - """ - states_needs_self_loops = set() - for arc in arcs: - src, dst, ilabel, olabel, score = arc - if olabel != 0: - states_needs_self_loops.add(src) - - ans = [] - for s in states_needs_self_loops: - ans.append([s, s, disambig_token, disambig_word, 0]) - - return arcs + ans - - -def lexicon_to_fst( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - sil_token: str = "SIL", - sil_prob: float = 0.5, - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format) with optional silence at - the beginning and end of each word. - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - sil_token: - The silence token. - sil_prob: - The probability for adding a silence at the beginning and end - of the word. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - assert sil_prob > 0.0 and sil_prob < 1.0 - # CAUTION: we use score, i.e, negative cost. - sil_score = math.log(sil_prob) - no_sil_score = math.log(1.0 - sil_prob) - - start_state = 0 - loop_state = 1 # words enter and leave from here - sil_state = 2 # words terminate here when followed by silence; this state - # has a silence transition to loop_state. - next_state = 3 # the next un-allocated state, will be incremented as we go. - arcs = [] - - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - sil_token = token2id[sil_token] - - arcs.append([start_state, loop_state, eps, eps, no_sil_score]) - arcs.append([start_state, sil_state, eps, eps, sil_score]) - arcs.append([sil_state, loop_state, sil_token, eps, 0]) - - for word, tokens in lexicon: - assert len(tokens) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - tokens = [token2id[i] for i in tokens] - - for i in range(len(tokens) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, tokens[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last token of this word - # It has two out-going arcs, one to the loop state, - # the other one to the sil_state. - i = len(tokens) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) - arcs.append([cur_state, sil_state, tokens[i], w, sil_score]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - lexicon_filename = lang_dir / "lexicon.txt" - sil_token = "SIL" - sil_prob = 0.5 - - lexicon = read_lexicon(lexicon_filename) - tokens = get_tokens(lexicon) - words = get_words(lexicon) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in tokens - tokens.append(f"#{i}") - - assert "" not in tokens - tokens = [""] + tokens - - assert "" not in words - assert "#0" not in words - assert "" not in words - assert "" not in words - - words = [""] + words + ["#0", "", ""] - - token2id = generate_id_map(tokens) - word2id = generate_id_map(words) - - write_mapping(lang_dir / "tokens.txt", token2id) - write_mapping(lang_dir / "words.txt", word2id) - write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst( - lexicon, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - ) - - L_disambig = lexicon_to_fst( - lexicon_disambig, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - need_self_loops=True, - ) - torch.save(L.as_dict(), lang_dir / "L.pt") - torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") - - if args.debug: - labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") - aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") - - L.labels_sym = labels_sym - L.aux_labels_sym = aux_labels_sym - L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") - - L_disambig.labels_sym = labels_sym - L_disambig.aux_labels_sym = aux_labels_sym - L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") - - -if __name__ == "__main__": - main() diff --git a/egs/swbd/ASR/local/prepare_lm_training_data.py b/egs/swbd/ASR/local/prepare_lm_training_data.py deleted file mode 100755 index 70343fef7..000000000 --- a/egs/swbd/ASR/local/prepare_lm_training_data.py +++ /dev/null @@ -1,167 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey -# Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script takes a `bpe.model` and a text file such as -./download/lm/librispeech-lm-norm.txt -and outputs the LM training data to a supplied directory such -as data/lm_training_bpe_500. The format is as follows: - -It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a -representation of a dict with the following format: - - 'words' -> a k2.RaggedTensor of two axes [word][token] with dtype torch.int32 - containing the BPE representations of each word, indexed by - integer word ID. (These integer word IDS are present in - 'lm_data'). The sentencepiece object can be used to turn the - words and BPE units into string form. - 'sentences' -> a k2.RaggedTensor of two axes [sentence][word] with dtype - torch.int32 containing all the sentences, as word-ids (we don't - output the string form of this directly but it can be worked out - together with 'words' and the bpe.model). - 'sentence_lengths' -> a 1-D torch.Tensor of dtype torch.int32, containing - number of BPE tokens of each sentence. -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import sentencepiece as spm -import torch - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--bpe-model", - type=str, - help="Input BPE model, e.g. data/bpe_500/bpe.model", - ) - parser.add_argument( - "--lm-data", - type=str, - help="""Input LM training data as text, e.g. - download/pb.train.txt""", - ) - parser.add_argument( - "--lm-archive", - type=str, - help="""Path to output archive, e.g. data/bpe_500/lm_data.pt; - look at the source of this script to see the format.""", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - - if Path(args.lm_archive).exists(): - logging.warning(f"{args.lm_archive} exists - skipping") - return - - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - - # word2index is a dictionary from words to integer ids. No need to reserve - # space for epsilon, etc.; the words are just used as a convenient way to - # compress the sequences of BPE pieces. - word2index = dict() - - word2bpe = [] # Will be a list-of-list-of-int, representing BPE pieces. - sentences = [] # Will be a list-of-list-of-int, representing word-ids. - - if "librispeech-lm-norm" in args.lm_data: - num_lines_in_total = 40418261.0 - step = 5000000 - elif "valid" in args.lm_data: - num_lines_in_total = 5567.0 - step = 3000 - elif "test" in args.lm_data: - num_lines_in_total = 5559.0 - step = 3000 - else: - num_lines_in_total = None - step = None - - processed = 0 - - with open(args.lm_data) as f: - while True: - line = f.readline() - if line == "": - break - - if step and processed % step == 0: - logging.info( - f"Processed number of lines: {processed} " - f"({processed/num_lines_in_total*100: .3f}%)" - ) - processed += 1 - - line_words = line.split() - for w in line_words: - if w not in word2index: - w_bpe = sp.encode(w) - word2index[w] = len(word2bpe) - word2bpe.append(w_bpe) - sentences.append([word2index[w] for w in line_words]) - - logging.info("Constructing ragged tensors") - words = k2.ragged.RaggedTensor(word2bpe) - sentences = k2.ragged.RaggedTensor(sentences) - - output = dict(words=words, sentences=sentences) - - num_sentences = sentences.dim0 - logging.info(f"Computing sentence lengths, num_sentences: {num_sentences}") - sentence_lengths = [0] * num_sentences - for i in range(num_sentences): - if step and i % step == 0: - logging.info( - f"Processed number of lines: {i} ({i/num_sentences*100: .3f}%)" - ) - - word_ids = sentences[i] - - # NOTE: If word_ids is a tensor with only 1 entry, - # token_ids is a torch.Tensor - token_ids = words[word_ids] - if isinstance(token_ids, k2.RaggedTensor): - token_ids = token_ids.values - - # token_ids is a 1-D tensor containing the BPE tokens - # of the current sentence - - sentence_lengths[i] = token_ids.numel() - - output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32) - - torch.save(output, args.lm_archive) - logging.info(f"Saved to {args.lm_archive}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/swbd/ASR/local/validate_bpe_lexicon.py b/egs/swbd/ASR/local/validate_bpe_lexicon.py deleted file mode 100755 index c542f2fab..000000000 --- a/egs/swbd/ASR/local/validate_bpe_lexicon.py +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script checks that there are no OOV tokens in the BPE-based lexicon. - -Usage example: - - python3 ./local/validate_bpe_lexicon.py \ - --lexicon /path/to/lexicon.txt \ - --bpe-model /path/to/bpe.model -""" - -import argparse -from pathlib import Path -from typing import List, Tuple - -import sentencepiece as spm - -from icefall.lexicon import read_lexicon - -# Map word to word pieces -Lexicon = List[Tuple[str, List[str]]] - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--lexicon", - required=True, - type=Path, - help="Path to lexicon.txt", - ) - - parser.add_argument( - "--bpe-model", - required=True, - type=Path, - help="Path to bpe.model", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - assert args.lexicon.is_file(), args.lexicon - assert args.bpe_model.is_file(), args.bpe_model - - lexicon = read_lexicon(args.lexicon) - - sp = spm.SentencePieceProcessor() - sp.load(str(args.bpe_model)) - - word_pieces = set(sp.id_to_piece(list(range(sp.vocab_size())))) - for word, pieces in lexicon: - for p in pieces: - if p not in word_pieces: - raise ValueError(f"The word {word} contains an OOV token {p}") - - -if __name__ == "__main__": - main()