mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Done some changes that are required.
This commit is contained in:
parent
7391f4febf
commit
d412dbb2f0
@ -30,7 +30,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from local.dataset_audio import dataset_audio
|
||||
from local.dataset_audio import AudioDataSet
|
||||
from model import AudioNet
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
@ -467,7 +467,7 @@ def main():
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
grid = dataset_audio(
|
||||
grid = AudioDataSet(
|
||||
params.video_path,
|
||||
params.anno_path,
|
||||
params.val_list,
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
|
@ -209,7 +209,7 @@ def main():
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
# aud_padding = 480
|
||||
# Here , we set aud_padding as 480.
|
||||
features_new = torch.zeros(len(features), 480, params.feature_dim).to(
|
||||
device
|
||||
)
|
||||
|
@ -32,7 +32,7 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from local.dataset_audio import dataset_audio
|
||||
from local.dataset_audio import AudioDataSet
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import AudioNet
|
||||
from torch import Tensor
|
||||
@ -533,7 +533,7 @@ def run(rank, world_size, args):
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
grid = dataset_audio(
|
||||
grid = AudioDataSet(
|
||||
params.video_path,
|
||||
params.anno_path,
|
||||
params.train_list,
|
||||
|
@ -14,27 +14,32 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script is to encodes the supervisions as Tuple list.
|
||||
The supervision tensor has shape ``(batch_size, 3)``.
|
||||
Its second dimension contains information about sequence index [0],
|
||||
start frames [1] and num frames [2].
|
||||
In GRID, the start frame of each audio sample is 0.
|
||||
"""
|
||||
import torch
|
||||
|
||||
|
||||
def encode_supervisions(nnet_output_shape, batch):
|
||||
def encode_supervisions(nnet_output_shape: int, batch: dict):
|
||||
"""
|
||||
Encodes the output of net and texts into
|
||||
a pair of torch Tensor, and a list of transcription strings.
|
||||
|
||||
The supervision tensor has shape ``(batch_size, 3)``.
|
||||
Its second dimension contains information about sequence index [0],
|
||||
start frames [1] and num frames [2].
|
||||
|
||||
In GRID, the start frame of each audio sample is 0.
|
||||
Args:
|
||||
nnet_output_shape:
|
||||
The shape of nnet_output, e.g: (N, T, D).
|
||||
batch:
|
||||
A batch of dataloader, it's a dict file
|
||||
including text and aud/vid arrays.
|
||||
Return:
|
||||
The tuple list of supervisions and the text in batch.
|
||||
"""
|
||||
N, T, D = nnet_output_shape
|
||||
|
||||
supervisions_idx = torch.arange(0, N).to(torch.int32)
|
||||
start_frames = [0 for _ in range(N)]
|
||||
supervisions_start_frame = torch.tensor(start_frames).to(torch.int32)
|
||||
num_frames = [T for _ in range(N)]
|
||||
supervisions_num_frames = torch.tensor(num_frames).to(torch.int32)
|
||||
supervisions_idx = torch.arange(0, N, dtype=torch.int32)
|
||||
supervisions_start_frame = torch.full((1, N), 0, dtype=torch.int32)[0]
|
||||
supervisions_num_frames = torch.full((1, N), T, dtype=torch.int32)[0]
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
@ -43,7 +48,7 @@ def encode_supervisions(nnet_output_shape, batch):
|
||||
supervisions_num_frames,
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
)
|
||||
texts = batch["txt"]
|
||||
|
||||
return supervision_segments, texts
|
||||
|
@ -30,7 +30,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from local.dataset_av import dataset_av
|
||||
from local.dataset_av import AudioVisualDataset
|
||||
from model import CombineNet
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
@ -475,7 +475,7 @@ def main():
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
grid = dataset_av(
|
||||
grid = AudioVisualDataset(
|
||||
params.video_path,
|
||||
params.anno_path,
|
||||
params.val_list,
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
|
@ -32,7 +32,7 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from local.dataset_av import dataset_av
|
||||
from local.dataset_av import AudioVisualDataset
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import CombineNet
|
||||
from torch import Tensor
|
||||
@ -544,7 +544,7 @@ def run(rank, world_size, args):
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
grid = dataset_av(
|
||||
grid = AudioVisualDataset(
|
||||
params.video_path,
|
||||
params.anno_path,
|
||||
params.train_list,
|
||||
|
@ -19,24 +19,20 @@ import torch
|
||||
|
||||
def encode_supervisions(nnet_output_shape, batch):
|
||||
"""
|
||||
Encodes Lhotse's ``batch["supervisions"]`` dict into
|
||||
Encodes the output of net and texts into
|
||||
a pair of torch Tensor, and a list of transcription strings.
|
||||
|
||||
The supervision tensor has shape ``(batch_size, 3)``.
|
||||
Its second dimension contains information about sequence index [0],
|
||||
start frames [1] and num frames [2].
|
||||
|
||||
The batch items might become re-ordered during this operation -- the
|
||||
returned tensor and list of strings are guaranteed to be consistent with
|
||||
each other.
|
||||
In GRID, the start frame of each audio sample is 0.
|
||||
"""
|
||||
N, T, D = nnet_output_shape
|
||||
|
||||
supervisions_idx = torch.arange(0, N).to(torch.int32)
|
||||
start_frames = [0 for _ in range(N)]
|
||||
supervisions_start_frame = torch.tensor(start_frames).to(torch.int32)
|
||||
num_frames = [T for _ in range(N)]
|
||||
supervisions_num_frames = torch.tensor(num_frames).to(torch.int32)
|
||||
supervisions_idx = torch.arange(0, N, dtype=torch.int32)
|
||||
supervisions_start_frame = torch.full((1, N), 0, dtype=torch.int32)[0]
|
||||
supervisions_num_frames = torch.full((1, N), T, dtype=torch.int32)[0]
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
@ -45,7 +41,7 @@ def encode_supervisions(nnet_output_shape, batch):
|
||||
supervisions_num_frames,
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
)
|
||||
texts = batch["txt"]
|
||||
|
||||
return supervision_segments, texts
|
||||
|
@ -26,13 +26,31 @@ The input for the above functions is a sequence of images.
|
||||
import random
|
||||
|
||||
|
||||
def HorizontalFlip(batch_img, p=0.5):
|
||||
# (T, H, W, C)
|
||||
def horizontal_flip(batch_img: float, p: float):
|
||||
"""
|
||||
Args:
|
||||
batch_img:
|
||||
The float array of a sequence of images, the shape of the
|
||||
arrat is (T, H, W, C).
|
||||
p:
|
||||
The probability of implementing horizontal flip, the defaults
|
||||
value is 0.5.
|
||||
Return:
|
||||
A new float array of the sequence of images after flipping.
|
||||
"""
|
||||
if random.random() > p:
|
||||
batch_img = batch_img[:, :, ::-1, ...]
|
||||
return batch_img
|
||||
|
||||
|
||||
def ColorNormalize(batch_img):
|
||||
def color_normalize(batch_img: float):
|
||||
"""
|
||||
Args:
|
||||
batch_img:
|
||||
The float array of a sequence of images, the shape of the
|
||||
arrat is (T, H, W, C).
|
||||
Return:
|
||||
A new float array of the sequence of images after normalizing.
|
||||
"""
|
||||
batch_img = batch_img / 255.0
|
||||
return batch_img
|
||||
|
@ -19,7 +19,6 @@
|
||||
This script is to load the audio data in GRID.
|
||||
The class dataset_audio makes each audio batch data have the same shape.
|
||||
"""
|
||||
import kaldifeat
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
@ -27,8 +26,10 @@ import torch
|
||||
import torchaudio
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import kaldifeat
|
||||
|
||||
class dataset_audio(Dataset):
|
||||
|
||||
class AudioDataSet(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
video_path: str,
|
||||
@ -46,7 +47,7 @@ class dataset_audio(Dataset):
|
||||
anno_path:
|
||||
The dir path of the texts data.
|
||||
file_list:
|
||||
The file which listing all samples for training or testing.
|
||||
A txt file which listing all samples for training or testing.
|
||||
aud_padding:
|
||||
The padding for each audio sample.
|
||||
sample_rate:
|
||||
@ -61,6 +62,15 @@ class dataset_audio(Dataset):
|
||||
self.sample_rate = sample_rate
|
||||
self.feature_dim = feature_dim
|
||||
self.phase = phase
|
||||
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = self.sample_rate
|
||||
opts.mel_opts.num_bins = self.feature_dim
|
||||
self.fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
with open(file_list, "r") as f:
|
||||
self.videos = [
|
||||
os.path.join(video_path, line.strip()) for line in f.readlines()
|
||||
@ -92,19 +102,26 @@ class dataset_audio(Dataset):
|
||||
return len(self.data)
|
||||
|
||||
def _load_aud(self, filename):
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = self.sample_rate
|
||||
opts.mel_opts.num_bins = self.feature_dim
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
wave, sr = torchaudio.load(filename)
|
||||
"""Load the audio data.
|
||||
Args:
|
||||
filename:
|
||||
The full path of a wav file.
|
||||
Return:
|
||||
The fbank feature array.
|
||||
"""
|
||||
wave, _ = torchaudio.load(filename)
|
||||
wave = wave[0]
|
||||
features = fbank(wave)
|
||||
features = self.fbank(wave)
|
||||
return features
|
||||
|
||||
def _load_anno(self, name):
|
||||
"""Load the text file.
|
||||
Args:
|
||||
name:
|
||||
The file which records the text.
|
||||
Return:
|
||||
A sequence of words.
|
||||
"""
|
||||
with open(name, "r") as f:
|
||||
lines = [line.strip().split(" ") for line in f.readlines()]
|
||||
txt = [line[2] for line in lines]
|
||||
@ -113,6 +130,15 @@ class dataset_audio(Dataset):
|
||||
return txt
|
||||
|
||||
def _padding(self, array, length):
|
||||
"""Pad zeros for the feature array.
|
||||
Args:
|
||||
array:
|
||||
The feature arry. (Audio or Visual feature)
|
||||
length:
|
||||
The length for padding.
|
||||
Return:
|
||||
A new feature array after padding.
|
||||
"""
|
||||
array = [array[_] for _ in range(array.shape[0])]
|
||||
size = array[0].shape
|
||||
for i in range(length - len(array)):
|
||||
|
@ -20,7 +20,6 @@ This script is to load the pair of audio-visual data in GRID.
|
||||
The class dataset_av makes each audio-visual batch data have the same shape.
|
||||
"""
|
||||
import cv2
|
||||
import kaldifeat
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
@ -28,10 +27,11 @@ import torch
|
||||
import torchaudio
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from .cvtransforms import HorizontalFlip, ColorNormalize
|
||||
import kaldifeat
|
||||
from .cvtransforms import horizontal_flip, color_normalize
|
||||
|
||||
|
||||
class dataset_av(Dataset):
|
||||
class AudioVisualDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
video_path,
|
||||
@ -94,8 +94,8 @@ class dataset_av(Dataset):
|
||||
)
|
||||
|
||||
if self.phase == "train":
|
||||
vid = HorizontalFlip(vid)
|
||||
vid = ColorNormalize(vid)
|
||||
vid = horizontal_flip(vid)
|
||||
vid = color_normalize(vid)
|
||||
|
||||
vid = self._padding(vid, self.vid_pading)
|
||||
aud = self._padding(aud, self.aud_pading)
|
||||
@ -110,6 +110,14 @@ class dataset_av(Dataset):
|
||||
return len(self.data)
|
||||
|
||||
def _load_vid(self, p):
|
||||
"""Load the visual data.
|
||||
Args:
|
||||
p:
|
||||
A directory which contains a sequence of frames
|
||||
for a visual sample.
|
||||
Return:
|
||||
The array of a visual sample.
|
||||
"""
|
||||
files = os.listdir(p)
|
||||
files = list(filter(lambda file: file.find(".jpg") != -1, files))
|
||||
files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
|
||||
@ -123,6 +131,13 @@ class dataset_av(Dataset):
|
||||
return array
|
||||
|
||||
def _load_aud(self, filename):
|
||||
"""Load the audio data.
|
||||
Args:
|
||||
filename:
|
||||
The full path of a wav file.
|
||||
Return:
|
||||
The fbank feature array.
|
||||
"""
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
@ -135,6 +150,13 @@ class dataset_av(Dataset):
|
||||
return features
|
||||
|
||||
def _load_anno(self, name):
|
||||
"""Load the text file.
|
||||
Args:
|
||||
name:
|
||||
The file which records the text.
|
||||
Return:
|
||||
A sequence of words.
|
||||
"""
|
||||
with open(name, "r") as f:
|
||||
lines = [line.strip().split(" ") for line in f.readlines()]
|
||||
txt = [line[2] for line in lines]
|
||||
@ -143,6 +165,15 @@ class dataset_av(Dataset):
|
||||
return txt
|
||||
|
||||
def _padding(self, array, length):
|
||||
"""Pad zeros for the feature array.
|
||||
Args:
|
||||
array:
|
||||
The feature arry. (Audio or Visual feature)
|
||||
length:
|
||||
The length for padding.
|
||||
Return:
|
||||
A new feature array after padding.
|
||||
"""
|
||||
array = [array[_] for _ in range(array.shape[0])]
|
||||
size = array[0].shape
|
||||
for i in range(length - len(array)):
|
||||
|
@ -24,10 +24,13 @@ import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from .cvtransforms import HorizontalFlip, ColorNormalize
|
||||
from .cvtransforms import (
|
||||
color_normalize,
|
||||
horizontal_flip,
|
||||
)
|
||||
|
||||
|
||||
class dataset_visual(Dataset):
|
||||
class VisualDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
video_path: str,
|
||||
@ -74,8 +77,8 @@ class dataset_visual(Dataset):
|
||||
)
|
||||
|
||||
if self.phase == "train":
|
||||
vid = HorizontalFlip(vid)
|
||||
vid = ColorNormalize(vid)
|
||||
vid = horizontal_flip(vid, p=0.5)
|
||||
vid = color_normalize(vid)
|
||||
|
||||
vid = self._padding(vid, self.vid_padding)
|
||||
|
||||
@ -88,6 +91,14 @@ class dataset_visual(Dataset):
|
||||
return len(self.data)
|
||||
|
||||
def _load_vid(self, p):
|
||||
"""Load the visual data.
|
||||
Args:
|
||||
p:
|
||||
A directory which contains a sequence of frames
|
||||
for a visual sample.
|
||||
Return:
|
||||
The array of a visual sample.
|
||||
"""
|
||||
files = os.listdir(p)
|
||||
files = list(filter(lambda file: file.find(".jpg") != -1, files))
|
||||
files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
|
||||
@ -101,6 +112,13 @@ class dataset_visual(Dataset):
|
||||
return array
|
||||
|
||||
def _load_anno(self, name):
|
||||
"""Load the text file.
|
||||
Args:
|
||||
name:
|
||||
The file which records the text.
|
||||
Return:
|
||||
A sequence of words.
|
||||
"""
|
||||
with open(name, "r") as f:
|
||||
lines = [line.strip().split(" ") for line in f.readlines()]
|
||||
txt = [line[2] for line in lines]
|
||||
@ -109,6 +127,15 @@ class dataset_visual(Dataset):
|
||||
return txt
|
||||
|
||||
def _padding(self, array, length):
|
||||
"""Pad zeros for the feature array.
|
||||
Args:
|
||||
array:
|
||||
The feature arry. (Audio or Visual feature)
|
||||
length:
|
||||
The length for padding.
|
||||
Return:
|
||||
A new feature array after padding.
|
||||
"""
|
||||
array = [array[_] for _ in range(array.shape[0])]
|
||||
size = array[0].shape
|
||||
for i in range(length - len(array)):
|
||||
|
@ -30,7 +30,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from local.dataset_visual import dataset_visual
|
||||
from local.dataset_visual import VisualDataset
|
||||
|
||||
from model import VisualNet2
|
||||
|
||||
@ -463,7 +463,7 @@ def main():
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
grid = dataset_visual(
|
||||
grid = VisualDataset(
|
||||
params.video_path,
|
||||
params.anno_path,
|
||||
params.val_list,
|
||||
|
@ -32,7 +32,7 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from local.dataset_visual import dataset_visual
|
||||
from local.dataset_visual import VisualDataset
|
||||
from lhotse.utils import fix_random_seed
|
||||
|
||||
from model import VisualNet2
|
||||
@ -529,7 +529,7 @@ def run(rank, world_size, args):
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
grid = dataset_visual(
|
||||
grid = VisualDataset(
|
||||
params.video_path,
|
||||
params.anno_path,
|
||||
params.train_list,
|
||||
|
@ -14,22 +14,32 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script is to encodes the supervisions as Tuple list.
|
||||
The supervision tensor has shape ``(batch_size, 3)``.
|
||||
Its second dimension contains information about sequence index [0],
|
||||
start frames [1] and num frames [2].
|
||||
In GRID, the start frame of each audio sample is 0.
|
||||
"""
|
||||
import torch
|
||||
|
||||
|
||||
def encode_supervisions(nnet_output_shape, batch):
|
||||
def encode_supervisions(nnet_output_shape: int, batch: dict):
|
||||
"""
|
||||
In GRID, the lengths of all samples are same.
|
||||
And here, we don't deploy cut operation on it.
|
||||
So, the start frame is always 0 among all samples.
|
||||
Args:
|
||||
nnet_output_shape:
|
||||
The shape of nnet_output, e.g: (N, T, D).
|
||||
batch:
|
||||
A batch of dataloader, it's a dict file
|
||||
including text and aud/vid arrays.
|
||||
Return:
|
||||
The tuple list of supervisions and the text in batch.
|
||||
"""
|
||||
N, T, D = nnet_output_shape
|
||||
|
||||
supervisions_idx = torch.arange(0, N).to(torch.int32)
|
||||
start_frames = [0 for _ in range(N)]
|
||||
supervisions_start_frame = torch.tensor(start_frames).to(torch.int32)
|
||||
num_frames = [T for _ in range(N)]
|
||||
supervisions_num_frames = torch.tensor(num_frames).to(torch.int32)
|
||||
supervisions_idx = torch.arange(0, N, dtype=torch.int32)
|
||||
supervisions_start_frame = torch.full((1, N), 0, dtype=torch.int32)[0]
|
||||
supervisions_num_frames = torch.full((1, N), T, dtype=torch.int32)[0]
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
@ -38,8 +48,7 @@ def encode_supervisions(nnet_output_shape, batch):
|
||||
supervisions_num_frames,
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
||||
)
|
||||
texts = batch["txt"]
|
||||
|
||||
return supervision_segments, texts
|
||||
|
@ -30,7 +30,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from local.dataset_visual import dataset_visual
|
||||
from local.dataset_visual import VisualDataset
|
||||
from model import VisualNet
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
@ -462,7 +462,7 @@ def main():
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
grid = dataset_visual(
|
||||
grid = VisualDataset(
|
||||
params.video_path,
|
||||
params.anno_path,
|
||||
params.val_list,
|
||||
|
@ -32,7 +32,7 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from local.dataset_visual import dataset_visual
|
||||
from local.dataset_visual import VisualDataset
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import VisualNet
|
||||
from torch import Tensor
|
||||
@ -528,7 +528,7 @@ def run(rank, world_size, args):
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
grid = dataset_visual(
|
||||
grid = VisualDataset(
|
||||
params.video_path,
|
||||
params.anno_path,
|
||||
params.train_list,
|
||||
|
@ -14,22 +14,32 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script is to encodes the supervisions as Tuple list.
|
||||
The supervision tensor has shape ``(batch_size, 3)``.
|
||||
Its second dimension contains information about sequence index [0],
|
||||
start frames [1] and num frames [2].
|
||||
In GRID, the start frame of each audio sample is 0.
|
||||
"""
|
||||
import torch
|
||||
|
||||
|
||||
def encode_supervisions(nnet_output_shape, batch):
|
||||
def encode_supervisions(nnet_output_shape: int, batch: dict):
|
||||
"""
|
||||
In GRID, the lengths of all samples are same.
|
||||
And here, we don't deploy cut operation on it.
|
||||
So, the start frame is always 0 among all samples.
|
||||
Args:
|
||||
nnet_output_shape:
|
||||
The shape of nnet_output, e.g: (N, T, D).
|
||||
batch:
|
||||
A batch of dataloader, it's a dict file
|
||||
including text and aud/vid arrays.
|
||||
Return:
|
||||
The tuple list of supervisions and the text in batch.
|
||||
"""
|
||||
N, T, D = nnet_output_shape
|
||||
|
||||
supervisions_idx = torch.arange(0, N).to(torch.int32)
|
||||
start_frames = [0 for _ in range(N)]
|
||||
supervisions_start_frame = torch.tensor(start_frames).to(torch.int32)
|
||||
num_frames = [T for _ in range(N)]
|
||||
supervisions_num_frames = torch.tensor(num_frames).to(torch.int32)
|
||||
supervisions_idx = torch.arange(0, N, dtype=torch.int32)
|
||||
supervisions_start_frame = torch.full((1, N), 0, dtype=torch.int32)[0]
|
||||
supervisions_num_frames = torch.full((1, N), T, dtype=torch.int32)[0]
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
@ -38,8 +48,7 @@ def encode_supervisions(nnet_output_shape, batch):
|
||||
supervisions_num_frames,
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
||||
)
|
||||
texts = batch["txt"]
|
||||
|
||||
return supervision_segments, texts
|
||||
|
Loading…
x
Reference in New Issue
Block a user