diff --git a/egs/librispeech/ASR/conformer_lm/dataset.py b/egs/librispeech/ASR/conformer_lm/dataset.py index fcf7f39f0..dd3ab8deb 100644 --- a/egs/librispeech/ASR/conformer_lm/dataset.py +++ b/egs/librispeech/ASR/conformer_lm/dataset.py @@ -130,7 +130,12 @@ def mask_and_pad(sentence: List[int], # length of masked regions. num_split_points = int(torch.binomial(count=torch.tensor([float(sent_len - num_mask)]), prob=torch.tensor([mask_proportion * inv_mask_length / (1.0 - mask_proportion)])).item()) - assert num_split_points <= sent_len - num_mask + # Somehow this assertion failed, debugging it below. + # assert num_split_points <= sent_len - num_mask + if num_split_points > sent_len - num_mask: + print(f"Warning about num_split_points: {num_split_points} > {sent_len} - {num_mask}") + num_split_points = sent_len - num_mask + assert isinstance(num_split_points, int) def split_into_subseqs(length: int , num_subseqs: int) -> List[int]: @@ -797,6 +802,13 @@ class LmBatchSampler(torch.utils.data.Sampler): yield self.indices[batch_start:batch_end].tolist() +class CollateFn: + def __init__(self, **kwargs): + self.extra_args = kwargs + + def __call__(self, sentences: List[List[int]]): + return collate_fn(sentences, **self.extra_args) + diff --git a/egs/librispeech/ASR/conformer_lm/train.py b/egs/librispeech/ASR/conformer_lm/train.py index e8a5c8888..0b7e49db5 100755 --- a/egs/librispeech/ASR/conformer_lm/train.py +++ b/egs/librispeech/ASR/conformer_lm/train.py @@ -532,14 +532,14 @@ def run(rank, world_size, args): train,test = dataset.load_train_test_lm_dataset(params.lm_dataset) - collate_fn=(lambda x:dataset.collate_fn(x, bos_sym=params.bos_sym, - eos_sym=params.eos_sym, - blank_sym=params.blank_sym, - mask_proportion=0.15, - padding_proportion=0.15, - randomize_proportion=0.05, - inv_mask_length=0.25, - unmasked_weight=0.25)) + collate_fn=dataset.CollateFn(bos_sym=params.bos_sym, + eos_sym=params.eos_sym, + blank_sym=params.blank_sym, + mask_proportion=0.15, + padding_proportion=0.15, + randomize_proportion=0.05, + inv_mask_length=0.25, + unmasked_weight=0.25) train_sampler = dataset.LmBatchSampler(train, symbols_per_batch=params.symbols_per_batch, @@ -551,6 +551,7 @@ def run(rank, world_size, args): train_dl = torch.utils.data.DataLoader(train, batch_sampler=train_sampler, collate_fn=collate_fn) + valid_dl = torch.utils.data.DataLoader(test, batch_sampler=test_sampler, collate_fn=collate_fn)