diff --git a/egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py b/egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py index 020800c15..2b8708272 100755 --- a/egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py +++ b/egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py @@ -22,8 +22,6 @@ import argparse from pathlib import Path -from tqdm.auto import tqdm - from icefall.utils import tokenize_by_CJK_char diff --git a/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py b/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py new file mode 100755 index 000000000..45cc4b8d7 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey +# Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script takes a `bpe.model` and a text file such as +./download/lm/librispeech-lm-norm.txt +and outputs the LM training data to a supplied directory such +as data/lm_training_bpe_500. The format is as follows: + +It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a +representation of a dict with the following format: + + 'words' -> a k2.RaggedTensor of two axes [word][token] with dtype torch.int32 + containing the BPE representations of each word, indexed by + integer word ID. (These integer word IDS are present in + 'lm_data'). The sentencepiece object can be used to turn the + words and BPE units into string form. + 'sentences' -> a k2.RaggedTensor of two axes [sentence][word] with dtype + torch.int32 containing all the sentences, as word-ids (we don't + output the string form of this directly but it can be worked out + together with 'words' and the bpe.model). + 'sentence_lengths' -> a 1-D torch.Tensor of dtype torch.int32, containing + number of BPE tokens of each sentence. +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import sentencepiece as spm +import torch + +from icefall.utils import tokenize_by_CJK_char + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--bpe-model", + default="data/lang_bpe_2000/bpe.model", + type=str, + help="Input BPE model, e.g. data/bpe_500/bpe.model", + ) + parser.add_argument( + "--lm-data", + type=str, + help="""Input LM training data as text, e.g. + download/pb.train.txt""", + ) + parser.add_argument( + "--lm-archive", + type=str, + help="""Path to output archive, e.g. data/bpe_500/lm_data.pt; + look at the source of this script to see the format.""", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + if Path(args.lm_archive).exists(): + logging.warning(f"{args.lm_archive} exists - skipping") + return + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + # word2index is a dictionary from words to integer ids. No need to reserve + # space for epsilon, etc.; the words are just used as a convenient way to + # compress the sequences of BPE pieces. + word2index = dict() + + word2bpe = [] # Will be a list-of-list-of-int, representing BPE pieces. + sentences = [] # Will be a list-of-list-of-int, representing word-ids. + + num_lines_in_total = None + step = 500000 + + processed = 0 + + with open(args.lm_data) as f: + while True: + line = f.readline() + line = tokenize_by_CJK_char(line) + if line == "": + break + + if step and processed % step == 0: + logging.info( + f"Processed number of lines: {processed} " + f"({processed/num_lines_in_total*100: .3f}%)" + ) + processed += 1 + + line_words = line.split() + for w in line_words: + if w not in word2index: + w_bpe = sp.encode(w) + word2index[w] = len(word2bpe) + word2bpe.append(w_bpe) + sentences.append([word2index[w] for w in line_words]) + + logging.info("Constructing ragged tensors") + words = k2.ragged.RaggedTensor(word2bpe) + sentences = k2.ragged.RaggedTensor(sentences) + + output = dict(words=words, sentences=sentences) + + num_sentences = sentences.dim0 + logging.info(f"Computing sentence lengths, num_sentences: {num_sentences}") + sentence_lengths = [0] * num_sentences + for i in range(num_sentences): + if step and i % step == 0: + logging.info( + f"Processed number of lines: {i} ({i/num_sentences*100: .3f}%)" + ) + + word_ids = sentences[i] + + # NOTE: If word_ids is a tensor with only 1 entry, + # token_ids is a torch.Tensor + token_ids = words[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + + # token_ids is a 1-D tensor containing the BPE tokens + # of the current sentence + + sentence_lengths[i] = token_ids.numel() + + output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32) + + torch.save(output, args.lm_archive) + logging.info(f"Saved to {args.lm_archive}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/multi_zh-hans/ASR/local/sort_lm_training_data.py b/egs/multi_zh-hans/ASR/local/sort_lm_training_data.py new file mode 120000 index 000000000..1d6ccbe33 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/sort_lm_training_data.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/sort_lm_training_data.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py index 02cfa1346..ae1264659 100644 --- a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py +++ b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py @@ -19,25 +19,20 @@ import argparse import inspect import logging -from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures +from lhotse import CutSet, Fbank, FbankConfig, load_manifest +from lhotse.dataset import ( CutConcatenate, CutMix, DynamicBucketingSampler, K2SpeechRecognitionDataset, - PrecomputedFeatures, SimpleCutSampler, SpecAugment, ) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) +from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader diff --git a/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py b/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py index b1920e62e..ac90dc73e 100644 --- a/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py +++ b/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py @@ -15,11 +15,9 @@ # limitations under the License. -import glob import logging -import re from pathlib import Path -from typing import Dict, List +from typing import Dict import lhotse from lhotse import CutSet, load_manifest_lazy