mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Get dataset to work for empty input sentences; test it
This commit is contained in:
parent
a7b61100de
commit
d045831a4f
@ -116,7 +116,7 @@ def mask_and_pad(sentence: List[int],
|
|||||||
num_pad -= max(0, sent_len + 2 + num_pad - seq_len)
|
num_pad -= max(0, sent_len + 2 + num_pad - seq_len)
|
||||||
|
|
||||||
if num_mask + num_pad == 0:
|
if num_mask + num_pad == 0:
|
||||||
num_mask += 1
|
num_pad += 1
|
||||||
|
|
||||||
# num_split_points is the number of times we split the (masked+padded)
|
# num_split_points is the number of times we split the (masked+padded)
|
||||||
# region, so the total number of (masking+padding) subsequences will be
|
# region, so the total number of (masking+padding) subsequences will be
|
||||||
@ -131,10 +131,7 @@ def mask_and_pad(sentence: List[int],
|
|||||||
num_split_points = int(torch.binomial(count=torch.tensor([float(sent_len - num_mask)]),
|
num_split_points = int(torch.binomial(count=torch.tensor([float(sent_len - num_mask)]),
|
||||||
prob=torch.tensor([mask_proportion * inv_mask_length / (1.0 - mask_proportion)])).item())
|
prob=torch.tensor([mask_proportion * inv_mask_length / (1.0 - mask_proportion)])).item())
|
||||||
# Somehow this assertion failed, debugging it below.
|
# Somehow this assertion failed, debugging it below.
|
||||||
# assert num_split_points <= sent_len - num_mask
|
assert num_split_points <= sent_len - num_mask
|
||||||
if num_split_points > sent_len - num_mask:
|
|
||||||
print(f"Warning about num_split_points: {num_split_points} > {sent_len} - {num_mask}")
|
|
||||||
num_split_points = sent_len - num_mask
|
|
||||||
|
|
||||||
assert isinstance(num_split_points, int)
|
assert isinstance(num_split_points, int)
|
||||||
|
|
||||||
|
39
egs/librispeech/ASR/conformer_lm/test_dataset_empty.py
Normal file
39
egs/librispeech/ASR/conformer_lm/test_dataset_empty.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
import _k2
|
||||||
|
import dataset
|
||||||
|
from dataset import LmDataset
|
||||||
|
import os
|
||||||
|
from torch import multiprocessing as mp
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
def local_collate_fn(sentences):
|
||||||
|
return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=False)
|
||||||
|
|
||||||
|
x = _k2.RaggedInt('[[1]]') # make sure library initialized?
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
mp.set_start_method('spawn')
|
||||||
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
|
os.environ["MASTER_PORT"] = "12344"
|
||||||
|
|
||||||
|
dist.init_process_group(backend="nccl", group_name="main",
|
||||||
|
rank=0, world_size=1)
|
||||||
|
|
||||||
|
words = k2.RaggedInt('[[0][1 2]]')
|
||||||
|
sentences = k2.RaggedInt('[[1][][][][][]]')
|
||||||
|
|
||||||
|
train = LmDataset(sentences, words)
|
||||||
|
|
||||||
|
|
||||||
|
sampler = dataset.LmBatchSampler(train, symbols_per_batch=10, world_size=1, rank=0)
|
||||||
|
|
||||||
|
a = iter(sampler)
|
||||||
|
print(str(next(a)))
|
||||||
|
|
||||||
|
train_dl = torch.utils.data.DataLoader(train, batch_sampler=sampler,
|
||||||
|
collate_fn=local_collate_fn,
|
||||||
|
num_workers=0)
|
||||||
|
x = iter(train_dl)
|
||||||
|
print(str(next(x)))
|
Loading…
x
Reference in New Issue
Block a user