mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-14 04:22:21 +00:00
Use collate_fn as class. harmless but not necessary without multiple workers
This commit is contained in:
parent
0d97e689be
commit
a7b61100de
@ -130,7 +130,12 @@ def mask_and_pad(sentence: List[int],
|
||||
# length of masked regions.
|
||||
num_split_points = int(torch.binomial(count=torch.tensor([float(sent_len - num_mask)]),
|
||||
prob=torch.tensor([mask_proportion * inv_mask_length / (1.0 - mask_proportion)])).item())
|
||||
assert num_split_points <= sent_len - num_mask
|
||||
# Somehow this assertion failed, debugging it below.
|
||||
# assert num_split_points <= sent_len - num_mask
|
||||
if num_split_points > sent_len - num_mask:
|
||||
print(f"Warning about num_split_points: {num_split_points} > {sent_len} - {num_mask}")
|
||||
num_split_points = sent_len - num_mask
|
||||
|
||||
assert isinstance(num_split_points, int)
|
||||
|
||||
def split_into_subseqs(length: int , num_subseqs: int) -> List[int]:
|
||||
@ -797,6 +802,13 @@ class LmBatchSampler(torch.utils.data.Sampler):
|
||||
yield self.indices[batch_start:batch_end].tolist()
|
||||
|
||||
|
||||
class CollateFn:
|
||||
def __init__(self, **kwargs):
|
||||
self.extra_args = kwargs
|
||||
|
||||
def __call__(self, sentences: List[List[int]]):
|
||||
return collate_fn(sentences, **self.extra_args)
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -532,14 +532,14 @@ def run(rank, world_size, args):
|
||||
|
||||
train,test = dataset.load_train_test_lm_dataset(params.lm_dataset)
|
||||
|
||||
collate_fn=(lambda x:dataset.collate_fn(x, bos_sym=params.bos_sym,
|
||||
eos_sym=params.eos_sym,
|
||||
blank_sym=params.blank_sym,
|
||||
mask_proportion=0.15,
|
||||
padding_proportion=0.15,
|
||||
randomize_proportion=0.05,
|
||||
inv_mask_length=0.25,
|
||||
unmasked_weight=0.25))
|
||||
collate_fn=dataset.CollateFn(bos_sym=params.bos_sym,
|
||||
eos_sym=params.eos_sym,
|
||||
blank_sym=params.blank_sym,
|
||||
mask_proportion=0.15,
|
||||
padding_proportion=0.15,
|
||||
randomize_proportion=0.05,
|
||||
inv_mask_length=0.25,
|
||||
unmasked_weight=0.25)
|
||||
|
||||
train_sampler = dataset.LmBatchSampler(train,
|
||||
symbols_per_batch=params.symbols_per_batch,
|
||||
@ -551,6 +551,7 @@ def run(rank, world_size, args):
|
||||
train_dl = torch.utils.data.DataLoader(train,
|
||||
batch_sampler=train_sampler,
|
||||
collate_fn=collate_fn)
|
||||
|
||||
valid_dl = torch.utils.data.DataLoader(test,
|
||||
batch_sampler=test_sampler,
|
||||
collate_fn=collate_fn)
|
||||
|
Loading…
x
Reference in New Issue
Block a user