diff --git a/egs/librispeech/ASR/conformer_lm/dataset.py b/egs/librispeech/ASR/conformer_lm/dataset.py index dd3ab8deb..4f466a9e1 100644 --- a/egs/librispeech/ASR/conformer_lm/dataset.py +++ b/egs/librispeech/ASR/conformer_lm/dataset.py @@ -116,7 +116,7 @@ def mask_and_pad(sentence: List[int], num_pad -= max(0, sent_len + 2 + num_pad - seq_len) if num_mask + num_pad == 0: - num_mask += 1 + num_pad += 1 # num_split_points is the number of times we split the (masked+padded) # region, so the total number of (masking+padding) subsequences will be @@ -131,10 +131,7 @@ def mask_and_pad(sentence: List[int], num_split_points = int(torch.binomial(count=torch.tensor([float(sent_len - num_mask)]), prob=torch.tensor([mask_proportion * inv_mask_length / (1.0 - mask_proportion)])).item()) # Somehow this assertion failed, debugging it below. - # assert num_split_points <= sent_len - num_mask - if num_split_points > sent_len - num_mask: - print(f"Warning about num_split_points: {num_split_points} > {sent_len} - {num_mask}") - num_split_points = sent_len - num_mask + assert num_split_points <= sent_len - num_mask assert isinstance(num_split_points, int) diff --git a/egs/librispeech/ASR/conformer_lm/test_dataset_empty.py b/egs/librispeech/ASR/conformer_lm/test_dataset_empty.py new file mode 100644 index 000000000..7e933f07b --- /dev/null +++ b/egs/librispeech/ASR/conformer_lm/test_dataset_empty.py @@ -0,0 +1,39 @@ +import k2 +import torch +import _k2 +import dataset +from dataset import LmDataset +import os +from torch import multiprocessing as mp +import torch.distributed as dist + +def local_collate_fn(sentences): + return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=False) + +x = _k2.RaggedInt('[[1]]') # make sure library initialized? + +if __name__ == '__main__': + + mp.set_start_method('spawn') + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12344" + + dist.init_process_group(backend="nccl", group_name="main", + rank=0, world_size=1) + + words = k2.RaggedInt('[[0][1 2]]') + sentences = k2.RaggedInt('[[1][][][][][]]') + + train = LmDataset(sentences, words) + + + sampler = dataset.LmBatchSampler(train, symbols_per_batch=10, world_size=1, rank=0) + + a = iter(sampler) + print(str(next(a))) + + train_dl = torch.utils.data.DataLoader(train, batch_sampler=sampler, + collate_fn=local_collate_fn, + num_workers=0) + x = iter(train_dl) + print(str(next(x)))