diff --git a/egs/grid/AVSR/audionet_ctc_asr/decode.py b/egs/grid/AVSR/audionet_ctc_asr/decode.py index 7a80c70e0..af02a143f 100644 --- a/egs/grid/AVSR/audionet_ctc_asr/decode.py +++ b/egs/grid/AVSR/audionet_ctc_asr/decode.py @@ -30,7 +30,7 @@ import torch import torch.nn as nn from torch.utils.data import DataLoader -from local.dataset_audio import dataset_audio +from local.dataset_audio import AudioDataSet from model import AudioNet from icefall.checkpoint import average_checkpoints, load_checkpoint @@ -467,7 +467,7 @@ def main(): model.to(device) model.eval() - grid = dataset_audio( + grid = AudioDataSet( params.video_path, params.anno_path, params.val_list, diff --git a/egs/grid/AVSR/audionet_ctc_asr/model.py b/egs/grid/AVSR/audionet_ctc_asr/model.py index 93c442aa7..a368cd85a 100644 --- a/egs/grid/AVSR/audionet_ctc_asr/model.py +++ b/egs/grid/AVSR/audionet_ctc_asr/model.py @@ -1,4 +1,5 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # diff --git a/egs/grid/AVSR/audionet_ctc_asr/pretrained.py b/egs/grid/AVSR/audionet_ctc_asr/pretrained.py index 853a38066..f88d4e0fb 100644 --- a/egs/grid/AVSR/audionet_ctc_asr/pretrained.py +++ b/egs/grid/AVSR/audionet_ctc_asr/pretrained.py @@ -209,7 +209,7 @@ def main(): logging.info("Decoding started") features = fbank(waves) - # aud_padding = 480 + # Here , we set aud_padding as 480. features_new = torch.zeros(len(features), 480, params.feature_dim).to( device ) diff --git a/egs/grid/AVSR/audionet_ctc_asr/train.py b/egs/grid/AVSR/audionet_ctc_asr/train.py index a0e2a002e..f67d4c515 100644 --- a/egs/grid/AVSR/audionet_ctc_asr/train.py +++ b/egs/grid/AVSR/audionet_ctc_asr/train.py @@ -32,7 +32,7 @@ import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader -from local.dataset_audio import dataset_audio +from local.dataset_audio import AudioDataSet from lhotse.utils import fix_random_seed from model import AudioNet from torch import Tensor @@ -533,7 +533,7 @@ def run(rank, world_size, args): optimizer.load_state_dict(checkpoints["optimizer"]) scheduler.load_state_dict(checkpoints["scheduler"]) - grid = dataset_audio( + grid = AudioDataSet( params.video_path, params.anno_path, params.train_list, diff --git a/egs/grid/AVSR/audionet_ctc_asr/utils.py b/egs/grid/AVSR/audionet_ctc_asr/utils.py index 03b1b4ec4..17889a31c 100644 --- a/egs/grid/AVSR/audionet_ctc_asr/utils.py +++ b/egs/grid/AVSR/audionet_ctc_asr/utils.py @@ -14,27 +14,32 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +This script is to encodes the supervisions as Tuple list. +The supervision tensor has shape ``(batch_size, 3)``. +Its second dimension contains information about sequence index [0], +start frames [1] and num frames [2]. +In GRID, the start frame of each audio sample is 0. +""" import torch -def encode_supervisions(nnet_output_shape, batch): +def encode_supervisions(nnet_output_shape: int, batch: dict): """ - Encodes the output of net and texts into - a pair of torch Tensor, and a list of transcription strings. - - The supervision tensor has shape ``(batch_size, 3)``. - Its second dimension contains information about sequence index [0], - start frames [1] and num frames [2]. - - In GRID, the start frame of each audio sample is 0. + Args: + nnet_output_shape: + The shape of nnet_output, e.g: (N, T, D). + batch: + A batch of dataloader, it's a dict file + including text and aud/vid arrays. + Return: + The tuple list of supervisions and the text in batch. """ N, T, D = nnet_output_shape - supervisions_idx = torch.arange(0, N).to(torch.int32) - start_frames = [0 for _ in range(N)] - supervisions_start_frame = torch.tensor(start_frames).to(torch.int32) - num_frames = [T for _ in range(N)] - supervisions_num_frames = torch.tensor(num_frames).to(torch.int32) + supervisions_idx = torch.arange(0, N, dtype=torch.int32) + supervisions_start_frame = torch.full((1, N), 0, dtype=torch.int32)[0] + supervisions_num_frames = torch.full((1, N), T, dtype=torch.int32)[0] supervision_segments = torch.stack( ( @@ -43,7 +48,7 @@ def encode_supervisions(nnet_output_shape, batch): supervisions_num_frames, ), 1, - ).to(torch.int32) + ) texts = batch["txt"] return supervision_segments, texts diff --git a/egs/grid/AVSR/combinenet_ctc_avsr/decode.py b/egs/grid/AVSR/combinenet_ctc_avsr/decode.py index 579225fea..885e7e17d 100644 --- a/egs/grid/AVSR/combinenet_ctc_avsr/decode.py +++ b/egs/grid/AVSR/combinenet_ctc_avsr/decode.py @@ -30,7 +30,7 @@ import torch import torch.nn as nn from torch.utils.data import DataLoader -from local.dataset_av import dataset_av +from local.dataset_av import AudioVisualDataset from model import CombineNet from icefall.checkpoint import average_checkpoints, load_checkpoint @@ -475,7 +475,7 @@ def main(): model.to(device) model.eval() - grid = dataset_av( + grid = AudioVisualDataset( params.video_path, params.anno_path, params.val_list, diff --git a/egs/grid/AVSR/combinenet_ctc_avsr/model.py b/egs/grid/AVSR/combinenet_ctc_avsr/model.py index c0bcffa04..158727192 100644 --- a/egs/grid/AVSR/combinenet_ctc_avsr/model.py +++ b/egs/grid/AVSR/combinenet_ctc_avsr/model.py @@ -1,4 +1,4 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # diff --git a/egs/grid/AVSR/combinenet_ctc_avsr/train.py b/egs/grid/AVSR/combinenet_ctc_avsr/train.py index df476a2b0..7eac415d5 100644 --- a/egs/grid/AVSR/combinenet_ctc_avsr/train.py +++ b/egs/grid/AVSR/combinenet_ctc_avsr/train.py @@ -32,7 +32,7 @@ import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader -from local.dataset_av import dataset_av +from local.dataset_av import AudioVisualDataset from lhotse.utils import fix_random_seed from model import CombineNet from torch import Tensor @@ -544,7 +544,7 @@ def run(rank, world_size, args): optimizer.load_state_dict(checkpoints["optimizer"]) scheduler.load_state_dict(checkpoints["scheduler"]) - grid = dataset_av( + grid = AudioVisualDataset( params.video_path, params.anno_path, params.train_list, diff --git a/egs/grid/AVSR/combinenet_ctc_avsr/utils.py b/egs/grid/AVSR/combinenet_ctc_avsr/utils.py index dbe01bb25..66823874a 100644 --- a/egs/grid/AVSR/combinenet_ctc_avsr/utils.py +++ b/egs/grid/AVSR/combinenet_ctc_avsr/utils.py @@ -19,24 +19,20 @@ import torch def encode_supervisions(nnet_output_shape, batch): """ - Encodes Lhotse's ``batch["supervisions"]`` dict into + Encodes the output of net and texts into a pair of torch Tensor, and a list of transcription strings. The supervision tensor has shape ``(batch_size, 3)``. Its second dimension contains information about sequence index [0], start frames [1] and num frames [2]. - The batch items might become re-ordered during this operation -- the - returned tensor and list of strings are guaranteed to be consistent with - each other. + In GRID, the start frame of each audio sample is 0. """ N, T, D = nnet_output_shape - supervisions_idx = torch.arange(0, N).to(torch.int32) - start_frames = [0 for _ in range(N)] - supervisions_start_frame = torch.tensor(start_frames).to(torch.int32) - num_frames = [T for _ in range(N)] - supervisions_num_frames = torch.tensor(num_frames).to(torch.int32) + supervisions_idx = torch.arange(0, N, dtype=torch.int32) + supervisions_start_frame = torch.full((1, N), 0, dtype=torch.int32)[0] + supervisions_num_frames = torch.full((1, N), T, dtype=torch.int32)[0] supervision_segments = torch.stack( ( @@ -45,7 +41,7 @@ def encode_supervisions(nnet_output_shape, batch): supervisions_num_frames, ), 1, - ).to(torch.int32) + ) texts = batch["txt"] return supervision_segments, texts diff --git a/egs/grid/AVSR/local/cvtransforms.py b/egs/grid/AVSR/local/cvtransforms.py index 11f47a57d..0e64c632e 100644 --- a/egs/grid/AVSR/local/cvtransforms.py +++ b/egs/grid/AVSR/local/cvtransforms.py @@ -26,13 +26,31 @@ The input for the above functions is a sequence of images. import random -def HorizontalFlip(batch_img, p=0.5): - # (T, H, W, C) +def horizontal_flip(batch_img: float, p: float): + """ + Args: + batch_img: + The float array of a sequence of images, the shape of the + arrat is (T, H, W, C). + p: + The probability of implementing horizontal flip, the defaults + value is 0.5. + Return: + A new float array of the sequence of images after flipping. + """ if random.random() > p: batch_img = batch_img[:, :, ::-1, ...] return batch_img -def ColorNormalize(batch_img): +def color_normalize(batch_img: float): + """ + Args: + batch_img: + The float array of a sequence of images, the shape of the + arrat is (T, H, W, C). + Return: + A new float array of the sequence of images after normalizing. + """ batch_img = batch_img / 255.0 return batch_img diff --git a/egs/grid/AVSR/local/dataset_audio.py b/egs/grid/AVSR/local/dataset_audio.py index 27868529f..1e99f7607 100644 --- a/egs/grid/AVSR/local/dataset_audio.py +++ b/egs/grid/AVSR/local/dataset_audio.py @@ -19,7 +19,6 @@ This script is to load the audio data in GRID. The class dataset_audio makes each audio batch data have the same shape. """ -import kaldifeat import numpy as np import os @@ -27,8 +26,10 @@ import torch import torchaudio from torch.utils.data import Dataset +import kaldifeat -class dataset_audio(Dataset): + +class AudioDataSet(Dataset): def __init__( self, video_path: str, @@ -46,7 +47,7 @@ class dataset_audio(Dataset): anno_path: The dir path of the texts data. file_list: - The file which listing all samples for training or testing. + A txt file which listing all samples for training or testing. aud_padding: The padding for each audio sample. sample_rate: @@ -61,6 +62,15 @@ class dataset_audio(Dataset): self.sample_rate = sample_rate self.feature_dim = feature_dim self.phase = phase + + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = self.sample_rate + opts.mel_opts.num_bins = self.feature_dim + self.fbank = kaldifeat.Fbank(opts) + with open(file_list, "r") as f: self.videos = [ os.path.join(video_path, line.strip()) for line in f.readlines() @@ -92,19 +102,26 @@ class dataset_audio(Dataset): return len(self.data) def _load_aud(self, filename): - opts = kaldifeat.FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = self.sample_rate - opts.mel_opts.num_bins = self.feature_dim - fbank = kaldifeat.Fbank(opts) - wave, sr = torchaudio.load(filename) + """Load the audio data. + Args: + filename: + The full path of a wav file. + Return: + The fbank feature array. + """ + wave, _ = torchaudio.load(filename) wave = wave[0] - features = fbank(wave) + features = self.fbank(wave) return features def _load_anno(self, name): + """Load the text file. + Args: + name: + The file which records the text. + Return: + A sequence of words. + """ with open(name, "r") as f: lines = [line.strip().split(" ") for line in f.readlines()] txt = [line[2] for line in lines] @@ -113,6 +130,15 @@ class dataset_audio(Dataset): return txt def _padding(self, array, length): + """Pad zeros for the feature array. + Args: + array: + The feature arry. (Audio or Visual feature) + length: + The length for padding. + Return: + A new feature array after padding. + """ array = [array[_] for _ in range(array.shape[0])] size = array[0].shape for i in range(length - len(array)): diff --git a/egs/grid/AVSR/local/dataset_av.py b/egs/grid/AVSR/local/dataset_av.py index 54e095ea0..a0faefd68 100644 --- a/egs/grid/AVSR/local/dataset_av.py +++ b/egs/grid/AVSR/local/dataset_av.py @@ -20,7 +20,6 @@ This script is to load the pair of audio-visual data in GRID. The class dataset_av makes each audio-visual batch data have the same shape. """ import cv2 -import kaldifeat import numpy as np import os @@ -28,10 +27,11 @@ import torch import torchaudio from torch.utils.data import Dataset -from .cvtransforms import HorizontalFlip, ColorNormalize +import kaldifeat +from .cvtransforms import horizontal_flip, color_normalize -class dataset_av(Dataset): +class AudioVisualDataset(Dataset): def __init__( self, video_path, @@ -94,8 +94,8 @@ class dataset_av(Dataset): ) if self.phase == "train": - vid = HorizontalFlip(vid) - vid = ColorNormalize(vid) + vid = horizontal_flip(vid) + vid = color_normalize(vid) vid = self._padding(vid, self.vid_pading) aud = self._padding(aud, self.aud_pading) @@ -110,6 +110,14 @@ class dataset_av(Dataset): return len(self.data) def _load_vid(self, p): + """Load the visual data. + Args: + p: + A directory which contains a sequence of frames + for a visual sample. + Return: + The array of a visual sample. + """ files = os.listdir(p) files = list(filter(lambda file: file.find(".jpg") != -1, files)) files = sorted(files, key=lambda file: int(os.path.splitext(file)[0])) @@ -123,6 +131,13 @@ class dataset_av(Dataset): return array def _load_aud(self, filename): + """Load the audio data. + Args: + filename: + The full path of a wav file. + Return: + The fbank feature array. + """ opts = kaldifeat.FbankOptions() opts.frame_opts.dither = 0 opts.frame_opts.snip_edges = False @@ -135,6 +150,13 @@ class dataset_av(Dataset): return features def _load_anno(self, name): + """Load the text file. + Args: + name: + The file which records the text. + Return: + A sequence of words. + """ with open(name, "r") as f: lines = [line.strip().split(" ") for line in f.readlines()] txt = [line[2] for line in lines] @@ -143,6 +165,15 @@ class dataset_av(Dataset): return txt def _padding(self, array, length): + """Pad zeros for the feature array. + Args: + array: + The feature arry. (Audio or Visual feature) + length: + The length for padding. + Return: + A new feature array after padding. + """ array = [array[_] for _ in range(array.shape[0])] size = array[0].shape for i in range(length - len(array)): diff --git a/egs/grid/AVSR/local/dataset_visual.py b/egs/grid/AVSR/local/dataset_visual.py index f0f2f21ae..720f5d09a 100644 --- a/egs/grid/AVSR/local/dataset_visual.py +++ b/egs/grid/AVSR/local/dataset_visual.py @@ -24,10 +24,13 @@ import os import numpy as np import torch from torch.utils.data import Dataset -from .cvtransforms import HorizontalFlip, ColorNormalize +from .cvtransforms import ( + color_normalize, + horizontal_flip, +) -class dataset_visual(Dataset): +class VisualDataset(Dataset): def __init__( self, video_path: str, @@ -74,8 +77,8 @@ class dataset_visual(Dataset): ) if self.phase == "train": - vid = HorizontalFlip(vid) - vid = ColorNormalize(vid) + vid = horizontal_flip(vid, p=0.5) + vid = color_normalize(vid) vid = self._padding(vid, self.vid_padding) @@ -88,6 +91,14 @@ class dataset_visual(Dataset): return len(self.data) def _load_vid(self, p): + """Load the visual data. + Args: + p: + A directory which contains a sequence of frames + for a visual sample. + Return: + The array of a visual sample. + """ files = os.listdir(p) files = list(filter(lambda file: file.find(".jpg") != -1, files)) files = sorted(files, key=lambda file: int(os.path.splitext(file)[0])) @@ -101,6 +112,13 @@ class dataset_visual(Dataset): return array def _load_anno(self, name): + """Load the text file. + Args: + name: + The file which records the text. + Return: + A sequence of words. + """ with open(name, "r") as f: lines = [line.strip().split(" ") for line in f.readlines()] txt = [line[2] for line in lines] @@ -109,6 +127,15 @@ class dataset_visual(Dataset): return txt def _padding(self, array, length): + """Pad zeros for the feature array. + Args: + array: + The feature arry. (Audio or Visual feature) + length: + The length for padding. + Return: + A new feature array after padding. + """ array = [array[_] for _ in range(array.shape[0])] size = array[0].shape for i in range(length - len(array)): diff --git a/egs/grid/AVSR/visualnet2_ctc_vsr/decode.py b/egs/grid/AVSR/visualnet2_ctc_vsr/decode.py index 1fbfd7650..a47c2b126 100644 --- a/egs/grid/AVSR/visualnet2_ctc_vsr/decode.py +++ b/egs/grid/AVSR/visualnet2_ctc_vsr/decode.py @@ -30,7 +30,7 @@ import torch import torch.nn as nn from torch.utils.data import DataLoader -from local.dataset_visual import dataset_visual +from local.dataset_visual import VisualDataset from model import VisualNet2 @@ -463,7 +463,7 @@ def main(): model.to(device) model.eval() - grid = dataset_visual( + grid = VisualDataset( params.video_path, params.anno_path, params.val_list, diff --git a/egs/grid/AVSR/visualnet2_ctc_vsr/train.py b/egs/grid/AVSR/visualnet2_ctc_vsr/train.py index af9bdec9e..35b691a5a 100644 --- a/egs/grid/AVSR/visualnet2_ctc_vsr/train.py +++ b/egs/grid/AVSR/visualnet2_ctc_vsr/train.py @@ -32,7 +32,7 @@ import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader -from local.dataset_visual import dataset_visual +from local.dataset_visual import VisualDataset from lhotse.utils import fix_random_seed from model import VisualNet2 @@ -529,7 +529,7 @@ def run(rank, world_size, args): optimizer.load_state_dict(checkpoints["optimizer"]) scheduler.load_state_dict(checkpoints["scheduler"]) - grid = dataset_visual( + grid = VisualDataset( params.video_path, params.anno_path, params.train_list, diff --git a/egs/grid/AVSR/visualnet2_ctc_vsr/utils.py b/egs/grid/AVSR/visualnet2_ctc_vsr/utils.py index cf68944bf..17889a31c 100644 --- a/egs/grid/AVSR/visualnet2_ctc_vsr/utils.py +++ b/egs/grid/AVSR/visualnet2_ctc_vsr/utils.py @@ -14,22 +14,32 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +This script is to encodes the supervisions as Tuple list. +The supervision tensor has shape ``(batch_size, 3)``. +Its second dimension contains information about sequence index [0], +start frames [1] and num frames [2]. +In GRID, the start frame of each audio sample is 0. +""" import torch -def encode_supervisions(nnet_output_shape, batch): +def encode_supervisions(nnet_output_shape: int, batch: dict): """ - In GRID, the lengths of all samples are same. - And here, we don't deploy cut operation on it. - So, the start frame is always 0 among all samples. + Args: + nnet_output_shape: + The shape of nnet_output, e.g: (N, T, D). + batch: + A batch of dataloader, it's a dict file + including text and aud/vid arrays. + Return: + The tuple list of supervisions and the text in batch. """ N, T, D = nnet_output_shape - supervisions_idx = torch.arange(0, N).to(torch.int32) - start_frames = [0 for _ in range(N)] - supervisions_start_frame = torch.tensor(start_frames).to(torch.int32) - num_frames = [T for _ in range(N)] - supervisions_num_frames = torch.tensor(num_frames).to(torch.int32) + supervisions_idx = torch.arange(0, N, dtype=torch.int32) + supervisions_start_frame = torch.full((1, N), 0, dtype=torch.int32)[0] + supervisions_num_frames = torch.full((1, N), T, dtype=torch.int32)[0] supervision_segments = torch.stack( ( @@ -38,8 +48,7 @@ def encode_supervisions(nnet_output_shape, batch): supervisions_num_frames, ), 1, - ).to(torch.int32) - + ) texts = batch["txt"] return supervision_segments, texts diff --git a/egs/grid/AVSR/visualnet_ctc_vsr/decode.py b/egs/grid/AVSR/visualnet_ctc_vsr/decode.py index 80b4d8b87..2373ada2c 100644 --- a/egs/grid/AVSR/visualnet_ctc_vsr/decode.py +++ b/egs/grid/AVSR/visualnet_ctc_vsr/decode.py @@ -30,7 +30,7 @@ import torch import torch.nn as nn from torch.utils.data import DataLoader -from local.dataset_visual import dataset_visual +from local.dataset_visual import VisualDataset from model import VisualNet from icefall.checkpoint import average_checkpoints, load_checkpoint @@ -462,7 +462,7 @@ def main(): model.to(device) model.eval() - grid = dataset_visual( + grid = VisualDataset( params.video_path, params.anno_path, params.val_list, diff --git a/egs/grid/AVSR/visualnet_ctc_vsr/train.py b/egs/grid/AVSR/visualnet_ctc_vsr/train.py index 55e1b3b53..df66b5cc9 100644 --- a/egs/grid/AVSR/visualnet_ctc_vsr/train.py +++ b/egs/grid/AVSR/visualnet_ctc_vsr/train.py @@ -32,7 +32,7 @@ import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader -from local.dataset_visual import dataset_visual +from local.dataset_visual import VisualDataset from lhotse.utils import fix_random_seed from model import VisualNet from torch import Tensor @@ -528,7 +528,7 @@ def run(rank, world_size, args): optimizer.load_state_dict(checkpoints["optimizer"]) scheduler.load_state_dict(checkpoints["scheduler"]) - grid = dataset_visual( + grid = VisualDataset( params.video_path, params.anno_path, params.train_list, diff --git a/egs/grid/AVSR/visualnet_ctc_vsr/utils.py b/egs/grid/AVSR/visualnet_ctc_vsr/utils.py index cf68944bf..17889a31c 100644 --- a/egs/grid/AVSR/visualnet_ctc_vsr/utils.py +++ b/egs/grid/AVSR/visualnet_ctc_vsr/utils.py @@ -14,22 +14,32 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +This script is to encodes the supervisions as Tuple list. +The supervision tensor has shape ``(batch_size, 3)``. +Its second dimension contains information about sequence index [0], +start frames [1] and num frames [2]. +In GRID, the start frame of each audio sample is 0. +""" import torch -def encode_supervisions(nnet_output_shape, batch): +def encode_supervisions(nnet_output_shape: int, batch: dict): """ - In GRID, the lengths of all samples are same. - And here, we don't deploy cut operation on it. - So, the start frame is always 0 among all samples. + Args: + nnet_output_shape: + The shape of nnet_output, e.g: (N, T, D). + batch: + A batch of dataloader, it's a dict file + including text and aud/vid arrays. + Return: + The tuple list of supervisions and the text in batch. """ N, T, D = nnet_output_shape - supervisions_idx = torch.arange(0, N).to(torch.int32) - start_frames = [0 for _ in range(N)] - supervisions_start_frame = torch.tensor(start_frames).to(torch.int32) - num_frames = [T for _ in range(N)] - supervisions_num_frames = torch.tensor(num_frames).to(torch.int32) + supervisions_idx = torch.arange(0, N, dtype=torch.int32) + supervisions_start_frame = torch.full((1, N), 0, dtype=torch.int32)[0] + supervisions_num_frames = torch.full((1, N), T, dtype=torch.int32)[0] supervision_segments = torch.stack( ( @@ -38,8 +48,7 @@ def encode_supervisions(nnet_output_shape, batch): supervisions_num_frames, ), 1, - ).to(torch.int32) - + ) texts = batch["txt"] return supervision_segments, texts