icefall/egs/librispeech/ASR/local/prepare_lm_training_data.py
2022-11-17 09:42:17 -05:00

168 lines
5.4 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey
# Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes a `bpe.model` and a text file such as
./download/lm/librispeech-lm-norm.txt
and outputs the LM training data to a supplied directory such
as data/lm_training_bpe_500. The format is as follows:
It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a
representation of a dict with the following format:
'words' -> a k2.RaggedTensor of two axes [word][token] with dtype torch.int32
containing the BPE representations of each word, indexed by
integer word ID. (These integer word IDS are present in
'lm_data'). The sentencepiece object can be used to turn the
words and BPE units into string form.
'sentences' -> a k2.RaggedTensor of two axes [sentence][word] with dtype
torch.int32 containing all the sentences, as word-ids (we don't
output the string form of this directly but it can be worked out
together with 'words' and the bpe.model).
'sentence_lengths' -> a 1-D torch.Tensor of dtype torch.int32, containing
number of BPE tokens of each sentence.
"""
import argparse
import logging
from pathlib import Path
import k2
import sentencepiece as spm
import torch
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--bpe-model",
type=str,
help="Input BPE model, e.g. data/bpe_500/bpe.model",
)
parser.add_argument(
"--lm-data",
type=str,
help="""Input LM training data as text, e.g.
download/pb.train.txt""",
)
parser.add_argument(
"--lm-archive",
type=str,
help="""Path to output archive, e.g. data/bpe_500/lm_data.pt;
look at the source of this script to see the format.""",
)
return parser.parse_args()
def main():
args = get_args()
if Path(args.lm_archive).exists():
logging.warning(f"{args.lm_archive} exists - skipping")
return
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
# word2index is a dictionary from words to integer ids. No need to reserve
# space for epsilon, etc.; the words are just used as a convenient way to
# compress the sequences of BPE pieces.
word2index = dict()
word2bpe = [] # Will be a list-of-list-of-int, representing BPE pieces.
sentences = [] # Will be a list-of-list-of-int, representing word-ids.
if "librispeech-lm-norm" in args.lm_data:
num_lines_in_total = 40418261.0
step = 5000000
elif "valid" in args.lm_data:
num_lines_in_total = 5567.0
step = 3000
elif "test" in args.lm_data:
num_lines_in_total = 5559.0
step = 3000
else:
num_lines_in_total = None
step = None
processed = 0
with open(args.lm_data) as f:
while True:
line = f.readline()
if line == "":
break
if step and processed % step == 0:
logging.info(
f"Processed number of lines: {processed} "
f"({processed/num_lines_in_total*100: .3f}%)"
)
processed += 1
line_words = line.split()
for w in line_words:
if w not in word2index:
w_bpe = sp.encode(w)
word2index[w] = len(word2bpe)
word2bpe.append(w_bpe)
sentences.append([word2index[w] for w in line_words])
logging.info("Constructing ragged tensors")
words = k2.ragged.RaggedTensor(word2bpe)
sentences = k2.ragged.RaggedTensor(sentences)
output = dict(words=words, sentences=sentences)
num_sentences = sentences.dim0
logging.info(f"Computing sentence lengths, num_sentences: {num_sentences}")
sentence_lengths = [0] * num_sentences
for i in range(num_sentences):
if step and i % step == 0:
logging.info(
f"Processed number of lines: {i} " f"({i/num_sentences*100: .3f}%)"
)
word_ids = sentences[i]
# NOTE: If word_ids is a tensor with only 1 entry,
# token_ids is a torch.Tensor
token_ids = words[word_ids]
if isinstance(token_ids, k2.RaggedTensor):
token_ids = token_ids.values
# token_ids is a 1-D tensor containing the BPE tokens
# of the current sentence
sentence_lengths[i] = token_ids.numel()
output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32)
torch.save(output, args.lm_archive)
logging.info(f"Saved to {args.lm_archive}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()