mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Refactoring LM loader
This commit is contained in:
parent
25d540a758
commit
00b1c291a6
@ -18,8 +18,10 @@ from typing import List, Tuple
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from icefall.utils import AttributeDict
|
from icefall.utils import AttributeDict, add_eos, add_sos
|
||||||
|
|
||||||
|
|
||||||
class LmDataset(torch.utils.data.Dataset):
|
class LmDataset(torch.utils.data.Dataset):
|
||||||
@ -113,104 +115,6 @@ class LmDataset(torch.utils.data.Dataset):
|
|||||||
return sentence_tokens
|
return sentence_tokens
|
||||||
|
|
||||||
|
|
||||||
def concat(
|
|
||||||
ragged: k2.RaggedTensor, value: int, direction: str
|
|
||||||
) -> k2.RaggedTensor:
|
|
||||||
"""Prepend a value to the beginning of each sublist or append a value.
|
|
||||||
to the end of each sublist.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ragged:
|
|
||||||
A ragged tensor with two axes.
|
|
||||||
value:
|
|
||||||
The value to prepend or append.
|
|
||||||
direction:
|
|
||||||
It can be either "left" or "right". If it is "left", we
|
|
||||||
prepend the value to the beginning of each sublist;
|
|
||||||
if it is "right", we append the value to the end of each
|
|
||||||
sublist.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Return a new ragged tensor, whose sublists either start with
|
|
||||||
or end with the given value.
|
|
||||||
|
|
||||||
>>> a = k2.RaggedTensor([[1, 3], [5]])
|
|
||||||
>>> a
|
|
||||||
[ [ 1 3 ] [ 5 ] ]
|
|
||||||
>>> concat(a, value=0, direction="left")
|
|
||||||
[ [ 0 1 3 ] [ 0 5 ] ]
|
|
||||||
>>> concat(a, value=0, direction="right")
|
|
||||||
[ [ 1 3 0 ] [ 5 0 ] ]
|
|
||||||
|
|
||||||
"""
|
|
||||||
dtype = ragged.dtype
|
|
||||||
device = ragged.device
|
|
||||||
|
|
||||||
assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}"
|
|
||||||
pad_values = torch.full(
|
|
||||||
size=(ragged.tot_size(0), 1),
|
|
||||||
fill_value=value,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
pad = k2.RaggedTensor(pad_values)
|
|
||||||
|
|
||||||
if direction == "left":
|
|
||||||
ans = k2.ragged.cat([pad, ragged], axis=1)
|
|
||||||
elif direction == "right":
|
|
||||||
ans = k2.ragged.cat([ragged, pad], axis=1)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f'Unsupported direction: {direction}. " \
|
|
||||||
"Expect either "left" or "right"'
|
|
||||||
)
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
def add_sos(ragged: k2.RaggedTensor, sos_id: int) -> k2.RaggedTensor:
|
|
||||||
"""Add SOS to each sublist.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ragged:
|
|
||||||
A ragged tensor with two axes.
|
|
||||||
sos_id:
|
|
||||||
The ID of the SOS symbol.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Return a new ragged tensor, where each sublist starts with SOS.
|
|
||||||
|
|
||||||
>>> a = k2.RaggedTensor([[1, 3], [5]])
|
|
||||||
>>> a
|
|
||||||
[ [ 1 3 ] [ 5 ] ]
|
|
||||||
>>> add_sos(a, sos_id=0)
|
|
||||||
[ [ 0 1 3 ] [ 0 5 ] ]
|
|
||||||
|
|
||||||
"""
|
|
||||||
return concat(ragged, sos_id, direction="left")
|
|
||||||
|
|
||||||
|
|
||||||
def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor:
|
|
||||||
"""Add EOS to each sublist.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ragged:
|
|
||||||
A ragged tensor with two axes.
|
|
||||||
eos_id:
|
|
||||||
The ID of the EOS symbol.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Return a new ragged tensor, where each sublist ends with EOS.
|
|
||||||
|
|
||||||
>>> a = k2.RaggedTensor([[1, 3], [5]])
|
|
||||||
>>> a
|
|
||||||
[ [ 1 3 ] [ 5 ] ]
|
|
||||||
>>> add_eos(a, eos_id=0)
|
|
||||||
[ [ 1 3 0 ] [ 5 0 ] ]
|
|
||||||
|
|
||||||
"""
|
|
||||||
return concat(ragged, eos_id, direction="right")
|
|
||||||
|
|
||||||
|
|
||||||
class LmDatasetCollate:
|
class LmDatasetCollate:
|
||||||
def __init__(self, sos_id: int, eos_id: int, blank_id: int):
|
def __init__(self, sos_id: int, eos_id: int, blank_id: int):
|
||||||
"""
|
"""
|
||||||
@ -294,9 +198,7 @@ def get_dataloader(
|
|||||||
batch_size=params.batch_size,
|
batch_size=params.batch_size,
|
||||||
)
|
)
|
||||||
if is_distributed:
|
if is_distributed:
|
||||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
sampler = DistributedSampler(dataset, shuffle=True, drop_last=False)
|
||||||
dataset, shuffle=True, drop_last=False
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
sampler = None
|
sampler = None
|
||||||
|
|
||||||
@ -306,7 +208,7 @@ def get_dataloader(
|
|||||||
blank_id=params.blank_id,
|
blank_id=params.blank_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user