mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 03:22:19 +00:00
219 lines
7.5 KiB
Python
219 lines
7.5 KiB
Python
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from typing import List, Tuple
|
|
|
|
import k2
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
|
|
from icefall.utils import AttributeDict, add_eos, add_sos
|
|
|
|
|
|
class LmDataset(torch.utils.data.Dataset):
|
|
def __init__(
|
|
self,
|
|
sentences: k2.RaggedTensor,
|
|
words: k2.RaggedTensor,
|
|
sentence_lengths: torch.Tensor,
|
|
max_sent_len: int,
|
|
batch_size: int,
|
|
):
|
|
"""
|
|
Args:
|
|
sentences:
|
|
A ragged tensor of dtype torch.int32 with 2 axes [sentence][word].
|
|
words:
|
|
A ragged tensor of dtype torch.int32 with 2 axes [word][token].
|
|
sentence_lengths:
|
|
A 1-D tensor of dtype torch.int32 containing number of tokens
|
|
of each sentence.
|
|
max_sent_len:
|
|
Maximum sentence length. It is used to change the batch size
|
|
dynamically. In general, we try to keep the product of
|
|
"max_sent_len in a batch" and "num_of_sent in a batch" being
|
|
a constant.
|
|
batch_size:
|
|
The expected batch size. It is changed dynamically according
|
|
to the "max_sent_len".
|
|
|
|
See `../local/prepare_lm_training_data.py` for how `sentences` and
|
|
`words` are generated. We assume that `sentences` are sorted by length.
|
|
See `../local/sort_lm_training_data.py`.
|
|
"""
|
|
super().__init__()
|
|
self.sentences = sentences
|
|
self.words = words
|
|
|
|
sentence_lengths = sentence_lengths.tolist()
|
|
|
|
assert batch_size > 0, batch_size
|
|
assert max_sent_len > 1, max_sent_len
|
|
batch_indexes = []
|
|
num_sentences = sentences.dim0
|
|
cur = 0
|
|
while cur < num_sentences:
|
|
sz = sentence_lengths[cur] // max_sent_len + 1
|
|
# Assume the current sentence has 3 * max_sent_len tokens,
|
|
# in the worst case, the subsequent sentences also have
|
|
# this number of tokens, we should reduce the batch size
|
|
# so that this batch will not contain too many tokens
|
|
actual_batch_size = batch_size // sz + 1
|
|
actual_batch_size = min(actual_batch_size, batch_size)
|
|
end = cur + actual_batch_size
|
|
end = min(end, num_sentences)
|
|
this_batch_indexes = torch.arange(cur, end).tolist()
|
|
batch_indexes.append(this_batch_indexes)
|
|
cur = end
|
|
assert batch_indexes[-1][-1] == num_sentences - 1
|
|
|
|
self.batch_indexes = k2.RaggedTensor(batch_indexes)
|
|
|
|
def __len__(self) -> int:
|
|
"""Return number of batches in this dataset"""
|
|
return self.batch_indexes.dim0
|
|
|
|
def __getitem__(self, i: int) -> k2.RaggedTensor:
|
|
"""Get the i'th batch in this dataset
|
|
Return a ragged tensor with 2 axes [sentence][token].
|
|
"""
|
|
assert 0 <= i < len(self), i
|
|
|
|
# indexes is a 1-D tensor containing sentence indexes
|
|
indexes = self.batch_indexes[i]
|
|
|
|
# sentence_words is a ragged tensor with 2 axes
|
|
# [sentence][word]
|
|
sentence_words = self.sentences[indexes]
|
|
|
|
# in case indexes contains only 1 entry, the returned
|
|
# sentence_words is a 1-D tensor, we have to convert
|
|
# it to a ragged tensor
|
|
if isinstance(sentence_words, torch.Tensor):
|
|
sentence_words = k2.RaggedTensor(sentence_words.unsqueeze(0))
|
|
|
|
# sentence_word_tokens is a ragged tensor with 3 axes
|
|
# [sentence][word][token]
|
|
sentence_word_tokens = self.words.index(sentence_words)
|
|
assert sentence_word_tokens.num_axes == 3
|
|
|
|
sentence_tokens = sentence_word_tokens.remove_axis(1)
|
|
return sentence_tokens
|
|
|
|
|
|
class LmDatasetCollate:
|
|
def __init__(self, sos_id: int, eos_id: int, blank_id: int):
|
|
"""
|
|
Args:
|
|
sos_id:
|
|
Token ID of the SOS symbol.
|
|
eos_id:
|
|
Token ID of the EOS symbol.
|
|
blank_id:
|
|
Token ID of the blank symbol.
|
|
"""
|
|
self.sos_id = sos_id
|
|
self.eos_id = eos_id
|
|
self.blank_id = blank_id
|
|
|
|
def __call__(
|
|
self, batch: List[k2.RaggedTensor]
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""Return a tuple containing 3 tensors:
|
|
|
|
- x, a 2-D tensor of dtype torch.int32; each row contains tokens
|
|
for a sentence starting with `self.sos_id`. It is padded to
|
|
the max sentence length with `self.blank_id`.
|
|
|
|
- y, a 2-D tensor of dtype torch.int32; each row contains tokens
|
|
for a sentence ending with `self.eos_id` before padding.
|
|
Then it is padded to the max sentence length with
|
|
`self.blank_id`.
|
|
|
|
- lengths, a 2-D tensor of dtype torch.int32, containing the number of
|
|
tokens of each sentence before padding.
|
|
"""
|
|
# The batching stuff has already been done in LmDataset
|
|
assert len(batch) == 1
|
|
sentence_tokens = batch[0]
|
|
row_splits = sentence_tokens.shape.row_splits(1)
|
|
sentence_token_lengths = row_splits[1:] - row_splits[:-1]
|
|
sentence_tokens_with_sos = add_sos(sentence_tokens, self.sos_id)
|
|
sentence_tokens_with_eos = add_eos(sentence_tokens, self.eos_id)
|
|
|
|
x = sentence_tokens_with_sos.pad(
|
|
mode="constant", padding_value=self.blank_id
|
|
)
|
|
y = sentence_tokens_with_eos.pad(
|
|
mode="constant", padding_value=self.blank_id
|
|
)
|
|
sentence_token_lengths += 1 # plus 1 since we added a SOS
|
|
|
|
return x.to(torch.int64), y.to(torch.int64), sentence_token_lengths
|
|
|
|
|
|
def get_dataloader(
|
|
filename: str,
|
|
is_distributed: bool,
|
|
params: AttributeDict,
|
|
) -> torch.utils.data.DataLoader:
|
|
"""Get dataloader for LM training.
|
|
|
|
Args:
|
|
filename:
|
|
Path to the file containing LM data. The file is assumed to
|
|
be generated by `../local/sort_lm_training_data.py`.
|
|
is_distributed:
|
|
True if using DDP training. False otherwise.
|
|
params:
|
|
Set `get_params()` from `rnn_lm/train.py`
|
|
Returns:
|
|
Return a dataloader containing the LM data.
|
|
"""
|
|
lm_data = torch.load(filename)
|
|
|
|
words = lm_data["words"]
|
|
sentences = lm_data["sentences"]
|
|
sentence_lengths = lm_data["sentence_lengths"]
|
|
|
|
dataset = LmDataset(
|
|
sentences=sentences,
|
|
words=words,
|
|
sentence_lengths=sentence_lengths,
|
|
max_sent_len=params.max_sent_len,
|
|
batch_size=params.batch_size,
|
|
)
|
|
if is_distributed:
|
|
sampler = DistributedSampler(dataset, shuffle=True, drop_last=False)
|
|
else:
|
|
sampler = None
|
|
|
|
collate_fn = LmDatasetCollate(
|
|
sos_id=params.sos_id,
|
|
eos_id=params.eos_id,
|
|
blank_id=params.blank_id,
|
|
)
|
|
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=1,
|
|
collate_fn=collate_fn,
|
|
sampler=sampler,
|
|
shuffle=sampler is None,
|
|
)
|
|
return dataloader
|