WIP: Begin to add RNNLM.

Finish the dataset part.
This commit is contained in:
Fangjun Kuang 2021-11-21 15:58:05 +08:00
parent 30c43b7f69
commit 42dcd53361
12 changed files with 997 additions and 1 deletions

View File

@ -38,7 +38,6 @@ def get_args():
"--lang-dir", "--lang-dir",
type=str, type=str,
help="""Input and output directory. help="""Input and output directory.
It should contain the training corpus: transcript_words.txt.
The generated bpe.model is saved to this directory. The generated bpe.model is saved to this directory.
""", """,
) )

18
egs/ptb/LM/README.md Normal file
View File

@ -0,0 +1,18 @@
## Description
(Note: the experiments here are only about language modeling)
ptb is short for Penn Treebank.
About the Penn Treebank corpus:
- This corpus is free for research purposes
- ptb.train.txt: train set
- ptb.valid.txt: development set (should be used just for tuning hyper-parameters, but not for training)
- ptb.test.txt: test set for reporting perplexity
You can download the dataset from one of the following URLs:
- https://github.com/townie/PTB-dataset-from-Tomas-Mikolov-s-webpage
- http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
- https://deepai.org/dataset/penn-treebank

View File

@ -0,0 +1,146 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey
# Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes a `bpe.model` and a text file such as
`download/ptb.train.txt`,
and outputs the LM training data to a supplied directory such
as data/bpe_500. The format is as follows:
It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a
representation of a dict with the following format:
'words' -> a k2.RaggedTensor of two axes [word][token] with dtype torch.int32
containing the BPE representations of each word, indexed by
integer word ID. (These integer word IDS are present in
'lm_data'). The sentencepiece object can be used to turn the
words and BPE units into string form.
'sentences' -> a k2.RaggedTensor of two axes [sentence][word] with dtype
torch.int32 containing all the sentences, as word-ids (we don't
output the string form of this directly but it can be worked out
together with 'words' and the bpe.model).
'sentence_lengths' -> a 1-D torch.Tensor of dtype torch.int32, containing
number of BPE tokens of each sentence.
"""
import argparse
import logging
from pathlib import Path
import k2
import sentencepiece as spm
import torch
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--bpe-model",
type=str,
help="Input BPE model, e.g. data/bpe_500/bpe.model",
)
parser.add_argument(
"--lm-data",
type=str,
help="""Input LM training data as text, e.g.
download/pb.train.txt""",
)
parser.add_argument(
"--lm-archive",
type=str,
help="""Path to output archive, e.g. data/bpe_500/lm_data.pt;
look at the source of this script to see the format.""",
)
return parser.parse_args()
def main():
args = get_args()
if Path(args.lm_archive).exists():
logging.warning(f"{args.lm_archive} exists - skipping")
return
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
# word2index is a dictionary from words to integer ids. No need to reserve
# space for epsilon, etc.; the words are just used as a convenient way to
# compress the sequences of BPE pieces.
word2index = dict()
word2bpe = [] # Will be a list-of-list-of-int, representing BPE pieces.
# ptb.train.txt has already converted oov words to <unk>
word2bpe.append([sp.unk_id()])
word2index["<unk>"] = 0
sentences = [] # Will be a list-of-list-of-int, representing word-ids.
with open(args.lm_data) as f:
while True:
line = f.readline()
if line == "":
break
line_words = line.split()
for w in line_words:
if w not in word2index:
w_bpe = sp.encode(w)
word2index[w] = len(word2bpe)
word2bpe.append(w_bpe)
sentences.append([word2index[w] for w in line_words])
words = k2.ragged.RaggedTensor(word2bpe)
sentences = k2.ragged.RaggedTensor(sentences)
output = dict(words=words, sentences=sentences)
num_sentences = sentences.dim0
sentence_lengths = [0] * num_sentences
for i in range(num_sentences):
word_ids = sentences[i]
# NOTE: If word_ids is a tensor with only 1 entry,
# token_ids is a torch.Tensor
token_ids = words[word_ids]
if isinstance(token_ids, k2.RaggedTensor):
token_ids = token_ids.values
# token_ids is a 1-D tensor containing the BPE tokens
# of the current sentence
sentence_lengths[i] = token_ids.numel()
output["sentence_lengths"] = torch.tensor(
sentence_lengths, dtype=torch.int32
)
torch.save(output, args.lm_archive)
logging.info(f"Saved to {args.lm_archive}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,143 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file takes as input the filename of LM training data
generated by ./local/prepare_lm_training_data.py and sorts
it by sentence length.
Sentence length equals to the number of BPE tokens in a sentence.
"""
import argparse
import logging
from pathlib import Path
import k2
import numpy as np
import torch
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--in-lm-data",
type=str,
help="Input LM training data, e.g., data/bpe_500/lm_data.pt",
)
parser.add_argument(
"--out-lm-data",
type=str,
help="Input LM training data, e.g., data/bpe_500/sorted_lm_data.pt",
)
parser.add_argument(
"--out-statistics",
type=str,
help="Statistics about LM training data., data/bpe_500/statistics.txt",
)
return parser.parse_args()
def main():
args = get_args()
in_lm_data = Path(args.in_lm_data)
out_lm_data = Path(args.out_lm_data)
assert in_lm_data.is_file(), f"{in_lm_data}"
if out_lm_data.is_file():
logging.warning(f"{out_lm_data} exists - skipping")
return
data = torch.load(in_lm_data)
words2bpe = data["words"]
sentences = data["sentences"]
sentence_lengths = data["sentence_lengths"]
num_sentences = sentences.dim0
assert num_sentences == sentence_lengths.numel(), (
num_sentences,
sentence_lengths.numel(),
)
indices = torch.argsort(sentence_lengths, descending=True)
sorted_sentences = sentences[indices.to(torch.int32)]
sorted_sentence_lengths = sentence_lengths[indices]
# Check that sentences are ordered by length
assert num_sentences == sorted_sentences.dim0, (
num_sentences,
sorted_sentences.dim0,
)
cur = None
for i in range(num_sentences):
word_ids = sorted_sentences[i]
token_ids = words2bpe[word_ids]
if isinstance(token_ids, k2.RaggedTensor):
token_ids = token_ids.values
if cur is not None:
assert cur >= token_ids.numel(), (cur, token_ids.numel())
cur = token_ids.numel()
assert cur == sorted_sentence_lengths[i]
data["sentences"] = sorted_sentences
data["sentence_lengths"] = sorted_sentence_lengths
torch.save(data, args.out_lm_data)
logging.info(f"Saved to {args.out_lm_data}")
statistics = Path(args.out_statistics)
# Write statistics
num_words = sorted_sentences.numel()
num_tokens = sentence_lengths.sum().item()
max_sentence_length = sentence_lengths[indices[0]]
min_sentence_length = sentence_lengths[indices[-1]]
step = 10
hist, bins = np.histogram(
sentence_lengths.numpy(),
bins=np.arange(1, max_sentence_length + step, step),
)
histogram = np.stack((bins[:-1], hist)).transpose()
with open(statistics, "w") as f:
f.write(f"num_sentences: {num_sentences}\n")
f.write(f"num_words: {num_words}\n")
f.write(f"num_tokens: {num_tokens}\n")
f.write(f"max_sentence_length: {max_sentence_length}\n")
f.write(f"min_sentence_length: {min_sentence_length}\n")
f.write("histogram:\n")
f.write(" bin count percent\n")
for row in histogram:
f.write(
f"{int(row[0]):>5} {int(row[1]):>5} "
f"{100.*row[1]/num_sentences:.3f}%\n"
)
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,62 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
import sentencepiece as spm
import torch
def main():
lm_training_data = Path("./data/bpe_500/lm_data.pt")
bpe_model = Path("./data/bpe_500/bpe.model")
if not lm_training_data.exists():
logging.warning(f"{lm_training_data} does not exist - skipping")
return
if not bpe_model.exists():
logging.warning(f"{bpe_model} does not exist - skipping")
return
sp = spm.SentencePieceProcessor()
sp.load(str(bpe_model))
data = torch.load(lm_training_data)
words2bpe = data["words"]
sentences = data["sentences"]
ss = []
unk = sp.decode(sp.unk_id()).strip()
for i in range(10):
s = sp.decode(words2bpe[sentences[i]].values.tolist())
s = s.replace(unk, "<unk>")
ss.append(s)
for s in ss:
print(s)
# You can compare the output with the first 10 lines of ptb.train.txt
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,95 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# You can install sentencepiece via:
#
# pip install sentencepiece
#
# Due to an issue reported in
# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
#
# Please install a version >=0.1.96
import argparse
import shutil
from pathlib import Path
import sentencepiece as spm
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--out-dir",
type=str,
help="""Input and output directory.
The generated bpe.model is saved to this directory.
""",
)
parser.add_argument(
"--transcript",
type=str,
help="Training transcript.",
)
parser.add_argument(
"--vocab-size",
type=int,
help="Vocabulary size for BPE training",
)
return parser.parse_args()
def main():
args = get_args()
vocab_size = args.vocab_size
model_type = "unigram"
model_prefix = f"{args.out_dir}/{model_type}_{vocab_size}"
train_text = args.transcript
character_coverage = 1.0
input_sentence_size = 100000000
user_defined_symbols = ["<blk>", "<sos/eos>"]
unk_id = len(user_defined_symbols)
# Note: unk_id is fixed to 2.
# If you change it, you should also change other
# places that are using it.
model_file = Path(model_prefix + ".model")
if not model_file.is_file():
spm.SentencePieceTrainer.train(
input=train_text,
vocab_size=vocab_size,
model_type=model_type,
model_prefix=model_prefix,
input_sentence_size=input_sentence_size,
character_coverage=character_coverage,
user_defined_symbols=user_defined_symbols,
unk_id=unk_id,
bos_id=-1,
eos_id=-1,
)
shutil.copyfile(model_file, f"{args.out_dir}/bpe.model")
if __name__ == "__main__":
main()

95
egs/ptb/LM/prepare.sh Executable file
View File

@ -0,0 +1,95 @@
#!/usr/bin/env bash
set -eou pipefail
nj=15
stage=-1
stop_stage=100
dl_dir=$PWD/download
# The following files will be downloaded to $dl_dir
# - ptb.train.txt
# - ptb.valid.txt
# - ptb.test.txt
. shared/parse_options.sh || exit 1
# vocab size for sentence piece models.
# It will generate data/bpe_xxx, data/bpe_yyy
# if the array contains xxx, yyy
vocab_sizes=(
500
1000
2000
5000
)
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
mkdir -p $dl_dir
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "Stage -1: Download data"
if [ ! -f $dl_dir/.complete ]; then
url=https://raw.githubusercontent.com/townie/PTB-dataset-from-Tomas-Mikolov-s-webpage/master/data/
wget --no-verbose --directory-prefix $dl_dir $url/ptb.train.txt
wget --no-verbose --directory-prefix $dl_dir $url/ptb.valid.txt
wget --no-verbose --directory-prefix $dl_dir $url/ptb.test.txt
touch $dl_dir/.complete
fi
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Train BPE model"
for vocab_size in ${vocab_sizes[@]}; do
out_dir=data/bpe_${vocab_size}
mkdir -p $out_dir
./local/train_bpe_model.py \
--out-dir $out_dir \
--vocab-size $vocab_size \
--transcript $dl_dir/ptb.train.txt
done
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Generate LM training data"
# Note: ptb.train.txt has already been normalized
for vocab_size in ${vocab_sizes[@]}; do
out_dir=data/bpe_${vocab_size}
mkdir -p $out_dir
./local/prepare_lm_training_data.py \
--bpe-model $out_dir/bpe.model \
--lm-data $dl_dir/ptb.train.txt \
--lm-archive $out_dir/lm_data.pt
done
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Sort LM training data"
# Sort LM training data generated in stage 1
# by sentence length in descending order
# for ease of training.
#
# Sentence length equals to the number of BPE tokens
# in a sentence.
for vocab_size in ${vocab_sizes[@]}; do
out_dir=data/bpe_${vocab_size}
mkdir -p $out_dir
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_data.pt \
--out-lm-data $out_dir/sorted_lm_data.pt \
--out-statistics $out_dir/statistics.txt
done
fi

View File

View File

@ -0,0 +1,260 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
import k2
import torch
class LmDataset(torch.utils.data.Dataset):
def __init__(
self,
sentences: k2.RaggedTensor,
words: k2.RaggedTensor,
sentence_lengths: torch.Tensor,
max_sent_len: int,
batch_size: int,
):
"""
Args:
sentences:
A ragged tensor of dtype torch.int32 with 2 axes [sentence][word].
words:
A ragged tensor of dtype torch.int32 with 2 axes [word][token].
sentence_lengths:
A 1-D tensor of dtype torch.int32 containing number of tokens
of each sentence.
max_sent_len:
Maximum sentence length. It is used to change the batch size
dynamically. In general, we try to keep the product of
"max_sent_len in a batch" and "num_of_sent in a batch" being
a constant.
batch_size:
The expected batch size. It is changed dynamically according
to the "max_sent_len".
See `../local/prepare_lm_training_data.py` for how `sentences` and
`words` are generated. We assume that `sentences` are sorted by length.
See `../local/sort_lm_training_data.py`.
"""
super().__init__()
self.sentences = sentences
self.words = words
sentence_lengths = sentence_lengths.tolist()
assert batch_size > 0, batch_size
assert max_sent_len > 1, max_sent_len
batch_indexes = []
num_sentences = sentences.dim0
cur = 0
while cur < num_sentences:
sz = sentence_lengths[cur] // max_sent_len + 1
# Assume the current sentence has 3 * max_sent_len tokens,
# in the worst case, the subsequent sentences also have
# this number of tokens, we should reduce the batch size
# so that this batch will not contain too many tokens
actucal_batch_size = batch_size // sz + 1
actucal_batch_size = min(actucal_batch_size, batch_size)
end = cur + actucal_batch_size
end = min(end, num_sentences)
this_batch_indexes = torch.arange(cur, end).tolist()
batch_indexes.append(this_batch_indexes)
cur = end
assert batch_indexes[-1][-1] == num_sentences - 1
self.batch_indexes = k2.RaggedTensor(batch_indexes)
def __len__(self) -> int:
"""Return number of batches in this dataset"""
return self.batch_indexes.dim0
def __getitem__(self, i: int) -> k2.RaggedTensor:
"""Get the i'th batch in this dataset
Return a ragged tensor with 2 axes [sentence][token].
"""
assert 0 <= i < len(self), i
# indexes is a 1-D tensor containing sentence indexes
indexes = self.batch_indexes[i]
# sentence_words is a ragged tensor with 2 axes
# [sentence][word]
sentence_words = self.sentences[indexes]
# in case indexes contains only 1 entry, the returned
# sentence_words is a 1-D tensor, we have to convert
# it to a ragged tensor
if isinstance(sentence_words, torch.Tensor):
sentence_words = k2.RaggedTensor(sentence_words.unsqueeze(0))
# sentence_word_tokens is a ragged tensor with 3 axes
# [sentence][word][token]
sentence_word_tokens = self.words.index(sentence_words)
assert sentence_word_tokens.num_axes == 3
sentence_tokens = sentence_word_tokens.remove_axis(1)
return sentence_tokens
def concat(
ragged: k2.RaggedTensor, value: int, direction: str
) -> k2.RaggedTensor:
"""Prepend a value to the beginning of each sublist or append a value.
to the end of each sublist.
Args:
ragged:
A ragged tensor with two axes.
value:
The value to prepend or append.
direction:
It can be either "left" or "right". If it is "left", we
prepend the value to the beginning of each sublist;
if it is "right", we append the value to the end of each
sublist.
Returns:
Return a new ragged tensor, whose sublists either start with
or end with the given value.
>>> a = k2.RaggedTensor([[1, 3], [5]])
>>> a
[ [ 1 3 ] [ 5 ] ]
>>> concat(a, value=0, direction="left")
[ [ 0 1 3 ] [ 0 5 ] ]
>>> concat(a, value=0, direction="right")
[ [ 1 3 0 ] [ 5 0 ] ]
"""
dtype = ragged.dtype
device = ragged.device
assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}"
pad_values = torch.full(
size=(ragged.tot_size(0), 1),
fill_value=value,
device=device,
dtype=dtype,
)
pad = k2.RaggedTensor(pad_values)
if direction == "left":
ans = k2.ragged.cat([pad, ragged], axis=1)
elif direction == "right":
ans = k2.ragged.cat([ragged, pad], axis=1)
else:
raise ValueError(
f'Unsupported direction: {direction}. " \
"Expect either "left" or "right"'
)
return ans
def add_sos(ragged: k2.RaggedTensor, sos_id: int) -> k2.RaggedTensor:
"""Add SOS to each sublist.
Args:
ragged:
A ragged tensor with two axes.
sos_id:
The ID of the SOS symbol.
Returns:
Return a new ragged tensor, where each sublist starts with SOS.
>>> a = k2.RaggedTensor([[1, 3], [5]])
>>> a
[ [ 1 3 ] [ 5 ] ]
>>> add_sos(a, sos_id=0)
[ [ 0 1 3 ] [ 0 5 ] ]
"""
return concat(ragged, sos_id, direction="left")
def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor:
"""Add EOS to each sublist.
Args:
ragged:
A ragged tensor with two axes.
eos_id:
The ID of the EOS symbol.
Returns:
Return a new ragged tensor, where each sublist ends with EOS.
>>> a = k2.RaggedTensor([[1, 3], [5]])
>>> a
[ [ 1 3 ] [ 5 ] ]
>>> add_eos(a, eos_id=0)
[ [ 1 3 0 ] [ 5 0 ] ]
"""
return concat(ragged, eos_id, direction="right")
class LmDatasetCollate:
def __init__(self, sos_id: int, eos_id: int, blank_id: int):
"""
Args:
sos_id:
Token ID of the SOS symbol.
eos_id:
Token ID of the EOS symbol.
blank_id:
Token ID of the blank symbol.
"""
self.sos_id = sos_id
self.eos_id = eos_id
self.blank_id = blank_id
def __call__(
self, batch: List[k2.RaggedTensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Return a tuple containing 3 tensors:
- x, a 2-D tensor of dtype torch.int32; each row contains tokens
for a sentence starting with `self.sos_id`. It is padded to
the max sentence length with `self.blank_id`.
- x, a 2-D tensor of dtype torch.int32; each row contains tokens
for a sentence ending with `self.eos_id` before padding.
Then it is padded to the max sentence length with
`self.blank_id`.
- lengths, a 2-D tensor of dtype torch.int32, containing the number of
tokens of each sentence before padding.
"""
# The batching stuff has already been done in LmDataset
assert len(batch) == 1
sentence_tokens = batch[0]
row_splits = sentence_tokens.shape.row_splits(1)
sentence_token_lengths = row_splits[1:] - row_splits[:-1]
sentence_tokens_with_sos = add_sos(sentence_tokens, self.sos_id)
sentence_tokens_with_eos = add_eos(sentence_tokens, self.eos_id)
x = sentence_tokens_with_sos.pad(
mode="constant", padding_value=self.blank_id
)
y = sentence_tokens_with_eos.pad(
mode="constant", padding_value=self.blank_id
)
sentence_token_lengths += 1 # plus 1 since we added a SOS
return x, y, sentence_token_lengths

View File

@ -0,0 +1,74 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import k2
import torch
from rnn_lm.dataset import LmDataset, LmDatasetCollate
def main():
sentences = k2.RaggedTensor(
[[0, 1, 2], [1, 0, 1], [0, 1], [1, 3, 0, 2, 0], [3], [0, 2, 1]]
)
words = k2.RaggedTensor([[3, 6], [2, 8, 9, 3], [5], [5, 6, 7, 8, 9]])
num_sentences = sentences.dim0
sentence_lengths = [0] * num_sentences
for i in range(num_sentences):
word_ids = sentences[i]
# NOTE: If word_ids is a tensor with only 1 entry,
# token_ids is a torch.Tensor
token_ids = words[word_ids]
if isinstance(token_ids, k2.RaggedTensor):
token_ids = token_ids.values
# token_ids is a 1-D tensor containing the BPE tokens
# of the current sentence
sentence_lengths[i] = token_ids.numel()
sentence_lengths = torch.tensor(sentence_lengths, dtype=torch.int32)
indices = torch.argsort(sentence_lengths, descending=True)
sentences = sentences[indices.to(torch.int32)]
sentence_lengths = sentence_lengths[indices]
dataset = LmDataset(
sentences=sentences,
words=words,
sentence_lengths=sentence_lengths,
max_sent_len=3,
batch_size=4,
)
print(dataset.sentences)
print(dataset.words)
print(dataset.batch_indexes)
print(len(dataset))
collate_fn = LmDatasetCollate(sos_id=1, eos_id=-1, blank_id=0)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=1, collate_fn=collate_fn
)
for i in dataloader:
print(i)
# I've checked the output manually; the output is as expected.
if __name__ == "__main__":
main()

View File

@ -0,0 +1,103 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import k2
import torch
import torch.multiprocessing as mp
from rnn_lm.dataset import LmDataset, LmDatasetCollate
from torch import distributed as dist
def generate_data():
sentences = k2.RaggedTensor(
[[0, 1, 2], [1, 0, 1], [0, 1], [1, 3, 0, 2, 0], [3], [0, 2, 1]]
)
words = k2.RaggedTensor([[3, 6], [2, 8, 9, 3], [5], [5, 6, 7, 8, 9]])
num_sentences = sentences.dim0
sentence_lengths = [0] * num_sentences
for i in range(num_sentences):
word_ids = sentences[i]
# NOTE: If word_ids is a tensor with only 1 entry,
# token_ids is a torch.Tensor
token_ids = words[word_ids]
if isinstance(token_ids, k2.RaggedTensor):
token_ids = token_ids.values
# token_ids is a 1-D tensor containing the BPE tokens
# of the current sentence
sentence_lengths[i] = token_ids.numel()
sentence_lengths = torch.tensor(sentence_lengths, dtype=torch.int32)
indices = torch.argsort(sentence_lengths, descending=True)
sentences = sentences[indices.to(torch.int32)]
sentence_lengths = sentence_lengths[indices]
return sentences, words, sentence_lengths
def run(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12352"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
sentences, words, sentence_lengths = generate_data()
dataset = LmDataset(
sentences=sentences,
words=words,
sentence_lengths=sentence_lengths,
max_sent_len=3,
batch_size=4,
)
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, shuffle=True, drop_last=False
)
collate_fn = LmDatasetCollate(sos_id=1, eos_id=-1, blank_id=0)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
collate_fn=collate_fn,
sampler=sampler,
shuffle=False,
)
for i in dataloader:
print(f"rank: {rank}", i)
dist.destroy_process_group()
def main():
world_size = 2
mp.spawn(run, args=(world_size,), nprocs=world_size, join=True)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

1
egs/ptb/LM/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared/