Get dataset to work for empty input sentences; test it

This commit is contained in:
Daniel Povey 2021-08-25 15:54:36 +08:00
parent a7b61100de
commit d045831a4f
2 changed files with 41 additions and 5 deletions

View File

@ -116,7 +116,7 @@ def mask_and_pad(sentence: List[int],
num_pad -= max(0, sent_len + 2 + num_pad - seq_len) num_pad -= max(0, sent_len + 2 + num_pad - seq_len)
if num_mask + num_pad == 0: if num_mask + num_pad == 0:
num_mask += 1 num_pad += 1
# num_split_points is the number of times we split the (masked+padded) # num_split_points is the number of times we split the (masked+padded)
# region, so the total number of (masking+padding) subsequences will be # region, so the total number of (masking+padding) subsequences will be
@ -131,10 +131,7 @@ def mask_and_pad(sentence: List[int],
num_split_points = int(torch.binomial(count=torch.tensor([float(sent_len - num_mask)]), num_split_points = int(torch.binomial(count=torch.tensor([float(sent_len - num_mask)]),
prob=torch.tensor([mask_proportion * inv_mask_length / (1.0 - mask_proportion)])).item()) prob=torch.tensor([mask_proportion * inv_mask_length / (1.0 - mask_proportion)])).item())
# Somehow this assertion failed, debugging it below. # Somehow this assertion failed, debugging it below.
# assert num_split_points <= sent_len - num_mask assert num_split_points <= sent_len - num_mask
if num_split_points > sent_len - num_mask:
print(f"Warning about num_split_points: {num_split_points} > {sent_len} - {num_mask}")
num_split_points = sent_len - num_mask
assert isinstance(num_split_points, int) assert isinstance(num_split_points, int)

View File

@ -0,0 +1,39 @@
import k2
import torch
import _k2
import dataset
from dataset import LmDataset
import os
from torch import multiprocessing as mp
import torch.distributed as dist
def local_collate_fn(sentences):
return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=False)
x = _k2.RaggedInt('[[1]]') # make sure library initialized?
if __name__ == '__main__':
mp.set_start_method('spawn')
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12344"
dist.init_process_group(backend="nccl", group_name="main",
rank=0, world_size=1)
words = k2.RaggedInt('[[0][1 2]]')
sentences = k2.RaggedInt('[[1][][][][][]]')
train = LmDataset(sentences, words)
sampler = dataset.LmBatchSampler(train, symbols_per_batch=10, world_size=1, rank=0)
a = iter(sampler)
print(str(next(a)))
train_dl = torch.utils.data.DataLoader(train, batch_sampler=sampler,
collate_fn=local_collate_fn,
num_workers=0)
x = iter(train_dl)
print(str(next(x)))