Done some changes that are required.

This commit is contained in:
Mingshuang Luo 2022-01-06 18:00:05 +08:00
parent 7391f4febf
commit d412dbb2f0
19 changed files with 212 additions and 90 deletions

View File

@ -30,7 +30,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from local.dataset_audio import dataset_audio from local.dataset_audio import AudioDataSet
from model import AudioNet from model import AudioNet
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
@ -467,7 +467,7 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
grid = dataset_audio( grid = AudioDataSet(
params.video_path, params.video_path,
params.anno_path, params.anno_path,
params.val_list, params.val_list,

View File

@ -1,4 +1,5 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #

View File

@ -209,7 +209,7 @@ def main():
logging.info("Decoding started") logging.info("Decoding started")
features = fbank(waves) features = fbank(waves)
# aud_padding = 480 # Here , we set aud_padding as 480.
features_new = torch.zeros(len(features), 480, params.feature_dim).to( features_new = torch.zeros(len(features), 480, params.feature_dim).to(
device device
) )

View File

@ -32,7 +32,7 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from local.dataset_audio import dataset_audio from local.dataset_audio import AudioDataSet
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import AudioNet from model import AudioNet
from torch import Tensor from torch import Tensor
@ -533,7 +533,7 @@ def run(rank, world_size, args):
optimizer.load_state_dict(checkpoints["optimizer"]) optimizer.load_state_dict(checkpoints["optimizer"])
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
grid = dataset_audio( grid = AudioDataSet(
params.video_path, params.video_path,
params.anno_path, params.anno_path,
params.train_list, params.train_list,

View File

@ -14,27 +14,32 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch
def encode_supervisions(nnet_output_shape, batch):
""" """
Encodes the output of net and texts into This script is to encodes the supervisions as Tuple list.
a pair of torch Tensor, and a list of transcription strings.
The supervision tensor has shape ``(batch_size, 3)``. The supervision tensor has shape ``(batch_size, 3)``.
Its second dimension contains information about sequence index [0], Its second dimension contains information about sequence index [0],
start frames [1] and num frames [2]. start frames [1] and num frames [2].
In GRID, the start frame of each audio sample is 0. In GRID, the start frame of each audio sample is 0.
""" """
import torch
def encode_supervisions(nnet_output_shape: int, batch: dict):
"""
Args:
nnet_output_shape:
The shape of nnet_output, e.g: (N, T, D).
batch:
A batch of dataloader, it's a dict file
including text and aud/vid arrays.
Return:
The tuple list of supervisions and the text in batch.
"""
N, T, D = nnet_output_shape N, T, D = nnet_output_shape
supervisions_idx = torch.arange(0, N).to(torch.int32) supervisions_idx = torch.arange(0, N, dtype=torch.int32)
start_frames = [0 for _ in range(N)] supervisions_start_frame = torch.full((1, N), 0, dtype=torch.int32)[0]
supervisions_start_frame = torch.tensor(start_frames).to(torch.int32) supervisions_num_frames = torch.full((1, N), T, dtype=torch.int32)[0]
num_frames = [T for _ in range(N)]
supervisions_num_frames = torch.tensor(num_frames).to(torch.int32)
supervision_segments = torch.stack( supervision_segments = torch.stack(
( (
@ -43,7 +48,7 @@ def encode_supervisions(nnet_output_shape, batch):
supervisions_num_frames, supervisions_num_frames,
), ),
1, 1,
).to(torch.int32) )
texts = batch["txt"] texts = batch["txt"]
return supervision_segments, texts return supervision_segments, texts

View File

@ -30,7 +30,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from local.dataset_av import dataset_av from local.dataset_av import AudioVisualDataset
from model import CombineNet from model import CombineNet
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
@ -475,7 +475,7 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
grid = dataset_av( grid = AudioVisualDataset(
params.video_path, params.video_path,
params.anno_path, params.anno_path,
params.val_list, params.val_list,

View File

@ -1,4 +1,4 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #

View File

@ -32,7 +32,7 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from local.dataset_av import dataset_av from local.dataset_av import AudioVisualDataset
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import CombineNet from model import CombineNet
from torch import Tensor from torch import Tensor
@ -544,7 +544,7 @@ def run(rank, world_size, args):
optimizer.load_state_dict(checkpoints["optimizer"]) optimizer.load_state_dict(checkpoints["optimizer"])
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
grid = dataset_av( grid = AudioVisualDataset(
params.video_path, params.video_path,
params.anno_path, params.anno_path,
params.train_list, params.train_list,

View File

@ -19,24 +19,20 @@ import torch
def encode_supervisions(nnet_output_shape, batch): def encode_supervisions(nnet_output_shape, batch):
""" """
Encodes Lhotse's ``batch["supervisions"]`` dict into Encodes the output of net and texts into
a pair of torch Tensor, and a list of transcription strings. a pair of torch Tensor, and a list of transcription strings.
The supervision tensor has shape ``(batch_size, 3)``. The supervision tensor has shape ``(batch_size, 3)``.
Its second dimension contains information about sequence index [0], Its second dimension contains information about sequence index [0],
start frames [1] and num frames [2]. start frames [1] and num frames [2].
The batch items might become re-ordered during this operation -- the In GRID, the start frame of each audio sample is 0.
returned tensor and list of strings are guaranteed to be consistent with
each other.
""" """
N, T, D = nnet_output_shape N, T, D = nnet_output_shape
supervisions_idx = torch.arange(0, N).to(torch.int32) supervisions_idx = torch.arange(0, N, dtype=torch.int32)
start_frames = [0 for _ in range(N)] supervisions_start_frame = torch.full((1, N), 0, dtype=torch.int32)[0]
supervisions_start_frame = torch.tensor(start_frames).to(torch.int32) supervisions_num_frames = torch.full((1, N), T, dtype=torch.int32)[0]
num_frames = [T for _ in range(N)]
supervisions_num_frames = torch.tensor(num_frames).to(torch.int32)
supervision_segments = torch.stack( supervision_segments = torch.stack(
( (
@ -45,7 +41,7 @@ def encode_supervisions(nnet_output_shape, batch):
supervisions_num_frames, supervisions_num_frames,
), ),
1, 1,
).to(torch.int32) )
texts = batch["txt"] texts = batch["txt"]
return supervision_segments, texts return supervision_segments, texts

View File

@ -26,13 +26,31 @@ The input for the above functions is a sequence of images.
import random import random
def HorizontalFlip(batch_img, p=0.5): def horizontal_flip(batch_img: float, p: float):
# (T, H, W, C) """
Args:
batch_img:
The float array of a sequence of images, the shape of the
arrat is (T, H, W, C).
p:
The probability of implementing horizontal flip, the defaults
value is 0.5.
Return:
A new float array of the sequence of images after flipping.
"""
if random.random() > p: if random.random() > p:
batch_img = batch_img[:, :, ::-1, ...] batch_img = batch_img[:, :, ::-1, ...]
return batch_img return batch_img
def ColorNormalize(batch_img): def color_normalize(batch_img: float):
"""
Args:
batch_img:
The float array of a sequence of images, the shape of the
arrat is (T, H, W, C).
Return:
A new float array of the sequence of images after normalizing.
"""
batch_img = batch_img / 255.0 batch_img = batch_img / 255.0
return batch_img return batch_img

View File

@ -19,7 +19,6 @@
This script is to load the audio data in GRID. This script is to load the audio data in GRID.
The class dataset_audio makes each audio batch data have the same shape. The class dataset_audio makes each audio batch data have the same shape.
""" """
import kaldifeat
import numpy as np import numpy as np
import os import os
@ -27,8 +26,10 @@ import torch
import torchaudio import torchaudio
from torch.utils.data import Dataset from torch.utils.data import Dataset
import kaldifeat
class dataset_audio(Dataset):
class AudioDataSet(Dataset):
def __init__( def __init__(
self, self,
video_path: str, video_path: str,
@ -46,7 +47,7 @@ class dataset_audio(Dataset):
anno_path: anno_path:
The dir path of the texts data. The dir path of the texts data.
file_list: file_list:
The file which listing all samples for training or testing. A txt file which listing all samples for training or testing.
aud_padding: aud_padding:
The padding for each audio sample. The padding for each audio sample.
sample_rate: sample_rate:
@ -61,6 +62,15 @@ class dataset_audio(Dataset):
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.feature_dim = feature_dim self.feature_dim = feature_dim
self.phase = phase self.phase = phase
opts = kaldifeat.FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = self.sample_rate
opts.mel_opts.num_bins = self.feature_dim
self.fbank = kaldifeat.Fbank(opts)
with open(file_list, "r") as f: with open(file_list, "r") as f:
self.videos = [ self.videos = [
os.path.join(video_path, line.strip()) for line in f.readlines() os.path.join(video_path, line.strip()) for line in f.readlines()
@ -92,19 +102,26 @@ class dataset_audio(Dataset):
return len(self.data) return len(self.data)
def _load_aud(self, filename): def _load_aud(self, filename):
opts = kaldifeat.FbankOptions() """Load the audio data.
opts.device = "cpu" Args:
opts.frame_opts.dither = 0 filename:
opts.frame_opts.snip_edges = False The full path of a wav file.
opts.frame_opts.samp_freq = self.sample_rate Return:
opts.mel_opts.num_bins = self.feature_dim The fbank feature array.
fbank = kaldifeat.Fbank(opts) """
wave, sr = torchaudio.load(filename) wave, _ = torchaudio.load(filename)
wave = wave[0] wave = wave[0]
features = fbank(wave) features = self.fbank(wave)
return features return features
def _load_anno(self, name): def _load_anno(self, name):
"""Load the text file.
Args:
name:
The file which records the text.
Return:
A sequence of words.
"""
with open(name, "r") as f: with open(name, "r") as f:
lines = [line.strip().split(" ") for line in f.readlines()] lines = [line.strip().split(" ") for line in f.readlines()]
txt = [line[2] for line in lines] txt = [line[2] for line in lines]
@ -113,6 +130,15 @@ class dataset_audio(Dataset):
return txt return txt
def _padding(self, array, length): def _padding(self, array, length):
"""Pad zeros for the feature array.
Args:
array:
The feature arry. (Audio or Visual feature)
length:
The length for padding.
Return:
A new feature array after padding.
"""
array = [array[_] for _ in range(array.shape[0])] array = [array[_] for _ in range(array.shape[0])]
size = array[0].shape size = array[0].shape
for i in range(length - len(array)): for i in range(length - len(array)):

View File

@ -20,7 +20,6 @@ This script is to load the pair of audio-visual data in GRID.
The class dataset_av makes each audio-visual batch data have the same shape. The class dataset_av makes each audio-visual batch data have the same shape.
""" """
import cv2 import cv2
import kaldifeat
import numpy as np import numpy as np
import os import os
@ -28,10 +27,11 @@ import torch
import torchaudio import torchaudio
from torch.utils.data import Dataset from torch.utils.data import Dataset
from .cvtransforms import HorizontalFlip, ColorNormalize import kaldifeat
from .cvtransforms import horizontal_flip, color_normalize
class dataset_av(Dataset): class AudioVisualDataset(Dataset):
def __init__( def __init__(
self, self,
video_path, video_path,
@ -94,8 +94,8 @@ class dataset_av(Dataset):
) )
if self.phase == "train": if self.phase == "train":
vid = HorizontalFlip(vid) vid = horizontal_flip(vid)
vid = ColorNormalize(vid) vid = color_normalize(vid)
vid = self._padding(vid, self.vid_pading) vid = self._padding(vid, self.vid_pading)
aud = self._padding(aud, self.aud_pading) aud = self._padding(aud, self.aud_pading)
@ -110,6 +110,14 @@ class dataset_av(Dataset):
return len(self.data) return len(self.data)
def _load_vid(self, p): def _load_vid(self, p):
"""Load the visual data.
Args:
p:
A directory which contains a sequence of frames
for a visual sample.
Return:
The array of a visual sample.
"""
files = os.listdir(p) files = os.listdir(p)
files = list(filter(lambda file: file.find(".jpg") != -1, files)) files = list(filter(lambda file: file.find(".jpg") != -1, files))
files = sorted(files, key=lambda file: int(os.path.splitext(file)[0])) files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
@ -123,6 +131,13 @@ class dataset_av(Dataset):
return array return array
def _load_aud(self, filename): def _load_aud(self, filename):
"""Load the audio data.
Args:
filename:
The full path of a wav file.
Return:
The fbank feature array.
"""
opts = kaldifeat.FbankOptions() opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0 opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False opts.frame_opts.snip_edges = False
@ -135,6 +150,13 @@ class dataset_av(Dataset):
return features return features
def _load_anno(self, name): def _load_anno(self, name):
"""Load the text file.
Args:
name:
The file which records the text.
Return:
A sequence of words.
"""
with open(name, "r") as f: with open(name, "r") as f:
lines = [line.strip().split(" ") for line in f.readlines()] lines = [line.strip().split(" ") for line in f.readlines()]
txt = [line[2] for line in lines] txt = [line[2] for line in lines]
@ -143,6 +165,15 @@ class dataset_av(Dataset):
return txt return txt
def _padding(self, array, length): def _padding(self, array, length):
"""Pad zeros for the feature array.
Args:
array:
The feature arry. (Audio or Visual feature)
length:
The length for padding.
Return:
A new feature array after padding.
"""
array = [array[_] for _ in range(array.shape[0])] array = [array[_] for _ in range(array.shape[0])]
size = array[0].shape size = array[0].shape
for i in range(length - len(array)): for i in range(length - len(array)):

View File

@ -24,10 +24,13 @@ import os
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from .cvtransforms import HorizontalFlip, ColorNormalize from .cvtransforms import (
color_normalize,
horizontal_flip,
)
class dataset_visual(Dataset): class VisualDataset(Dataset):
def __init__( def __init__(
self, self,
video_path: str, video_path: str,
@ -74,8 +77,8 @@ class dataset_visual(Dataset):
) )
if self.phase == "train": if self.phase == "train":
vid = HorizontalFlip(vid) vid = horizontal_flip(vid, p=0.5)
vid = ColorNormalize(vid) vid = color_normalize(vid)
vid = self._padding(vid, self.vid_padding) vid = self._padding(vid, self.vid_padding)
@ -88,6 +91,14 @@ class dataset_visual(Dataset):
return len(self.data) return len(self.data)
def _load_vid(self, p): def _load_vid(self, p):
"""Load the visual data.
Args:
p:
A directory which contains a sequence of frames
for a visual sample.
Return:
The array of a visual sample.
"""
files = os.listdir(p) files = os.listdir(p)
files = list(filter(lambda file: file.find(".jpg") != -1, files)) files = list(filter(lambda file: file.find(".jpg") != -1, files))
files = sorted(files, key=lambda file: int(os.path.splitext(file)[0])) files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
@ -101,6 +112,13 @@ class dataset_visual(Dataset):
return array return array
def _load_anno(self, name): def _load_anno(self, name):
"""Load the text file.
Args:
name:
The file which records the text.
Return:
A sequence of words.
"""
with open(name, "r") as f: with open(name, "r") as f:
lines = [line.strip().split(" ") for line in f.readlines()] lines = [line.strip().split(" ") for line in f.readlines()]
txt = [line[2] for line in lines] txt = [line[2] for line in lines]
@ -109,6 +127,15 @@ class dataset_visual(Dataset):
return txt return txt
def _padding(self, array, length): def _padding(self, array, length):
"""Pad zeros for the feature array.
Args:
array:
The feature arry. (Audio or Visual feature)
length:
The length for padding.
Return:
A new feature array after padding.
"""
array = [array[_] for _ in range(array.shape[0])] array = [array[_] for _ in range(array.shape[0])]
size = array[0].shape size = array[0].shape
for i in range(length - len(array)): for i in range(length - len(array)):

View File

@ -30,7 +30,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from local.dataset_visual import dataset_visual from local.dataset_visual import VisualDataset
from model import VisualNet2 from model import VisualNet2
@ -463,7 +463,7 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
grid = dataset_visual( grid = VisualDataset(
params.video_path, params.video_path,
params.anno_path, params.anno_path,
params.val_list, params.val_list,

View File

@ -32,7 +32,7 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from local.dataset_visual import dataset_visual from local.dataset_visual import VisualDataset
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import VisualNet2 from model import VisualNet2
@ -529,7 +529,7 @@ def run(rank, world_size, args):
optimizer.load_state_dict(checkpoints["optimizer"]) optimizer.load_state_dict(checkpoints["optimizer"])
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
grid = dataset_visual( grid = VisualDataset(
params.video_path, params.video_path,
params.anno_path, params.anno_path,
params.train_list, params.train_list,

View File

@ -14,22 +14,32 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
This script is to encodes the supervisions as Tuple list.
The supervision tensor has shape ``(batch_size, 3)``.
Its second dimension contains information about sequence index [0],
start frames [1] and num frames [2].
In GRID, the start frame of each audio sample is 0.
"""
import torch import torch
def encode_supervisions(nnet_output_shape, batch): def encode_supervisions(nnet_output_shape: int, batch: dict):
""" """
In GRID, the lengths of all samples are same. Args:
And here, we don't deploy cut operation on it. nnet_output_shape:
So, the start frame is always 0 among all samples. The shape of nnet_output, e.g: (N, T, D).
batch:
A batch of dataloader, it's a dict file
including text and aud/vid arrays.
Return:
The tuple list of supervisions and the text in batch.
""" """
N, T, D = nnet_output_shape N, T, D = nnet_output_shape
supervisions_idx = torch.arange(0, N).to(torch.int32) supervisions_idx = torch.arange(0, N, dtype=torch.int32)
start_frames = [0 for _ in range(N)] supervisions_start_frame = torch.full((1, N), 0, dtype=torch.int32)[0]
supervisions_start_frame = torch.tensor(start_frames).to(torch.int32) supervisions_num_frames = torch.full((1, N), T, dtype=torch.int32)[0]
num_frames = [T for _ in range(N)]
supervisions_num_frames = torch.tensor(num_frames).to(torch.int32)
supervision_segments = torch.stack( supervision_segments = torch.stack(
( (
@ -38,8 +48,7 @@ def encode_supervisions(nnet_output_shape, batch):
supervisions_num_frames, supervisions_num_frames,
), ),
1, 1,
).to(torch.int32) )
texts = batch["txt"] texts = batch["txt"]
return supervision_segments, texts return supervision_segments, texts

View File

@ -30,7 +30,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from local.dataset_visual import dataset_visual from local.dataset_visual import VisualDataset
from model import VisualNet from model import VisualNet
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
@ -462,7 +462,7 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
grid = dataset_visual( grid = VisualDataset(
params.video_path, params.video_path,
params.anno_path, params.anno_path,
params.val_list, params.val_list,

View File

@ -32,7 +32,7 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from local.dataset_visual import dataset_visual from local.dataset_visual import VisualDataset
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import VisualNet from model import VisualNet
from torch import Tensor from torch import Tensor
@ -528,7 +528,7 @@ def run(rank, world_size, args):
optimizer.load_state_dict(checkpoints["optimizer"]) optimizer.load_state_dict(checkpoints["optimizer"])
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
grid = dataset_visual( grid = VisualDataset(
params.video_path, params.video_path,
params.anno_path, params.anno_path,
params.train_list, params.train_list,

View File

@ -14,22 +14,32 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
This script is to encodes the supervisions as Tuple list.
The supervision tensor has shape ``(batch_size, 3)``.
Its second dimension contains information about sequence index [0],
start frames [1] and num frames [2].
In GRID, the start frame of each audio sample is 0.
"""
import torch import torch
def encode_supervisions(nnet_output_shape, batch): def encode_supervisions(nnet_output_shape: int, batch: dict):
""" """
In GRID, the lengths of all samples are same. Args:
And here, we don't deploy cut operation on it. nnet_output_shape:
So, the start frame is always 0 among all samples. The shape of nnet_output, e.g: (N, T, D).
batch:
A batch of dataloader, it's a dict file
including text and aud/vid arrays.
Return:
The tuple list of supervisions and the text in batch.
""" """
N, T, D = nnet_output_shape N, T, D = nnet_output_shape
supervisions_idx = torch.arange(0, N).to(torch.int32) supervisions_idx = torch.arange(0, N, dtype=torch.int32)
start_frames = [0 for _ in range(N)] supervisions_start_frame = torch.full((1, N), 0, dtype=torch.int32)[0]
supervisions_start_frame = torch.tensor(start_frames).to(torch.int32) supervisions_num_frames = torch.full((1, N), T, dtype=torch.int32)[0]
num_frames = [T for _ in range(N)]
supervisions_num_frames = torch.tensor(num_frames).to(torch.int32)
supervision_segments = torch.stack( supervision_segments = torch.stack(
( (
@ -38,8 +48,7 @@ def encode_supervisions(nnet_output_shape, batch):
supervisions_num_frames, supervisions_num_frames,
), ),
1, 1,
).to(torch.int32) )
texts = batch["txt"] texts = batch["txt"]
return supervision_segments, texts return supervision_segments, texts