diff --git a/egs/librispeech/ASR/local/train_bpe_model.py b/egs/librispeech/ASR/local/train_bpe_model.py index bc5812810..42aba9572 100755 --- a/egs/librispeech/ASR/local/train_bpe_model.py +++ b/egs/librispeech/ASR/local/train_bpe_model.py @@ -38,7 +38,6 @@ def get_args(): "--lang-dir", type=str, help="""Input and output directory. - It should contain the training corpus: transcript_words.txt. The generated bpe.model is saved to this directory. """, ) diff --git a/egs/ptb/LM/README.md b/egs/ptb/LM/README.md new file mode 100644 index 000000000..7629a950d --- /dev/null +++ b/egs/ptb/LM/README.md @@ -0,0 +1,18 @@ +## Description + +(Note: the experiments here are only about language modeling) + +ptb is short for Penn Treebank. + + +About the Penn Treebank corpus: + - This corpus is free for research purposes + - ptb.train.txt: train set + - ptb.valid.txt: development set (should be used just for tuning hyper-parameters, but not for training) + - ptb.test.txt: test set for reporting perplexity + +You can download the dataset from one of the following URLs: + +- https://github.com/townie/PTB-dataset-from-Tomas-Mikolov-s-webpage +- http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz +- https://deepai.org/dataset/penn-treebank diff --git a/egs/ptb/LM/local/prepare_lm_training_data.py b/egs/ptb/LM/local/prepare_lm_training_data.py new file mode 100755 index 000000000..bc7555209 --- /dev/null +++ b/egs/ptb/LM/local/prepare_lm_training_data.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey +# Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script takes a `bpe.model` and a text file such as +`download/ptb.train.txt`, +and outputs the LM training data to a supplied directory such +as data/bpe_500. The format is as follows: + +It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a +representation of a dict with the following format: + + 'words' -> a k2.RaggedTensor of two axes [word][token] with dtype torch.int32 + containing the BPE representations of each word, indexed by + integer word ID. (These integer word IDS are present in + 'lm_data'). The sentencepiece object can be used to turn the + words and BPE units into string form. + 'sentences' -> a k2.RaggedTensor of two axes [sentence][word] with dtype + torch.int32 containing all the sentences, as word-ids (we don't + output the string form of this directly but it can be worked out + together with 'words' and the bpe.model). + 'sentence_lengths' -> a 1-D torch.Tensor of dtype torch.int32, containing + number of BPE tokens of each sentence. +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import sentencepiece as spm +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--bpe-model", + type=str, + help="Input BPE model, e.g. data/bpe_500/bpe.model", + ) + parser.add_argument( + "--lm-data", + type=str, + help="""Input LM training data as text, e.g. + download/pb.train.txt""", + ) + parser.add_argument( + "--lm-archive", + type=str, + help="""Path to output archive, e.g. data/bpe_500/lm_data.pt; + look at the source of this script to see the format.""", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + if Path(args.lm_archive).exists(): + logging.warning(f"{args.lm_archive} exists - skipping") + return + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + # word2index is a dictionary from words to integer ids. No need to reserve + # space for epsilon, etc.; the words are just used as a convenient way to + # compress the sequences of BPE pieces. + word2index = dict() + + word2bpe = [] # Will be a list-of-list-of-int, representing BPE pieces. + + # ptb.train.txt has already converted oov words to + word2bpe.append([sp.unk_id()]) + word2index[""] = 0 + + sentences = [] # Will be a list-of-list-of-int, representing word-ids. + + with open(args.lm_data) as f: + while True: + line = f.readline() + if line == "": + break + line_words = line.split() + for w in line_words: + if w not in word2index: + w_bpe = sp.encode(w) + word2index[w] = len(word2bpe) + word2bpe.append(w_bpe) + sentences.append([word2index[w] for w in line_words]) + + words = k2.ragged.RaggedTensor(word2bpe) + sentences = k2.ragged.RaggedTensor(sentences) + + output = dict(words=words, sentences=sentences) + + num_sentences = sentences.dim0 + sentence_lengths = [0] * num_sentences + for i in range(num_sentences): + word_ids = sentences[i] + + # NOTE: If word_ids is a tensor with only 1 entry, + # token_ids is a torch.Tensor + token_ids = words[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + + # token_ids is a 1-D tensor containing the BPE tokens + # of the current sentence + + sentence_lengths[i] = token_ids.numel() + + output["sentence_lengths"] = torch.tensor( + sentence_lengths, dtype=torch.int32 + ) + + torch.save(output, args.lm_archive) + logging.info(f"Saved to {args.lm_archive}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/ptb/LM/local/sort_lm_training_data.py b/egs/ptb/LM/local/sort_lm_training_data.py new file mode 100755 index 000000000..af54dbd07 --- /dev/null +++ b/egs/ptb/LM/local/sort_lm_training_data.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file takes as input the filename of LM training data +generated by ./local/prepare_lm_training_data.py and sorts +it by sentence length. + +Sentence length equals to the number of BPE tokens in a sentence. +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import numpy as np +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--in-lm-data", + type=str, + help="Input LM training data, e.g., data/bpe_500/lm_data.pt", + ) + + parser.add_argument( + "--out-lm-data", + type=str, + help="Input LM training data, e.g., data/bpe_500/sorted_lm_data.pt", + ) + + parser.add_argument( + "--out-statistics", + type=str, + help="Statistics about LM training data., data/bpe_500/statistics.txt", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + in_lm_data = Path(args.in_lm_data) + out_lm_data = Path(args.out_lm_data) + assert in_lm_data.is_file(), f"{in_lm_data}" + if out_lm_data.is_file(): + logging.warning(f"{out_lm_data} exists - skipping") + return + data = torch.load(in_lm_data) + words2bpe = data["words"] + sentences = data["sentences"] + sentence_lengths = data["sentence_lengths"] + + num_sentences = sentences.dim0 + assert num_sentences == sentence_lengths.numel(), ( + num_sentences, + sentence_lengths.numel(), + ) + + indices = torch.argsort(sentence_lengths, descending=True) + + sorted_sentences = sentences[indices.to(torch.int32)] + sorted_sentence_lengths = sentence_lengths[indices] + + # Check that sentences are ordered by length + assert num_sentences == sorted_sentences.dim0, ( + num_sentences, + sorted_sentences.dim0, + ) + + cur = None + for i in range(num_sentences): + word_ids = sorted_sentences[i] + token_ids = words2bpe[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + if cur is not None: + assert cur >= token_ids.numel(), (cur, token_ids.numel()) + + cur = token_ids.numel() + assert cur == sorted_sentence_lengths[i] + + data["sentences"] = sorted_sentences + data["sentence_lengths"] = sorted_sentence_lengths + torch.save(data, args.out_lm_data) + logging.info(f"Saved to {args.out_lm_data}") + + statistics = Path(args.out_statistics) + + # Write statistics + num_words = sorted_sentences.numel() + num_tokens = sentence_lengths.sum().item() + max_sentence_length = sentence_lengths[indices[0]] + min_sentence_length = sentence_lengths[indices[-1]] + + step = 10 + hist, bins = np.histogram( + sentence_lengths.numpy(), + bins=np.arange(1, max_sentence_length + step, step), + ) + + histogram = np.stack((bins[:-1], hist)).transpose() + + with open(statistics, "w") as f: + f.write(f"num_sentences: {num_sentences}\n") + f.write(f"num_words: {num_words}\n") + f.write(f"num_tokens: {num_tokens}\n") + f.write(f"max_sentence_length: {max_sentence_length}\n") + f.write(f"min_sentence_length: {min_sentence_length}\n") + f.write("histogram:\n") + f.write(" bin count percent\n") + for row in histogram: + f.write( + f"{int(row[0]):>5} {int(row[1]):>5} " + f"{100.*row[1]/num_sentences:.3f}%\n" + ) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ptb/LM/local/test_prepare_lm_training_data.py b/egs/ptb/LM/local/test_prepare_lm_training_data.py new file mode 100755 index 000000000..877720e7b --- /dev/null +++ b/egs/ptb/LM/local/test_prepare_lm_training_data.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path + +import sentencepiece as spm +import torch + + +def main(): + lm_training_data = Path("./data/bpe_500/lm_data.pt") + bpe_model = Path("./data/bpe_500/bpe.model") + if not lm_training_data.exists(): + logging.warning(f"{lm_training_data} does not exist - skipping") + return + + if not bpe_model.exists(): + logging.warning(f"{bpe_model} does not exist - skipping") + return + + sp = spm.SentencePieceProcessor() + sp.load(str(bpe_model)) + + data = torch.load(lm_training_data) + words2bpe = data["words"] + sentences = data["sentences"] + + ss = [] + unk = sp.decode(sp.unk_id()).strip() + for i in range(10): + s = sp.decode(words2bpe[sentences[i]].values.tolist()) + s = s.replace(unk, "") + ss.append(s) + + for s in ss: + print(s) + # You can compare the output with the first 10 lines of ptb.train.txt + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ptb/LM/local/train_bpe_model.py b/egs/ptb/LM/local/train_bpe_model.py new file mode 100755 index 000000000..8d87707a9 --- /dev/null +++ b/egs/ptb/LM/local/train_bpe_model.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# You can install sentencepiece via: +# +# pip install sentencepiece +# +# Due to an issue reported in +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 +# +# Please install a version >=0.1.96 + +import argparse +import shutil +from pathlib import Path + +import sentencepiece as spm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--out-dir", + type=str, + help="""Input and output directory. + The generated bpe.model is saved to this directory. + """, + ) + + parser.add_argument( + "--transcript", + type=str, + help="Training transcript.", + ) + + parser.add_argument( + "--vocab-size", + type=int, + help="Vocabulary size for BPE training", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + vocab_size = args.vocab_size + model_type = "unigram" + + model_prefix = f"{args.out_dir}/{model_type}_{vocab_size}" + train_text = args.transcript + character_coverage = 1.0 + input_sentence_size = 100000000 + + user_defined_symbols = ["", ""] + unk_id = len(user_defined_symbols) + # Note: unk_id is fixed to 2. + # If you change it, you should also change other + # places that are using it. + + model_file = Path(model_prefix + ".model") + if not model_file.is_file(): + spm.SentencePieceTrainer.train( + input=train_text, + vocab_size=vocab_size, + model_type=model_type, + model_prefix=model_prefix, + input_sentence_size=input_sentence_size, + character_coverage=character_coverage, + user_defined_symbols=user_defined_symbols, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + ) + + shutil.copyfile(model_file, f"{args.out_dir}/bpe.model") + + +if __name__ == "__main__": + main() diff --git a/egs/ptb/LM/prepare.sh b/egs/ptb/LM/prepare.sh new file mode 100755 index 000000000..33b1e405a --- /dev/null +++ b/egs/ptb/LM/prepare.sh @@ -0,0 +1,95 @@ +#!/usr/bin/env bash + +set -eou pipefail + +nj=15 +stage=-1 +stop_stage=100 + +dl_dir=$PWD/download +# The following files will be downloaded to $dl_dir +# - ptb.train.txt +# - ptb.valid.txt +# - ptb.test.txt + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/bpe_xxx, data/bpe_yyy +# if the array contains xxx, yyy +vocab_sizes=( + 500 + 1000 + 2000 + 5000 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data +mkdir -p $dl_dir + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: Download data" + if [ ! -f $dl_dir/.complete ]; then + url=https://raw.githubusercontent.com/townie/PTB-dataset-from-Tomas-Mikolov-s-webpage/master/data/ + wget --no-verbose --directory-prefix $dl_dir $url/ptb.train.txt + wget --no-verbose --directory-prefix $dl_dir $url/ptb.valid.txt + wget --no-verbose --directory-prefix $dl_dir $url/ptb.test.txt + touch $dl_dir/.complete + fi +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Train BPE model" + + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/bpe_${vocab_size} + mkdir -p $out_dir + ./local/train_bpe_model.py \ + --out-dir $out_dir \ + --vocab-size $vocab_size \ + --transcript $dl_dir/ptb.train.txt + done +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Generate LM training data" + # Note: ptb.train.txt has already been normalized + + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/bpe_${vocab_size} + mkdir -p $out_dir + ./local/prepare_lm_training_data.py \ + --bpe-model $out_dir/bpe.model \ + --lm-data $dl_dir/ptb.train.txt \ + --lm-archive $out_dir/lm_data.pt + done +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Sort LM training data" + # Sort LM training data generated in stage 1 + # by sentence length in descending order + # for ease of training. + # + # Sentence length equals to the number of BPE tokens + # in a sentence. + + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/bpe_${vocab_size} + mkdir -p $out_dir + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt + done +fi diff --git a/egs/ptb/LM/rnn_lm/__init__.py b/egs/ptb/LM/rnn_lm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/ptb/LM/rnn_lm/dataset.py b/egs/ptb/LM/rnn_lm/dataset.py new file mode 100644 index 000000000..a7aaf37ac --- /dev/null +++ b/egs/ptb/LM/rnn_lm/dataset.py @@ -0,0 +1,260 @@ +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Tuple + +import k2 +import torch + + +class LmDataset(torch.utils.data.Dataset): + def __init__( + self, + sentences: k2.RaggedTensor, + words: k2.RaggedTensor, + sentence_lengths: torch.Tensor, + max_sent_len: int, + batch_size: int, + ): + """ + Args: + sentences: + A ragged tensor of dtype torch.int32 with 2 axes [sentence][word]. + words: + A ragged tensor of dtype torch.int32 with 2 axes [word][token]. + sentence_lengths: + A 1-D tensor of dtype torch.int32 containing number of tokens + of each sentence. + max_sent_len: + Maximum sentence length. It is used to change the batch size + dynamically. In general, we try to keep the product of + "max_sent_len in a batch" and "num_of_sent in a batch" being + a constant. + batch_size: + The expected batch size. It is changed dynamically according + to the "max_sent_len". + + See `../local/prepare_lm_training_data.py` for how `sentences` and + `words` are generated. We assume that `sentences` are sorted by length. + See `../local/sort_lm_training_data.py`. + """ + super().__init__() + self.sentences = sentences + self.words = words + + sentence_lengths = sentence_lengths.tolist() + + assert batch_size > 0, batch_size + assert max_sent_len > 1, max_sent_len + batch_indexes = [] + num_sentences = sentences.dim0 + cur = 0 + while cur < num_sentences: + sz = sentence_lengths[cur] // max_sent_len + 1 + # Assume the current sentence has 3 * max_sent_len tokens, + # in the worst case, the subsequent sentences also have + # this number of tokens, we should reduce the batch size + # so that this batch will not contain too many tokens + actucal_batch_size = batch_size // sz + 1 + actucal_batch_size = min(actucal_batch_size, batch_size) + end = cur + actucal_batch_size + end = min(end, num_sentences) + this_batch_indexes = torch.arange(cur, end).tolist() + batch_indexes.append(this_batch_indexes) + cur = end + assert batch_indexes[-1][-1] == num_sentences - 1 + + self.batch_indexes = k2.RaggedTensor(batch_indexes) + + def __len__(self) -> int: + """Return number of batches in this dataset""" + return self.batch_indexes.dim0 + + def __getitem__(self, i: int) -> k2.RaggedTensor: + """Get the i'th batch in this dataset + Return a ragged tensor with 2 axes [sentence][token]. + """ + assert 0 <= i < len(self), i + + # indexes is a 1-D tensor containing sentence indexes + indexes = self.batch_indexes[i] + + # sentence_words is a ragged tensor with 2 axes + # [sentence][word] + sentence_words = self.sentences[indexes] + + # in case indexes contains only 1 entry, the returned + # sentence_words is a 1-D tensor, we have to convert + # it to a ragged tensor + if isinstance(sentence_words, torch.Tensor): + sentence_words = k2.RaggedTensor(sentence_words.unsqueeze(0)) + + # sentence_word_tokens is a ragged tensor with 3 axes + # [sentence][word][token] + sentence_word_tokens = self.words.index(sentence_words) + assert sentence_word_tokens.num_axes == 3 + + sentence_tokens = sentence_word_tokens.remove_axis(1) + return sentence_tokens + + +def concat( + ragged: k2.RaggedTensor, value: int, direction: str +) -> k2.RaggedTensor: + """Prepend a value to the beginning of each sublist or append a value. + to the end of each sublist. + + Args: + ragged: + A ragged tensor with two axes. + value: + The value to prepend or append. + direction: + It can be either "left" or "right". If it is "left", we + prepend the value to the beginning of each sublist; + if it is "right", we append the value to the end of each + sublist. + + Returns: + Return a new ragged tensor, whose sublists either start with + or end with the given value. + + >>> a = k2.RaggedTensor([[1, 3], [5]]) + >>> a + [ [ 1 3 ] [ 5 ] ] + >>> concat(a, value=0, direction="left") + [ [ 0 1 3 ] [ 0 5 ] ] + >>> concat(a, value=0, direction="right") + [ [ 1 3 0 ] [ 5 0 ] ] + + """ + dtype = ragged.dtype + device = ragged.device + + assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}" + pad_values = torch.full( + size=(ragged.tot_size(0), 1), + fill_value=value, + device=device, + dtype=dtype, + ) + pad = k2.RaggedTensor(pad_values) + + if direction == "left": + ans = k2.ragged.cat([pad, ragged], axis=1) + elif direction == "right": + ans = k2.ragged.cat([ragged, pad], axis=1) + else: + raise ValueError( + f'Unsupported direction: {direction}. " \ + "Expect either "left" or "right"' + ) + return ans + + +def add_sos(ragged: k2.RaggedTensor, sos_id: int) -> k2.RaggedTensor: + """Add SOS to each sublist. + + Args: + ragged: + A ragged tensor with two axes. + sos_id: + The ID of the SOS symbol. + + Returns: + Return a new ragged tensor, where each sublist starts with SOS. + + >>> a = k2.RaggedTensor([[1, 3], [5]]) + >>> a + [ [ 1 3 ] [ 5 ] ] + >>> add_sos(a, sos_id=0) + [ [ 0 1 3 ] [ 0 5 ] ] + + """ + return concat(ragged, sos_id, direction="left") + + +def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor: + """Add EOS to each sublist. + + Args: + ragged: + A ragged tensor with two axes. + eos_id: + The ID of the EOS symbol. + + Returns: + Return a new ragged tensor, where each sublist ends with EOS. + + >>> a = k2.RaggedTensor([[1, 3], [5]]) + >>> a + [ [ 1 3 ] [ 5 ] ] + >>> add_eos(a, eos_id=0) + [ [ 1 3 0 ] [ 5 0 ] ] + + """ + return concat(ragged, eos_id, direction="right") + + +class LmDatasetCollate: + def __init__(self, sos_id: int, eos_id: int, blank_id: int): + """ + Args: + sos_id: + Token ID of the SOS symbol. + eos_id: + Token ID of the EOS symbol. + blank_id: + Token ID of the blank symbol. + """ + self.sos_id = sos_id + self.eos_id = eos_id + self.blank_id = blank_id + + def __call__( + self, batch: List[k2.RaggedTensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Return a tuple containing 3 tensors: + + - x, a 2-D tensor of dtype torch.int32; each row contains tokens + for a sentence starting with `self.sos_id`. It is padded to + the max sentence length with `self.blank_id`. + + - x, a 2-D tensor of dtype torch.int32; each row contains tokens + for a sentence ending with `self.eos_id` before padding. + Then it is padded to the max sentence length with + `self.blank_id`. + + - lengths, a 2-D tensor of dtype torch.int32, containing the number of + tokens of each sentence before padding. + """ + # The batching stuff has already been done in LmDataset + assert len(batch) == 1 + sentence_tokens = batch[0] + row_splits = sentence_tokens.shape.row_splits(1) + sentence_token_lengths = row_splits[1:] - row_splits[:-1] + sentence_tokens_with_sos = add_sos(sentence_tokens, self.sos_id) + sentence_tokens_with_eos = add_eos(sentence_tokens, self.eos_id) + + x = sentence_tokens_with_sos.pad( + mode="constant", padding_value=self.blank_id + ) + y = sentence_tokens_with_eos.pad( + mode="constant", padding_value=self.blank_id + ) + sentence_token_lengths += 1 # plus 1 since we added a SOS + + return x, y, sentence_token_lengths diff --git a/egs/ptb/LM/rnn_lm/test_dataset.py b/egs/ptb/LM/rnn_lm/test_dataset.py new file mode 100755 index 000000000..7d08515e0 --- /dev/null +++ b/egs/ptb/LM/rnn_lm/test_dataset.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import k2 +import torch +from rnn_lm.dataset import LmDataset, LmDatasetCollate + + +def main(): + sentences = k2.RaggedTensor( + [[0, 1, 2], [1, 0, 1], [0, 1], [1, 3, 0, 2, 0], [3], [0, 2, 1]] + ) + words = k2.RaggedTensor([[3, 6], [2, 8, 9, 3], [5], [5, 6, 7, 8, 9]]) + + num_sentences = sentences.dim0 + + sentence_lengths = [0] * num_sentences + for i in range(num_sentences): + word_ids = sentences[i] + + # NOTE: If word_ids is a tensor with only 1 entry, + # token_ids is a torch.Tensor + token_ids = words[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + + # token_ids is a 1-D tensor containing the BPE tokens + # of the current sentence + + sentence_lengths[i] = token_ids.numel() + + sentence_lengths = torch.tensor(sentence_lengths, dtype=torch.int32) + + indices = torch.argsort(sentence_lengths, descending=True) + sentences = sentences[indices.to(torch.int32)] + sentence_lengths = sentence_lengths[indices] + + dataset = LmDataset( + sentences=sentences, + words=words, + sentence_lengths=sentence_lengths, + max_sent_len=3, + batch_size=4, + ) + print(dataset.sentences) + print(dataset.words) + print(dataset.batch_indexes) + print(len(dataset)) + collate_fn = LmDatasetCollate(sos_id=1, eos_id=-1, blank_id=0) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=1, collate_fn=collate_fn + ) + + for i in dataloader: + print(i) + # I've checked the output manually; the output is as expected. + + +if __name__ == "__main__": + main() diff --git a/egs/ptb/LM/rnn_lm/test_dataset_ddp.py b/egs/ptb/LM/rnn_lm/test_dataset_ddp.py new file mode 100755 index 000000000..48fbb19cb --- /dev/null +++ b/egs/ptb/LM/rnn_lm/test_dataset_ddp.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import k2 +import torch +import torch.multiprocessing as mp +from rnn_lm.dataset import LmDataset, LmDatasetCollate +from torch import distributed as dist + + +def generate_data(): + sentences = k2.RaggedTensor( + [[0, 1, 2], [1, 0, 1], [0, 1], [1, 3, 0, 2, 0], [3], [0, 2, 1]] + ) + words = k2.RaggedTensor([[3, 6], [2, 8, 9, 3], [5], [5, 6, 7, 8, 9]]) + + num_sentences = sentences.dim0 + + sentence_lengths = [0] * num_sentences + for i in range(num_sentences): + word_ids = sentences[i] + + # NOTE: If word_ids is a tensor with only 1 entry, + # token_ids is a torch.Tensor + token_ids = words[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + + # token_ids is a 1-D tensor containing the BPE tokens + # of the current sentence + + sentence_lengths[i] = token_ids.numel() + + sentence_lengths = torch.tensor(sentence_lengths, dtype=torch.int32) + + indices = torch.argsort(sentence_lengths, descending=True) + sentences = sentences[indices.to(torch.int32)] + sentence_lengths = sentence_lengths[indices] + + return sentences, words, sentence_lengths + + +def run(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12352" + + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + + sentences, words, sentence_lengths = generate_data() + + dataset = LmDataset( + sentences=sentences, + words=words, + sentence_lengths=sentence_lengths, + max_sent_len=3, + batch_size=4, + ) + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, shuffle=True, drop_last=False + ) + + collate_fn = LmDatasetCollate(sos_id=1, eos_id=-1, blank_id=0) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=1, + collate_fn=collate_fn, + sampler=sampler, + shuffle=False, + ) + + for i in dataloader: + print(f"rank: {rank}", i) + + dist.destroy_process_group() + + +def main(): + world_size = 2 + mp.spawn(run, args=(world_size,), nprocs=world_size, join=True) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/ptb/LM/shared b/egs/ptb/LM/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/ptb/LM/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file