Use collate_fn as class. harmless but not necessary without multiple workers

This commit is contained in:
Daniel Povey 2021-08-25 11:27:47 +08:00
parent 0d97e689be
commit a7b61100de
2 changed files with 22 additions and 9 deletions

View File

@ -130,7 +130,12 @@ def mask_and_pad(sentence: List[int],
# length of masked regions.
num_split_points = int(torch.binomial(count=torch.tensor([float(sent_len - num_mask)]),
prob=torch.tensor([mask_proportion * inv_mask_length / (1.0 - mask_proportion)])).item())
assert num_split_points <= sent_len - num_mask
# Somehow this assertion failed, debugging it below.
# assert num_split_points <= sent_len - num_mask
if num_split_points > sent_len - num_mask:
print(f"Warning about num_split_points: {num_split_points} > {sent_len} - {num_mask}")
num_split_points = sent_len - num_mask
assert isinstance(num_split_points, int)
def split_into_subseqs(length: int , num_subseqs: int) -> List[int]:
@ -797,6 +802,13 @@ class LmBatchSampler(torch.utils.data.Sampler):
yield self.indices[batch_start:batch_end].tolist()
class CollateFn:
def __init__(self, **kwargs):
self.extra_args = kwargs
def __call__(self, sentences: List[List[int]]):
return collate_fn(sentences, **self.extra_args)

View File

@ -532,14 +532,14 @@ def run(rank, world_size, args):
train,test = dataset.load_train_test_lm_dataset(params.lm_dataset)
collate_fn=(lambda x:dataset.collate_fn(x, bos_sym=params.bos_sym,
eos_sym=params.eos_sym,
blank_sym=params.blank_sym,
mask_proportion=0.15,
padding_proportion=0.15,
randomize_proportion=0.05,
inv_mask_length=0.25,
unmasked_weight=0.25))
collate_fn=dataset.CollateFn(bos_sym=params.bos_sym,
eos_sym=params.eos_sym,
blank_sym=params.blank_sym,
mask_proportion=0.15,
padding_proportion=0.15,
randomize_proportion=0.05,
inv_mask_length=0.25,
unmasked_weight=0.25)
train_sampler = dataset.LmBatchSampler(train,
symbols_per_batch=params.symbols_per_batch,
@ -551,6 +551,7 @@ def run(rank, world_size, args):
train_dl = torch.utils.data.DataLoader(train,
batch_sampler=train_sampler,
collate_fn=collate_fn)
valid_dl = torch.utils.data.DataLoader(test,
batch_sampler=test_sampler,
collate_fn=collate_fn)