diff --git a/egs/librispeech/ASR/rnn_lm/dataset.py b/egs/librispeech/ASR/rnn_lm/dataset.py index 2da9539d1..598e329c4 100644 --- a/egs/librispeech/ASR/rnn_lm/dataset.py +++ b/egs/librispeech/ASR/rnn_lm/dataset.py @@ -18,8 +18,10 @@ from typing import List, Tuple import k2 import torch +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler -from icefall.utils import AttributeDict +from icefall.utils import AttributeDict, add_eos, add_sos class LmDataset(torch.utils.data.Dataset): @@ -113,104 +115,6 @@ class LmDataset(torch.utils.data.Dataset): return sentence_tokens -def concat( - ragged: k2.RaggedTensor, value: int, direction: str -) -> k2.RaggedTensor: - """Prepend a value to the beginning of each sublist or append a value. - to the end of each sublist. - - Args: - ragged: - A ragged tensor with two axes. - value: - The value to prepend or append. - direction: - It can be either "left" or "right". If it is "left", we - prepend the value to the beginning of each sublist; - if it is "right", we append the value to the end of each - sublist. - - Returns: - Return a new ragged tensor, whose sublists either start with - or end with the given value. - - >>> a = k2.RaggedTensor([[1, 3], [5]]) - >>> a - [ [ 1 3 ] [ 5 ] ] - >>> concat(a, value=0, direction="left") - [ [ 0 1 3 ] [ 0 5 ] ] - >>> concat(a, value=0, direction="right") - [ [ 1 3 0 ] [ 5 0 ] ] - - """ - dtype = ragged.dtype - device = ragged.device - - assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}" - pad_values = torch.full( - size=(ragged.tot_size(0), 1), - fill_value=value, - device=device, - dtype=dtype, - ) - pad = k2.RaggedTensor(pad_values) - - if direction == "left": - ans = k2.ragged.cat([pad, ragged], axis=1) - elif direction == "right": - ans = k2.ragged.cat([ragged, pad], axis=1) - else: - raise ValueError( - f'Unsupported direction: {direction}. " \ - "Expect either "left" or "right"' - ) - return ans - - -def add_sos(ragged: k2.RaggedTensor, sos_id: int) -> k2.RaggedTensor: - """Add SOS to each sublist. - - Args: - ragged: - A ragged tensor with two axes. - sos_id: - The ID of the SOS symbol. - - Returns: - Return a new ragged tensor, where each sublist starts with SOS. - - >>> a = k2.RaggedTensor([[1, 3], [5]]) - >>> a - [ [ 1 3 ] [ 5 ] ] - >>> add_sos(a, sos_id=0) - [ [ 0 1 3 ] [ 0 5 ] ] - - """ - return concat(ragged, sos_id, direction="left") - - -def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor: - """Add EOS to each sublist. - - Args: - ragged: - A ragged tensor with two axes. - eos_id: - The ID of the EOS symbol. - - Returns: - Return a new ragged tensor, where each sublist ends with EOS. - - >>> a = k2.RaggedTensor([[1, 3], [5]]) - >>> a - [ [ 1 3 ] [ 5 ] ] - >>> add_eos(a, eos_id=0) - [ [ 1 3 0 ] [ 5 0 ] ] - - """ - return concat(ragged, eos_id, direction="right") - - class LmDatasetCollate: def __init__(self, sos_id: int, eos_id: int, blank_id: int): """ @@ -294,9 +198,7 @@ def get_dataloader( batch_size=params.batch_size, ) if is_distributed: - sampler = torch.utils.data.distributed.DistributedSampler( - dataset, shuffle=True, drop_last=False - ) + sampler = DistributedSampler(dataset, shuffle=True, drop_last=False) else: sampler = None @@ -306,7 +208,7 @@ def get_dataloader( blank_id=params.blank_id, ) - dataloader = torch.utils.data.DataLoader( + dataloader = DataLoader( dataset, batch_size=1, collate_fn=collate_fn,