mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Get dataset to work for empty input sentences; test it
This commit is contained in:
parent
a7b61100de
commit
d045831a4f
@ -116,7 +116,7 @@ def mask_and_pad(sentence: List[int],
|
||||
num_pad -= max(0, sent_len + 2 + num_pad - seq_len)
|
||||
|
||||
if num_mask + num_pad == 0:
|
||||
num_mask += 1
|
||||
num_pad += 1
|
||||
|
||||
# num_split_points is the number of times we split the (masked+padded)
|
||||
# region, so the total number of (masking+padding) subsequences will be
|
||||
@ -131,10 +131,7 @@ def mask_and_pad(sentence: List[int],
|
||||
num_split_points = int(torch.binomial(count=torch.tensor([float(sent_len - num_mask)]),
|
||||
prob=torch.tensor([mask_proportion * inv_mask_length / (1.0 - mask_proportion)])).item())
|
||||
# Somehow this assertion failed, debugging it below.
|
||||
# assert num_split_points <= sent_len - num_mask
|
||||
if num_split_points > sent_len - num_mask:
|
||||
print(f"Warning about num_split_points: {num_split_points} > {sent_len} - {num_mask}")
|
||||
num_split_points = sent_len - num_mask
|
||||
assert num_split_points <= sent_len - num_mask
|
||||
|
||||
assert isinstance(num_split_points, int)
|
||||
|
||||
|
39
egs/librispeech/ASR/conformer_lm/test_dataset_empty.py
Normal file
39
egs/librispeech/ASR/conformer_lm/test_dataset_empty.py
Normal file
@ -0,0 +1,39 @@
|
||||
import k2
|
||||
import torch
|
||||
import _k2
|
||||
import dataset
|
||||
from dataset import LmDataset
|
||||
import os
|
||||
from torch import multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
|
||||
def local_collate_fn(sentences):
|
||||
return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=False)
|
||||
|
||||
x = _k2.RaggedInt('[[1]]') # make sure library initialized?
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
mp.set_start_method('spawn')
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12344"
|
||||
|
||||
dist.init_process_group(backend="nccl", group_name="main",
|
||||
rank=0, world_size=1)
|
||||
|
||||
words = k2.RaggedInt('[[0][1 2]]')
|
||||
sentences = k2.RaggedInt('[[1][][][][][]]')
|
||||
|
||||
train = LmDataset(sentences, words)
|
||||
|
||||
|
||||
sampler = dataset.LmBatchSampler(train, symbols_per_batch=10, world_size=1, rank=0)
|
||||
|
||||
a = iter(sampler)
|
||||
print(str(next(a)))
|
||||
|
||||
train_dl = torch.utils.data.DataLoader(train, batch_sampler=sampler,
|
||||
collate_fn=local_collate_fn,
|
||||
num_workers=0)
|
||||
x = iter(train_dl)
|
||||
print(str(next(x)))
|
Loading…
x
Reference in New Issue
Block a user