diff --git a/egs/mucs/ASR/conformer_ctc/asr_datamodule.py b/egs/mucs/ASR/conformer_ctc/asr_datamodule.py index 85c39b91b..e631665b2 100644 --- a/egs/mucs/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/mucs/ASR/conformer_ctc/asr_datamodule.py @@ -418,56 +418,3 @@ class LibriSpeechAsrDataModule: return load_manifest_lazy( self.args.manifest_dir / "mucs_cuts_train.jsonl.gz" ) - - - @lru_cache() - def train_clean_360_cuts(self) -> CutSet: - logging.info("About to get train-clean-360 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" - ) - - @lru_cache() - def train_other_500_cuts(self) -> CutSet: - logging.info("About to get train-other-500 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" - ) - - @lru_cache() - def train_all_shuf_cuts(self) -> CutSet: - logging.info( - "About to get the shuffled train-clean-100, \ - train-clean-360 and train-other-500 cuts" - ) - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" - ) - - @lru_cache() - def dev_clean_cuts(self) -> CutSet: - logging.info("About to get dev-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" - ) - - @lru_cache() - def dev_other_cuts(self) -> CutSet: - logging.info("About to get dev-other cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" - ) - - @lru_cache() - def test_clean_cuts(self) -> CutSet: - logging.info("About to get test-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" - ) - - @lru_cache() - def test_other_cuts(self) -> CutSet: - logging.info("About to get test-other cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" - ) diff --git a/egs/mucs/ASR/local/compile_hlg.py b/egs/mucs/ASR/local/compile_hlg.py new file mode 120000 index 000000000..471aa7fb4 --- /dev/null +++ b/egs/mucs/ASR/local/compile_hlg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/mucs/ASR/local/compute_fbank_mucs.py b/egs/mucs/ASR/local/compute_fbank_mucs.py old mode 100755 new mode 100644 index b987248cd..c7f33ec1d --- a/egs/mucs/ASR/local/compute_fbank_mucs.py +++ b/egs/mucs/ASR/local/compute_fbank_mucs.py @@ -19,7 +19,6 @@ """ This file computes fbank features of the LibriSpeech dataset. It looks for manifests in the directory data/manifests. - The generated fbank features are saved in data/fbank. """ @@ -83,9 +82,7 @@ def compute_fbank_mucs( "test", "dev", ) - # dataset_parts = ( - # "test", - # ) + prefix = "mucs" suffix = "jsonl.gz" manifests = read_manifests_if_cached( @@ -107,8 +104,7 @@ def compute_fbank_mucs( with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): - # print(m["recordings"]) - # exit() + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" if (output_dir / cuts_filename).is_file(): logging.info(f"{partition} already exists - skipping.") @@ -118,18 +114,7 @@ def compute_fbank_mucs( recordings=m["recordings"], supervisions=m["supervisions"], ) - # print(len(m["supervisions"])) - # for s in m["supervisions"]: - # # print(s) - # if s.channel != 0: - # print(s) - # exit() - # if "train" in partition: - # if bpe_model: - # cut_set = filter_cuts(cut_set, sp) - # cut_set = ( - # cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - # ) + cut_set = cut_set.compute_and_store_features( extractor=extractor, storage_path=f"{output_dir}/{prefix}_feats_{partition}", @@ -151,4 +136,4 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() logging.info(vars(args)) - compute_fbank_mucs(bpe_model=args.bpe_model, dataset=args.dataset) + compute_fbank_mucs(bpe_model=args.bpe_model, dataset=args.dataset) \ No newline at end of file diff --git a/egs/mucs/ASR/local/convert_transcript_words_to_tokens.py b/egs/mucs/ASR/local/convert_transcript_words_to_tokens.py new file mode 120000 index 000000000..2ce13fd69 --- /dev/null +++ b/egs/mucs/ASR/local/convert_transcript_words_to_tokens.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py \ No newline at end of file diff --git a/egs/mucs/ASR/local/filter_cuts.py b/egs/mucs/ASR/local/filter_cuts.py deleted file mode 100644 index fbcc9e24a..000000000 --- a/egs/mucs/ASR/local/filter_cuts.py +++ /dev/null @@ -1,160 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script removes short and long utterances from a cutset. - -Caution: - You may need to tune the thresholds for your own dataset. - -Usage example: - - python3 ./local/filter_cuts.py \ - --bpe-model data/lang_bpe_500/bpe.model \ - --in-cuts data/fbank/librispeech_cuts_test-clean.jsonl.gz \ - --out-cuts data/fbank-filtered/librispeech_cuts_test-clean.jsonl.gz -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -from lhotse import CutSet, load_manifest_lazy -from lhotse.cut import Cut - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--bpe-model", - type=Path, - help="Path to the bpe.model", - ) - - parser.add_argument( - "--in-cuts", - type=Path, - help="Path to the input cutset", - ) - - parser.add_argument( - "--out-cuts", - type=Path, - help="Path to the output cutset", - ) - - return parser.parse_args() - - -def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): - total = 0 # number of total utterances before removal - removed = 0 # number of removed utterances - - def remove_short_and_long_utterances(c: Cut): - """Return False to exclude the input cut""" - nonlocal removed, total - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ./display_manifest_statistics.py - # - # You should use ./display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - total += 1 - if c.duration < 1.0 or c.duration > 20.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - removed += 1 - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./pruned_transducer_stateless2/conformer.py, the - # conv module uses the following expression - # for subsampling - if c.num_frames is None: - num_frames = c.duration * 100 # approximate - else: - num_frames = c.num_frames - - T = ((num_frames - 1) // 2 - 1) // 2 - # Note: for ./lstm_transducer_stateless/lstm.py, the formula is - # T = ((num_frames - 3) // 2 - 1) // 2 - - # Note: for ./pruned_transducer_stateless7/zipformer.py, the formula is - # T = ((num_frames - 7) // 2 + 1) // 2 - - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - removed += 1 - return False - - return True - - # We use to_eager() here so that we can print out the value of total - # and removed below. - ans = cut_set.filter(remove_short_and_long_utterances).to_eager() - ratio = removed / total * 100 - logging.info( - f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed." - ) - return ans - - -def main(): - args = get_args() - logging.info(vars(args)) - - if args.out_cuts.is_file(): - logging.info(f"{args.out_cuts} already exists - skipping") - return - - assert args.in_cuts.is_file(), f"{args.in_cuts} does not exist" - assert args.bpe_model.is_file(), f"{args.bpe_model} does not exist" - - sp = spm.SentencePieceProcessor() - sp.load(str(args.bpe_model)) - - cut_set = load_manifest_lazy(args.in_cuts) - assert isinstance(cut_set, CutSet) - - cut_set = filter_cuts(cut_set, sp) - logging.info(f"Saving to {args.out_cuts}") - args.out_cuts.parent.mkdir(parents=True, exist_ok=True) - cut_set.to_file(args.out_cuts) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/mucs/ASR/local/filter_cuts.py b/egs/mucs/ASR/local/filter_cuts.py new file mode 120000 index 000000000..27aca1729 --- /dev/null +++ b/egs/mucs/ASR/local/filter_cuts.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/filter_cuts.py \ No newline at end of file diff --git a/egs/mucs/ASR/local/prepare_lang.py b/egs/mucs/ASR/local/prepare_lang.py deleted file mode 100755 index e00b92aad..000000000 --- a/egs/mucs/ASR/local/prepare_lang.py +++ /dev/null @@ -1,412 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input a lexicon file "data/lang_phone/lexicon.txt" -consisting of words and tokens (i.e., phones) and does the following: - -1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt - -2. Generate tokens.txt, the token table mapping a token to a unique integer. - -3. Generate words.txt, the word table mapping a word to a unique integer. - -4. Generate L.pt, in k2 format. It can be loaded by - - d = torch.load("L.pt") - lexicon = k2.Fsa.from_dict(d) - -5. Generate L_disambig.pt, in k2 format. -""" -import argparse -import math -from collections import defaultdict -from pathlib import Path -from typing import Any, Dict, List, Tuple - -import k2 -import torch - -from icefall.lexicon import read_lexicon, write_lexicon -from icefall.utils import str2bool - -Lexicon = List[Tuple[str, List[str]]] - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - It should contain a file lexicon.txt. - Generated files by this script are saved into this directory. - """, - ) - - parser.add_argument( - "--debug", - type=str2bool, - default=False, - help="""True for debugging, which will generate - a visualization of the lexicon FST. - - Caution: If your lexicon contains hundreds of thousands - of lines, please set it to False! - """, - ) - - return parser.parse_args() - - -def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: - """Write a symbol to ID mapping to a file. - - Note: - No need to implement `read_mapping` as it can be done - through :func:`k2.SymbolTable.from_file`. - - Args: - filename: - Filename to save the mapping. - sym2id: - A dict mapping symbols to IDs. - Returns: - Return None. - """ - with open(filename, "w", encoding="utf-8") as f: - for sym, i in sym2id.items(): - f.write(f"{sym} {i}\n") - - -def get_tokens(lexicon: Lexicon) -> List[str]: - """Get tokens from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique tokens. - """ - ans = set() - for _, tokens in lexicon: - ans.update(tokens) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def get_words(lexicon: Lexicon) -> List[str]: - """Get words from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique words. - """ - ans = set() - for word, _ in lexicon: - ans.add(word) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: - """It adds pseudo-token disambiguation symbols #1, #2 and so on - at the ends of tokens to ensure that all pronunciations are different, - and that none is a prefix of another. - - See also add_lex_disambig.pl from kaldi. - - Args: - lexicon: - It is returned by :func:`read_lexicon`. - Returns: - Return a tuple with two elements: - - - The output lexicon with disambiguation symbols - - The ID of the max disambiguation symbol that appears - in the lexicon - """ - - # (1) Work out the count of each token-sequence in the - # lexicon. - count = defaultdict(int) - for _, tokens in lexicon: - count[" ".join(tokens)] += 1 - - # (2) For each left sub-sequence of each token-sequence, note down - # that it exists (for identifying prefixes of longer strings). - issubseq = defaultdict(int) - for _, tokens in lexicon: - tokens = tokens.copy() - tokens.pop() - while tokens: - issubseq[" ".join(tokens)] = 1 - tokens.pop() - - # (3) For each entry in the lexicon: - # if the token sequence is unique and is not a - # prefix of another word, no disambig symbol. - # Else output #1, or #2, #3, ... if the same token-seq - # has already been assigned a disambig symbol. - ans = [] - - # We start with #1 since #0 has its own purpose - first_allowed_disambig = 1 - max_disambig = first_allowed_disambig - 1 - last_used_disambig_symbol_of = defaultdict(int) - - for word, tokens in lexicon: - tokenseq = " ".join(tokens) - assert tokenseq != "" - if issubseq[tokenseq] == 0 and count[tokenseq] == 1: - ans.append((word, tokens)) - continue - - cur_disambig = last_used_disambig_symbol_of[tokenseq] - if cur_disambig == 0: - cur_disambig = first_allowed_disambig - else: - cur_disambig += 1 - - if cur_disambig > max_disambig: - max_disambig = cur_disambig - last_used_disambig_symbol_of[tokenseq] = cur_disambig - tokenseq += f" #{cur_disambig}" - ans.append((word, tokenseq.split())) - return ans, max_disambig - - -def generate_id_map(symbols: List[str]) -> Dict[str, int]: - """Generate ID maps, i.e., map a symbol to a unique ID. - - Args: - symbols: - A list of unique symbols. - Returns: - A dict containing the mapping between symbols and IDs. - """ - return {sym: i for i, sym in enumerate(symbols)} - - -def add_self_loops( - arcs: List[List[Any]], disambig_token: int, disambig_word: int -) -> List[List[Any]]: - """Adds self-loops to states of an FST to propagate disambiguation symbols - through it. They are added on each state with non-epsilon output symbols - on at least one arc out of the state. - - See also fstaddselfloops.pl from Kaldi. One difference is that - Kaldi uses OpenFst style FSTs and it has multiple final states. - This function uses k2 style FSTs and it does not need to add self-loops - to the final state. - - The input label of a self-loop is `disambig_token`, while the output - label is `disambig_word`. - - Args: - arcs: - A list-of-list. The sublist contains - `[src_state, dest_state, label, aux_label, score]` - disambig_token: - It is the token ID of the symbol `#0`. - disambig_word: - It is the word ID of the symbol `#0`. - - Return: - Return new `arcs` containing self-loops. - """ - states_needs_self_loops = set() - for arc in arcs: - src, dst, ilabel, olabel, score = arc - if olabel != 0: - states_needs_self_loops.add(src) - - ans = [] - for s in states_needs_self_loops: - ans.append([s, s, disambig_token, disambig_word, 0]) - - return arcs + ans - - -def lexicon_to_fst( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - sil_token: str = "SIL", - sil_prob: float = 0.5, - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format) with optional silence at - the beginning and end of each word. - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - sil_token: - The silence token. - sil_prob: - The probability for adding a silence at the beginning and end - of the word. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - assert sil_prob > 0.0 and sil_prob < 1.0 - # CAUTION: we use score, i.e, negative cost. - sil_score = math.log(sil_prob) - no_sil_score = math.log(1.0 - sil_prob) - - start_state = 0 - loop_state = 1 # words enter and leave from here - sil_state = 2 # words terminate here when followed by silence; this state - # has a silence transition to loop_state. - next_state = 3 # the next un-allocated state, will be incremented as we go. - arcs = [] - - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - sil_token = token2id[sil_token] - - arcs.append([start_state, loop_state, eps, eps, no_sil_score]) - arcs.append([start_state, sil_state, eps, eps, sil_score]) - arcs.append([sil_state, loop_state, sil_token, eps, 0]) - - for word, tokens in lexicon: - assert len(tokens) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - tokens = [token2id[i] for i in tokens] - - for i in range(len(tokens) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, tokens[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last token of this word - # It has two out-going arcs, one to the loop state, - # the other one to the sil_state. - i = len(tokens) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) - arcs.append([cur_state, sil_state, tokens[i], w, sil_score]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - lexicon_filename = lang_dir / "lexicon.txt" - sil_token = "SIL" - sil_prob = 0.5 - lexicon = read_lexicon(lexicon_filename) - tokens = get_tokens(lexicon) - words = get_words(lexicon) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in tokens - tokens.append(f"#{i}") - - assert "" not in tokens - tokens = [""] + tokens - - assert "" not in words - assert "#0" not in words - assert "" not in words - assert "" not in words - - words = [""] + words + ["#0", "", ""] - - token2id = generate_id_map(tokens) - word2id = generate_id_map(words) - - write_mapping(lang_dir / "tokens.txt", token2id) - write_mapping(lang_dir / "words.txt", word2id) - write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst( - lexicon, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - ) - - L_disambig = lexicon_to_fst( - lexicon_disambig, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - need_self_loops=True, - ) - torch.save(L.as_dict(), lang_dir / "L.pt") - torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") - - if args.debug: - labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") - aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") - - L.labels_sym = labels_sym - L.aux_labels_sym = aux_labels_sym - L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") - - L_disambig.labels_sym = labels_sym - L_disambig.aux_labels_sym = aux_labels_sym - L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") - - -if __name__ == "__main__": - main() diff --git a/egs/mucs/ASR/local/prepare_lang.py b/egs/mucs/ASR/local/prepare_lang.py new file mode 120000 index 000000000..747f2ab39 --- /dev/null +++ b/egs/mucs/ASR/local/prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/mucs/ASR/local/prepare_lang_bpe.py b/egs/mucs/ASR/local/prepare_lang_bpe.py deleted file mode 100755 index 2a2d9c219..000000000 --- a/egs/mucs/ASR/local/prepare_lang_bpe.py +++ /dev/null @@ -1,266 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) - -""" - -This script takes as input `lang_dir`, which should contain:: - - - lang_dir/bpe.model, - - lang_dir/words.txt - -and generates the following files in the directory `lang_dir`: - - - lexicon.txt - - lexicon_disambig.txt - - L.pt - - L_disambig.pt - - tokens.txt -""" - -import argparse -from pathlib import Path -from typing import Dict, List, Tuple - -import k2 -import sentencepiece as spm -import torch -from prepare_lang import ( - Lexicon, - add_disambig_symbols, - add_self_loops, - write_lexicon, - write_mapping, -) - -from icefall.utils import str2bool - - -def lexicon_to_fst_no_sil( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format). - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - loop_state = 0 # words enter and leave from here - next_state = 1 # the next un-allocated state, will be incremented as we go - - arcs = [] - - # The blank symbol is defined in local/train_bpe_model.py - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - for word, pieces in lexicon: - assert len(pieces) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - pieces = [token2id[i] for i in pieces] - - for i in range(len(pieces) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, pieces[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last piece of this word - i = len(pieces) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, pieces[i], w, 0]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def generate_lexicon( - model_file: str, words: List[str], oov: str -) -> Tuple[Lexicon, Dict[str, int]]: - """Generate a lexicon from a BPE model. - - Args: - model_file: - Path to a sentencepiece model. - words: - A list of strings representing words. - oov: - The out of vocabulary word in lexicon. - Returns: - Return a tuple with two elements: - - A dict whose keys are words and values are the corresponding - word pieces. - - A dict representing the token symbol, mapping from tokens to IDs. - """ - sp = spm.SentencePieceProcessor() - sp.load(str(model_file)) - - # Convert word to word piece IDs instead of word piece strings - # to avoid OOV tokens. - words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int) - - # Now convert word piece IDs back to word piece strings. - words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] - - lexicon = [] - for word, pieces in zip(words, words_pieces): - lexicon.append((word, pieces)) - - lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())])) - - token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())} - - return lexicon, token2id - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - It should contain the bpe.model and words.txt - """, - ) - - parser.add_argument( - "--oov", - type=str, - default="", - help="The out of vocabulary word in lexicon.", - ) - - parser.add_argument( - "--debug", - type=str2bool, - default=False, - help="""True for debugging, which will generate - a visualization of the lexicon FST. - - Caution: If your lexicon contains hundreds of thousands - of lines, please set it to False! - - See "test/test_bpe_lexicon.py" for usage. - """, - ) - - return parser.parse_args() - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - model_file = lang_dir / "bpe.model" - - word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") - - words = word_sym_table.symbols - - excluded = ["", "!SIL", "", args.oov, "#0", "", ""] - - for w in excluded: - if w in words: - words.remove(w) - - lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - next_token_id = max(token_sym_table.values()) + 1 - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in token_sym_table - token_sym_table[disambig] = next_token_id - next_token_id += 1 - - word_sym_table.add("#0") - word_sym_table.add("") - word_sym_table.add("") - - write_mapping(lang_dir / "tokens.txt", token_sym_table) - - write_lexicon(lang_dir / "lexicon.txt", lexicon) - write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst_no_sil( - lexicon, - token2id=token_sym_table, - word2id=word_sym_table, - ) - - L_disambig = lexicon_to_fst_no_sil( - lexicon_disambig, - token2id=token_sym_table, - word2id=word_sym_table, - need_self_loops=True, - ) - torch.save(L.as_dict(), lang_dir / "L.pt") - torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") - - if args.debug: - labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") - aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") - - L.labels_sym = labels_sym - L.aux_labels_sym = aux_labels_sym - L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") - - L_disambig.labels_sym = labels_sym - L_disambig.aux_labels_sym = aux_labels_sym - L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") - - -if __name__ == "__main__": - main() diff --git a/egs/mucs/ASR/local/prepare_lang_bpe.py b/egs/mucs/ASR/local/prepare_lang_bpe.py new file mode 120000 index 000000000..36b40e7fc --- /dev/null +++ b/egs/mucs/ASR/local/prepare_lang_bpe.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/mucs/ASR/local/train_bpe_model.py b/egs/mucs/ASR/local/train_bpe_model.py deleted file mode 100755 index 2a3166dbd..000000000 --- a/egs/mucs/ASR/local/train_bpe_model.py +++ /dev/null @@ -1,100 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# You can install sentencepiece via: -# -# pip install sentencepiece -# -# Due to an issue reported in -# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 -# -# Please install a version >=0.1.96 - -import argparse -import shutil -from pathlib import Path - -import sentencepiece as spm - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - The generated bpe.model is saved to this directory. - """, - ) - - parser.add_argument( - "--transcript", - type=str, - help="Training transcript.", - ) - - parser.add_argument( - "--vocab-size", - type=int, - help="Vocabulary size for BPE training", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - vocab_size = args.vocab_size - lang_dir = Path(args.lang_dir) - - model_type = "unigram" - - model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" - train_text = args.transcript - character_coverage = 1.0 - input_sentence_size = 50000 - - user_defined_symbols = ["", ""] - unk_id = len(user_defined_symbols) - # Note: unk_id is fixed to 2. - # If you change it, you should also change other - # places that are using it. - - model_file = Path(model_prefix + ".model") - if not model_file.is_file(): - spm.SentencePieceTrainer.train( - input=train_text, - vocab_size=vocab_size, - model_type=model_type, - model_prefix=model_prefix, - input_sentence_size=input_sentence_size, - character_coverage=character_coverage, - user_defined_symbols=user_defined_symbols, - unk_id=unk_id, - bos_id=-1, - eos_id=-1, - ) - else: - print(f"{model_file} exists - skipping") - return - - shutil.copyfile(model_file, f"{lang_dir}/bpe.model") - - -if __name__ == "__main__": - main() diff --git a/egs/mucs/ASR/local/train_bpe_model.py b/egs/mucs/ASR/local/train_bpe_model.py new file mode 120000 index 000000000..6fad36421 --- /dev/null +++ b/egs/mucs/ASR/local/train_bpe_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/mucs/ASR/local/validate_bpe_lexicon.py b/egs/mucs/ASR/local/validate_bpe_lexicon.py deleted file mode 100755 index c542f2fab..000000000 --- a/egs/mucs/ASR/local/validate_bpe_lexicon.py +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script checks that there are no OOV tokens in the BPE-based lexicon. - -Usage example: - - python3 ./local/validate_bpe_lexicon.py \ - --lexicon /path/to/lexicon.txt \ - --bpe-model /path/to/bpe.model -""" - -import argparse -from pathlib import Path -from typing import List, Tuple - -import sentencepiece as spm - -from icefall.lexicon import read_lexicon - -# Map word to word pieces -Lexicon = List[Tuple[str, List[str]]] - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--lexicon", - required=True, - type=Path, - help="Path to lexicon.txt", - ) - - parser.add_argument( - "--bpe-model", - required=True, - type=Path, - help="Path to bpe.model", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - assert args.lexicon.is_file(), args.lexicon - assert args.bpe_model.is_file(), args.bpe_model - - lexicon = read_lexicon(args.lexicon) - - sp = spm.SentencePieceProcessor() - sp.load(str(args.bpe_model)) - - word_pieces = set(sp.id_to_piece(list(range(sp.vocab_size())))) - for word, pieces in lexicon: - for p in pieces: - if p not in word_pieces: - raise ValueError(f"The word {word} contains an OOV token {p}") - - -if __name__ == "__main__": - main() diff --git a/egs/mucs/ASR/local/validate_bpe_lexicon.py b/egs/mucs/ASR/local/validate_bpe_lexicon.py new file mode 120000 index 000000000..721bb48e7 --- /dev/null +++ b/egs/mucs/ASR/local/validate_bpe_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/mucs/ASR/local/validate_manifest.py b/egs/mucs/ASR/local/validate_manifest.py deleted file mode 100755 index 2452dbc43..000000000 --- a/egs/mucs/ASR/local/validate_manifest.py +++ /dev/null @@ -1,110 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script checks the following assumptions of the generated manifest: - -- Single supervision per cut -- Supervision time bounds are within cut time bounds - -We will add more checks later if needed. - -Usage example: - - python3 ./local/validate_manifest.py \ - ./data/fbank/librispeech_cuts_train-clean-100.jsonl.gz - -""" - -import argparse -import logging -from pathlib import Path - -from lhotse import CutSet, load_manifest_lazy -from lhotse.cut import Cut -from lhotse.dataset.speech_recognition import validate_for_asr - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "manifest", - type=Path, - help="Path to the manifest file", - ) - - return parser.parse_args() - - -def validate_one_supervision_per_cut(c: Cut): - if len(c.supervisions) != 1: - raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions") - - -def validate_supervision_and_cut_time_bounds(c: Cut): - tol = 2e-3 # same tolerance as in 'validate_for_asr()' - s = c.supervisions[0] - - # Supervision start time is relative to Cut ... - # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html - # print(s.start, ) - if s.start < -tol: - raise ValueError( - f"{c.id}: Supervision start time {s.start} must not be negative." - ) - if s.start > tol: - raise ValueError( - f"{c.id}: Supervision start time {s.start} is not at the beginning of the Cut. Please apply `lhotse cut trim-to-supervisions`." - ) - if c.start + s.end > c.end + tol: - raise ValueError( - f"{c.id}: Supervision end time {c.start+s.end} is larger " - f"than cut end time {c.end}" - ) - - -def main(): - args = get_args() - - manifest = args.manifest - logging.info(f"Validating {manifest}") - - assert manifest.is_file(), f"{manifest} does not exist" - # print(manifest) - cut_set = load_manifest_lazy(manifest) - # print(cut_set) - assert isinstance(cut_set, CutSet) - - for c in cut_set: - # print(len(c.supervisions)) - # print(c.supervisions) - validate_one_supervision_per_cut(c) - validate_supervision_and_cut_time_bounds(c) - - # Validation from K2 training - # - checks supervision start is 0 - # - checks supervision.duration is not longer than cut.duration - # - there is tolerance 2ms - validate_for_asr(cut_set) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/mucs/ASR/local/validate_manifest.py b/egs/mucs/ASR/local/validate_manifest.py new file mode 120000 index 000000000..0a9725e87 --- /dev/null +++ b/egs/mucs/ASR/local/validate_manifest.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_manifest.py \ No newline at end of file diff --git a/egs/mucs/ASR/run.sh b/egs/mucs/ASR/run.sh index d232d3f04..9df057330 100755 --- a/egs/mucs/ASR/run.sh +++ b/egs/mucs/ASR/run.sh @@ -4,14 +4,14 @@ export CUDA_VISIBLE_DEVICES="0" ./conformer_ctc/train.py \ --num-epochs 60 \ --max-duration 300 \ - --exp-dir ./conformer_ctc/exp_with_devset_split \ - --lang-dir data/lang_bpe_200 \ + --exp-dir ./conformer_ctc/exp_with_devset_split_bpe400 \ + --lang-dir data/lang_bpe_400 \ --enable-musan False \ ./conformer_ctc/decode.py \ --epoch 59 \ --avg 10 \ - --exp-dir ./conformer_ctc/exp_with_devset_split \ + --exp-dir ./conformer_ctc/exp_with_devset_split_bpe400 \ --max-duration 100 \ - --lang-dir ./data/lang_bpe_200 + --lang-dir ./data/lang_bpe_400