mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Refactoring LM loader
This commit is contained in:
parent
25d540a758
commit
00b1c291a6
@ -18,8 +18,10 @@ from typing import List, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from icefall.utils import AttributeDict
|
||||
from icefall.utils import AttributeDict, add_eos, add_sos
|
||||
|
||||
|
||||
class LmDataset(torch.utils.data.Dataset):
|
||||
@ -113,104 +115,6 @@ class LmDataset(torch.utils.data.Dataset):
|
||||
return sentence_tokens
|
||||
|
||||
|
||||
def concat(
|
||||
ragged: k2.RaggedTensor, value: int, direction: str
|
||||
) -> k2.RaggedTensor:
|
||||
"""Prepend a value to the beginning of each sublist or append a value.
|
||||
to the end of each sublist.
|
||||
|
||||
Args:
|
||||
ragged:
|
||||
A ragged tensor with two axes.
|
||||
value:
|
||||
The value to prepend or append.
|
||||
direction:
|
||||
It can be either "left" or "right". If it is "left", we
|
||||
prepend the value to the beginning of each sublist;
|
||||
if it is "right", we append the value to the end of each
|
||||
sublist.
|
||||
|
||||
Returns:
|
||||
Return a new ragged tensor, whose sublists either start with
|
||||
or end with the given value.
|
||||
|
||||
>>> a = k2.RaggedTensor([[1, 3], [5]])
|
||||
>>> a
|
||||
[ [ 1 3 ] [ 5 ] ]
|
||||
>>> concat(a, value=0, direction="left")
|
||||
[ [ 0 1 3 ] [ 0 5 ] ]
|
||||
>>> concat(a, value=0, direction="right")
|
||||
[ [ 1 3 0 ] [ 5 0 ] ]
|
||||
|
||||
"""
|
||||
dtype = ragged.dtype
|
||||
device = ragged.device
|
||||
|
||||
assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}"
|
||||
pad_values = torch.full(
|
||||
size=(ragged.tot_size(0), 1),
|
||||
fill_value=value,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
pad = k2.RaggedTensor(pad_values)
|
||||
|
||||
if direction == "left":
|
||||
ans = k2.ragged.cat([pad, ragged], axis=1)
|
||||
elif direction == "right":
|
||||
ans = k2.ragged.cat([ragged, pad], axis=1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Unsupported direction: {direction}. " \
|
||||
"Expect either "left" or "right"'
|
||||
)
|
||||
return ans
|
||||
|
||||
|
||||
def add_sos(ragged: k2.RaggedTensor, sos_id: int) -> k2.RaggedTensor:
|
||||
"""Add SOS to each sublist.
|
||||
|
||||
Args:
|
||||
ragged:
|
||||
A ragged tensor with two axes.
|
||||
sos_id:
|
||||
The ID of the SOS symbol.
|
||||
|
||||
Returns:
|
||||
Return a new ragged tensor, where each sublist starts with SOS.
|
||||
|
||||
>>> a = k2.RaggedTensor([[1, 3], [5]])
|
||||
>>> a
|
||||
[ [ 1 3 ] [ 5 ] ]
|
||||
>>> add_sos(a, sos_id=0)
|
||||
[ [ 0 1 3 ] [ 0 5 ] ]
|
||||
|
||||
"""
|
||||
return concat(ragged, sos_id, direction="left")
|
||||
|
||||
|
||||
def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor:
|
||||
"""Add EOS to each sublist.
|
||||
|
||||
Args:
|
||||
ragged:
|
||||
A ragged tensor with two axes.
|
||||
eos_id:
|
||||
The ID of the EOS symbol.
|
||||
|
||||
Returns:
|
||||
Return a new ragged tensor, where each sublist ends with EOS.
|
||||
|
||||
>>> a = k2.RaggedTensor([[1, 3], [5]])
|
||||
>>> a
|
||||
[ [ 1 3 ] [ 5 ] ]
|
||||
>>> add_eos(a, eos_id=0)
|
||||
[ [ 1 3 0 ] [ 5 0 ] ]
|
||||
|
||||
"""
|
||||
return concat(ragged, eos_id, direction="right")
|
||||
|
||||
|
||||
class LmDatasetCollate:
|
||||
def __init__(self, sos_id: int, eos_id: int, blank_id: int):
|
||||
"""
|
||||
@ -294,9 +198,7 @@ def get_dataloader(
|
||||
batch_size=params.batch_size,
|
||||
)
|
||||
if is_distributed:
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
dataset, shuffle=True, drop_last=False
|
||||
)
|
||||
sampler = DistributedSampler(dataset, shuffle=True, drop_last=False)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
@ -306,7 +208,7 @@ def get_dataloader(
|
||||
blank_id=params.blank_id,
|
||||
)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
collate_fn=collate_fn,
|
||||
|
Loading…
x
Reference in New Issue
Block a user