mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
Done some changes that are required.
This commit is contained in:
parent
7391f4febf
commit
d412dbb2f0
@ -30,7 +30,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from local.dataset_audio import dataset_audio
|
from local.dataset_audio import AudioDataSet
|
||||||
from model import AudioNet
|
from model import AudioNet
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
@ -467,7 +467,7 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
grid = dataset_audio(
|
grid = AudioDataSet(
|
||||||
params.video_path,
|
params.video_path,
|
||||||
params.anno_path,
|
params.anno_path,
|
||||||
params.val_list,
|
params.val_list,
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||||
|
# Mingshuang Luo)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
|
@ -209,7 +209,7 @@ def main():
|
|||||||
|
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
features = fbank(waves)
|
features = fbank(waves)
|
||||||
# aud_padding = 480
|
# Here , we set aud_padding as 480.
|
||||||
features_new = torch.zeros(len(features), 480, params.feature_dim).to(
|
features_new = torch.zeros(len(features), 480, params.feature_dim).to(
|
||||||
device
|
device
|
||||||
)
|
)
|
||||||
|
@ -32,7 +32,7 @@ import torch.nn as nn
|
|||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from local.dataset_audio import dataset_audio
|
from local.dataset_audio import AudioDataSet
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import AudioNet
|
from model import AudioNet
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -533,7 +533,7 @@ def run(rank, world_size, args):
|
|||||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||||
|
|
||||||
grid = dataset_audio(
|
grid = AudioDataSet(
|
||||||
params.video_path,
|
params.video_path,
|
||||||
params.anno_path,
|
params.anno_path,
|
||||||
params.train_list,
|
params.train_list,
|
||||||
|
@ -14,27 +14,32 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script is to encodes the supervisions as Tuple list.
|
||||||
|
The supervision tensor has shape ``(batch_size, 3)``.
|
||||||
|
Its second dimension contains information about sequence index [0],
|
||||||
|
start frames [1] and num frames [2].
|
||||||
|
In GRID, the start frame of each audio sample is 0.
|
||||||
|
"""
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def encode_supervisions(nnet_output_shape, batch):
|
def encode_supervisions(nnet_output_shape: int, batch: dict):
|
||||||
"""
|
"""
|
||||||
Encodes the output of net and texts into
|
Args:
|
||||||
a pair of torch Tensor, and a list of transcription strings.
|
nnet_output_shape:
|
||||||
|
The shape of nnet_output, e.g: (N, T, D).
|
||||||
The supervision tensor has shape ``(batch_size, 3)``.
|
batch:
|
||||||
Its second dimension contains information about sequence index [0],
|
A batch of dataloader, it's a dict file
|
||||||
start frames [1] and num frames [2].
|
including text and aud/vid arrays.
|
||||||
|
Return:
|
||||||
In GRID, the start frame of each audio sample is 0.
|
The tuple list of supervisions and the text in batch.
|
||||||
"""
|
"""
|
||||||
N, T, D = nnet_output_shape
|
N, T, D = nnet_output_shape
|
||||||
|
|
||||||
supervisions_idx = torch.arange(0, N).to(torch.int32)
|
supervisions_idx = torch.arange(0, N, dtype=torch.int32)
|
||||||
start_frames = [0 for _ in range(N)]
|
supervisions_start_frame = torch.full((1, N), 0, dtype=torch.int32)[0]
|
||||||
supervisions_start_frame = torch.tensor(start_frames).to(torch.int32)
|
supervisions_num_frames = torch.full((1, N), T, dtype=torch.int32)[0]
|
||||||
num_frames = [T for _ in range(N)]
|
|
||||||
supervisions_num_frames = torch.tensor(num_frames).to(torch.int32)
|
|
||||||
|
|
||||||
supervision_segments = torch.stack(
|
supervision_segments = torch.stack(
|
||||||
(
|
(
|
||||||
@ -43,7 +48,7 @@ def encode_supervisions(nnet_output_shape, batch):
|
|||||||
supervisions_num_frames,
|
supervisions_num_frames,
|
||||||
),
|
),
|
||||||
1,
|
1,
|
||||||
).to(torch.int32)
|
)
|
||||||
texts = batch["txt"]
|
texts = batch["txt"]
|
||||||
|
|
||||||
return supervision_segments, texts
|
return supervision_segments, texts
|
||||||
|
@ -30,7 +30,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from local.dataset_av import dataset_av
|
from local.dataset_av import AudioVisualDataset
|
||||||
from model import CombineNet
|
from model import CombineNet
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
@ -475,7 +475,7 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
grid = dataset_av(
|
grid = AudioVisualDataset(
|
||||||
params.video_path,
|
params.video_path,
|
||||||
params.anno_path,
|
params.anno_path,
|
||||||
params.val_list,
|
params.val_list,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
|
@ -32,7 +32,7 @@ import torch.nn as nn
|
|||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from local.dataset_av import dataset_av
|
from local.dataset_av import AudioVisualDataset
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import CombineNet
|
from model import CombineNet
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -544,7 +544,7 @@ def run(rank, world_size, args):
|
|||||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||||
|
|
||||||
grid = dataset_av(
|
grid = AudioVisualDataset(
|
||||||
params.video_path,
|
params.video_path,
|
||||||
params.anno_path,
|
params.anno_path,
|
||||||
params.train_list,
|
params.train_list,
|
||||||
|
@ -19,24 +19,20 @@ import torch
|
|||||||
|
|
||||||
def encode_supervisions(nnet_output_shape, batch):
|
def encode_supervisions(nnet_output_shape, batch):
|
||||||
"""
|
"""
|
||||||
Encodes Lhotse's ``batch["supervisions"]`` dict into
|
Encodes the output of net and texts into
|
||||||
a pair of torch Tensor, and a list of transcription strings.
|
a pair of torch Tensor, and a list of transcription strings.
|
||||||
|
|
||||||
The supervision tensor has shape ``(batch_size, 3)``.
|
The supervision tensor has shape ``(batch_size, 3)``.
|
||||||
Its second dimension contains information about sequence index [0],
|
Its second dimension contains information about sequence index [0],
|
||||||
start frames [1] and num frames [2].
|
start frames [1] and num frames [2].
|
||||||
|
|
||||||
The batch items might become re-ordered during this operation -- the
|
In GRID, the start frame of each audio sample is 0.
|
||||||
returned tensor and list of strings are guaranteed to be consistent with
|
|
||||||
each other.
|
|
||||||
"""
|
"""
|
||||||
N, T, D = nnet_output_shape
|
N, T, D = nnet_output_shape
|
||||||
|
|
||||||
supervisions_idx = torch.arange(0, N).to(torch.int32)
|
supervisions_idx = torch.arange(0, N, dtype=torch.int32)
|
||||||
start_frames = [0 for _ in range(N)]
|
supervisions_start_frame = torch.full((1, N), 0, dtype=torch.int32)[0]
|
||||||
supervisions_start_frame = torch.tensor(start_frames).to(torch.int32)
|
supervisions_num_frames = torch.full((1, N), T, dtype=torch.int32)[0]
|
||||||
num_frames = [T for _ in range(N)]
|
|
||||||
supervisions_num_frames = torch.tensor(num_frames).to(torch.int32)
|
|
||||||
|
|
||||||
supervision_segments = torch.stack(
|
supervision_segments = torch.stack(
|
||||||
(
|
(
|
||||||
@ -45,7 +41,7 @@ def encode_supervisions(nnet_output_shape, batch):
|
|||||||
supervisions_num_frames,
|
supervisions_num_frames,
|
||||||
),
|
),
|
||||||
1,
|
1,
|
||||||
).to(torch.int32)
|
)
|
||||||
texts = batch["txt"]
|
texts = batch["txt"]
|
||||||
|
|
||||||
return supervision_segments, texts
|
return supervision_segments, texts
|
||||||
|
@ -26,13 +26,31 @@ The input for the above functions is a sequence of images.
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
|
|
||||||
def HorizontalFlip(batch_img, p=0.5):
|
def horizontal_flip(batch_img: float, p: float):
|
||||||
# (T, H, W, C)
|
"""
|
||||||
|
Args:
|
||||||
|
batch_img:
|
||||||
|
The float array of a sequence of images, the shape of the
|
||||||
|
arrat is (T, H, W, C).
|
||||||
|
p:
|
||||||
|
The probability of implementing horizontal flip, the defaults
|
||||||
|
value is 0.5.
|
||||||
|
Return:
|
||||||
|
A new float array of the sequence of images after flipping.
|
||||||
|
"""
|
||||||
if random.random() > p:
|
if random.random() > p:
|
||||||
batch_img = batch_img[:, :, ::-1, ...]
|
batch_img = batch_img[:, :, ::-1, ...]
|
||||||
return batch_img
|
return batch_img
|
||||||
|
|
||||||
|
|
||||||
def ColorNormalize(batch_img):
|
def color_normalize(batch_img: float):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
batch_img:
|
||||||
|
The float array of a sequence of images, the shape of the
|
||||||
|
arrat is (T, H, W, C).
|
||||||
|
Return:
|
||||||
|
A new float array of the sequence of images after normalizing.
|
||||||
|
"""
|
||||||
batch_img = batch_img / 255.0
|
batch_img = batch_img / 255.0
|
||||||
return batch_img
|
return batch_img
|
||||||
|
@ -19,7 +19,6 @@
|
|||||||
This script is to load the audio data in GRID.
|
This script is to load the audio data in GRID.
|
||||||
The class dataset_audio makes each audio batch data have the same shape.
|
The class dataset_audio makes each audio batch data have the same shape.
|
||||||
"""
|
"""
|
||||||
import kaldifeat
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@ -27,8 +26,10 @@ import torch
|
|||||||
import torchaudio
|
import torchaudio
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
import kaldifeat
|
||||||
|
|
||||||
class dataset_audio(Dataset):
|
|
||||||
|
class AudioDataSet(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
video_path: str,
|
video_path: str,
|
||||||
@ -46,7 +47,7 @@ class dataset_audio(Dataset):
|
|||||||
anno_path:
|
anno_path:
|
||||||
The dir path of the texts data.
|
The dir path of the texts data.
|
||||||
file_list:
|
file_list:
|
||||||
The file which listing all samples for training or testing.
|
A txt file which listing all samples for training or testing.
|
||||||
aud_padding:
|
aud_padding:
|
||||||
The padding for each audio sample.
|
The padding for each audio sample.
|
||||||
sample_rate:
|
sample_rate:
|
||||||
@ -61,6 +62,15 @@ class dataset_audio(Dataset):
|
|||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.feature_dim = feature_dim
|
self.feature_dim = feature_dim
|
||||||
self.phase = phase
|
self.phase = phase
|
||||||
|
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.device = "cpu"
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = self.sample_rate
|
||||||
|
opts.mel_opts.num_bins = self.feature_dim
|
||||||
|
self.fbank = kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
with open(file_list, "r") as f:
|
with open(file_list, "r") as f:
|
||||||
self.videos = [
|
self.videos = [
|
||||||
os.path.join(video_path, line.strip()) for line in f.readlines()
|
os.path.join(video_path, line.strip()) for line in f.readlines()
|
||||||
@ -92,19 +102,26 @@ class dataset_audio(Dataset):
|
|||||||
return len(self.data)
|
return len(self.data)
|
||||||
|
|
||||||
def _load_aud(self, filename):
|
def _load_aud(self, filename):
|
||||||
opts = kaldifeat.FbankOptions()
|
"""Load the audio data.
|
||||||
opts.device = "cpu"
|
Args:
|
||||||
opts.frame_opts.dither = 0
|
filename:
|
||||||
opts.frame_opts.snip_edges = False
|
The full path of a wav file.
|
||||||
opts.frame_opts.samp_freq = self.sample_rate
|
Return:
|
||||||
opts.mel_opts.num_bins = self.feature_dim
|
The fbank feature array.
|
||||||
fbank = kaldifeat.Fbank(opts)
|
"""
|
||||||
wave, sr = torchaudio.load(filename)
|
wave, _ = torchaudio.load(filename)
|
||||||
wave = wave[0]
|
wave = wave[0]
|
||||||
features = fbank(wave)
|
features = self.fbank(wave)
|
||||||
return features
|
return features
|
||||||
|
|
||||||
def _load_anno(self, name):
|
def _load_anno(self, name):
|
||||||
|
"""Load the text file.
|
||||||
|
Args:
|
||||||
|
name:
|
||||||
|
The file which records the text.
|
||||||
|
Return:
|
||||||
|
A sequence of words.
|
||||||
|
"""
|
||||||
with open(name, "r") as f:
|
with open(name, "r") as f:
|
||||||
lines = [line.strip().split(" ") for line in f.readlines()]
|
lines = [line.strip().split(" ") for line in f.readlines()]
|
||||||
txt = [line[2] for line in lines]
|
txt = [line[2] for line in lines]
|
||||||
@ -113,6 +130,15 @@ class dataset_audio(Dataset):
|
|||||||
return txt
|
return txt
|
||||||
|
|
||||||
def _padding(self, array, length):
|
def _padding(self, array, length):
|
||||||
|
"""Pad zeros for the feature array.
|
||||||
|
Args:
|
||||||
|
array:
|
||||||
|
The feature arry. (Audio or Visual feature)
|
||||||
|
length:
|
||||||
|
The length for padding.
|
||||||
|
Return:
|
||||||
|
A new feature array after padding.
|
||||||
|
"""
|
||||||
array = [array[_] for _ in range(array.shape[0])]
|
array = [array[_] for _ in range(array.shape[0])]
|
||||||
size = array[0].shape
|
size = array[0].shape
|
||||||
for i in range(length - len(array)):
|
for i in range(length - len(array)):
|
||||||
|
@ -20,7 +20,6 @@ This script is to load the pair of audio-visual data in GRID.
|
|||||||
The class dataset_av makes each audio-visual batch data have the same shape.
|
The class dataset_av makes each audio-visual batch data have the same shape.
|
||||||
"""
|
"""
|
||||||
import cv2
|
import cv2
|
||||||
import kaldifeat
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@ -28,10 +27,11 @@ import torch
|
|||||||
import torchaudio
|
import torchaudio
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from .cvtransforms import HorizontalFlip, ColorNormalize
|
import kaldifeat
|
||||||
|
from .cvtransforms import horizontal_flip, color_normalize
|
||||||
|
|
||||||
|
|
||||||
class dataset_av(Dataset):
|
class AudioVisualDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
video_path,
|
video_path,
|
||||||
@ -94,8 +94,8 @@ class dataset_av(Dataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.phase == "train":
|
if self.phase == "train":
|
||||||
vid = HorizontalFlip(vid)
|
vid = horizontal_flip(vid)
|
||||||
vid = ColorNormalize(vid)
|
vid = color_normalize(vid)
|
||||||
|
|
||||||
vid = self._padding(vid, self.vid_pading)
|
vid = self._padding(vid, self.vid_pading)
|
||||||
aud = self._padding(aud, self.aud_pading)
|
aud = self._padding(aud, self.aud_pading)
|
||||||
@ -110,6 +110,14 @@ class dataset_av(Dataset):
|
|||||||
return len(self.data)
|
return len(self.data)
|
||||||
|
|
||||||
def _load_vid(self, p):
|
def _load_vid(self, p):
|
||||||
|
"""Load the visual data.
|
||||||
|
Args:
|
||||||
|
p:
|
||||||
|
A directory which contains a sequence of frames
|
||||||
|
for a visual sample.
|
||||||
|
Return:
|
||||||
|
The array of a visual sample.
|
||||||
|
"""
|
||||||
files = os.listdir(p)
|
files = os.listdir(p)
|
||||||
files = list(filter(lambda file: file.find(".jpg") != -1, files))
|
files = list(filter(lambda file: file.find(".jpg") != -1, files))
|
||||||
files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
|
files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
|
||||||
@ -123,6 +131,13 @@ class dataset_av(Dataset):
|
|||||||
return array
|
return array
|
||||||
|
|
||||||
def _load_aud(self, filename):
|
def _load_aud(self, filename):
|
||||||
|
"""Load the audio data.
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
The full path of a wav file.
|
||||||
|
Return:
|
||||||
|
The fbank feature array.
|
||||||
|
"""
|
||||||
opts = kaldifeat.FbankOptions()
|
opts = kaldifeat.FbankOptions()
|
||||||
opts.frame_opts.dither = 0
|
opts.frame_opts.dither = 0
|
||||||
opts.frame_opts.snip_edges = False
|
opts.frame_opts.snip_edges = False
|
||||||
@ -135,6 +150,13 @@ class dataset_av(Dataset):
|
|||||||
return features
|
return features
|
||||||
|
|
||||||
def _load_anno(self, name):
|
def _load_anno(self, name):
|
||||||
|
"""Load the text file.
|
||||||
|
Args:
|
||||||
|
name:
|
||||||
|
The file which records the text.
|
||||||
|
Return:
|
||||||
|
A sequence of words.
|
||||||
|
"""
|
||||||
with open(name, "r") as f:
|
with open(name, "r") as f:
|
||||||
lines = [line.strip().split(" ") for line in f.readlines()]
|
lines = [line.strip().split(" ") for line in f.readlines()]
|
||||||
txt = [line[2] for line in lines]
|
txt = [line[2] for line in lines]
|
||||||
@ -143,6 +165,15 @@ class dataset_av(Dataset):
|
|||||||
return txt
|
return txt
|
||||||
|
|
||||||
def _padding(self, array, length):
|
def _padding(self, array, length):
|
||||||
|
"""Pad zeros for the feature array.
|
||||||
|
Args:
|
||||||
|
array:
|
||||||
|
The feature arry. (Audio or Visual feature)
|
||||||
|
length:
|
||||||
|
The length for padding.
|
||||||
|
Return:
|
||||||
|
A new feature array after padding.
|
||||||
|
"""
|
||||||
array = [array[_] for _ in range(array.shape[0])]
|
array = [array[_] for _ in range(array.shape[0])]
|
||||||
size = array[0].shape
|
size = array[0].shape
|
||||||
for i in range(length - len(array)):
|
for i in range(length - len(array)):
|
||||||
|
@ -24,10 +24,13 @@ import os
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from .cvtransforms import HorizontalFlip, ColorNormalize
|
from .cvtransforms import (
|
||||||
|
color_normalize,
|
||||||
|
horizontal_flip,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class dataset_visual(Dataset):
|
class VisualDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
video_path: str,
|
video_path: str,
|
||||||
@ -74,8 +77,8 @@ class dataset_visual(Dataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.phase == "train":
|
if self.phase == "train":
|
||||||
vid = HorizontalFlip(vid)
|
vid = horizontal_flip(vid, p=0.5)
|
||||||
vid = ColorNormalize(vid)
|
vid = color_normalize(vid)
|
||||||
|
|
||||||
vid = self._padding(vid, self.vid_padding)
|
vid = self._padding(vid, self.vid_padding)
|
||||||
|
|
||||||
@ -88,6 +91,14 @@ class dataset_visual(Dataset):
|
|||||||
return len(self.data)
|
return len(self.data)
|
||||||
|
|
||||||
def _load_vid(self, p):
|
def _load_vid(self, p):
|
||||||
|
"""Load the visual data.
|
||||||
|
Args:
|
||||||
|
p:
|
||||||
|
A directory which contains a sequence of frames
|
||||||
|
for a visual sample.
|
||||||
|
Return:
|
||||||
|
The array of a visual sample.
|
||||||
|
"""
|
||||||
files = os.listdir(p)
|
files = os.listdir(p)
|
||||||
files = list(filter(lambda file: file.find(".jpg") != -1, files))
|
files = list(filter(lambda file: file.find(".jpg") != -1, files))
|
||||||
files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
|
files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
|
||||||
@ -101,6 +112,13 @@ class dataset_visual(Dataset):
|
|||||||
return array
|
return array
|
||||||
|
|
||||||
def _load_anno(self, name):
|
def _load_anno(self, name):
|
||||||
|
"""Load the text file.
|
||||||
|
Args:
|
||||||
|
name:
|
||||||
|
The file which records the text.
|
||||||
|
Return:
|
||||||
|
A sequence of words.
|
||||||
|
"""
|
||||||
with open(name, "r") as f:
|
with open(name, "r") as f:
|
||||||
lines = [line.strip().split(" ") for line in f.readlines()]
|
lines = [line.strip().split(" ") for line in f.readlines()]
|
||||||
txt = [line[2] for line in lines]
|
txt = [line[2] for line in lines]
|
||||||
@ -109,6 +127,15 @@ class dataset_visual(Dataset):
|
|||||||
return txt
|
return txt
|
||||||
|
|
||||||
def _padding(self, array, length):
|
def _padding(self, array, length):
|
||||||
|
"""Pad zeros for the feature array.
|
||||||
|
Args:
|
||||||
|
array:
|
||||||
|
The feature arry. (Audio or Visual feature)
|
||||||
|
length:
|
||||||
|
The length for padding.
|
||||||
|
Return:
|
||||||
|
A new feature array after padding.
|
||||||
|
"""
|
||||||
array = [array[_] for _ in range(array.shape[0])]
|
array = [array[_] for _ in range(array.shape[0])]
|
||||||
size = array[0].shape
|
size = array[0].shape
|
||||||
for i in range(length - len(array)):
|
for i in range(length - len(array)):
|
||||||
|
@ -30,7 +30,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from local.dataset_visual import dataset_visual
|
from local.dataset_visual import VisualDataset
|
||||||
|
|
||||||
from model import VisualNet2
|
from model import VisualNet2
|
||||||
|
|
||||||
@ -463,7 +463,7 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
grid = dataset_visual(
|
grid = VisualDataset(
|
||||||
params.video_path,
|
params.video_path,
|
||||||
params.anno_path,
|
params.anno_path,
|
||||||
params.val_list,
|
params.val_list,
|
||||||
|
@ -32,7 +32,7 @@ import torch.nn as nn
|
|||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from local.dataset_visual import dataset_visual
|
from local.dataset_visual import VisualDataset
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
|
|
||||||
from model import VisualNet2
|
from model import VisualNet2
|
||||||
@ -529,7 +529,7 @@ def run(rank, world_size, args):
|
|||||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||||
|
|
||||||
grid = dataset_visual(
|
grid = VisualDataset(
|
||||||
params.video_path,
|
params.video_path,
|
||||||
params.anno_path,
|
params.anno_path,
|
||||||
params.train_list,
|
params.train_list,
|
||||||
|
@ -14,22 +14,32 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script is to encodes the supervisions as Tuple list.
|
||||||
|
The supervision tensor has shape ``(batch_size, 3)``.
|
||||||
|
Its second dimension contains information about sequence index [0],
|
||||||
|
start frames [1] and num frames [2].
|
||||||
|
In GRID, the start frame of each audio sample is 0.
|
||||||
|
"""
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def encode_supervisions(nnet_output_shape, batch):
|
def encode_supervisions(nnet_output_shape: int, batch: dict):
|
||||||
"""
|
"""
|
||||||
In GRID, the lengths of all samples are same.
|
Args:
|
||||||
And here, we don't deploy cut operation on it.
|
nnet_output_shape:
|
||||||
So, the start frame is always 0 among all samples.
|
The shape of nnet_output, e.g: (N, T, D).
|
||||||
|
batch:
|
||||||
|
A batch of dataloader, it's a dict file
|
||||||
|
including text and aud/vid arrays.
|
||||||
|
Return:
|
||||||
|
The tuple list of supervisions and the text in batch.
|
||||||
"""
|
"""
|
||||||
N, T, D = nnet_output_shape
|
N, T, D = nnet_output_shape
|
||||||
|
|
||||||
supervisions_idx = torch.arange(0, N).to(torch.int32)
|
supervisions_idx = torch.arange(0, N, dtype=torch.int32)
|
||||||
start_frames = [0 for _ in range(N)]
|
supervisions_start_frame = torch.full((1, N), 0, dtype=torch.int32)[0]
|
||||||
supervisions_start_frame = torch.tensor(start_frames).to(torch.int32)
|
supervisions_num_frames = torch.full((1, N), T, dtype=torch.int32)[0]
|
||||||
num_frames = [T for _ in range(N)]
|
|
||||||
supervisions_num_frames = torch.tensor(num_frames).to(torch.int32)
|
|
||||||
|
|
||||||
supervision_segments = torch.stack(
|
supervision_segments = torch.stack(
|
||||||
(
|
(
|
||||||
@ -38,8 +48,7 @@ def encode_supervisions(nnet_output_shape, batch):
|
|||||||
supervisions_num_frames,
|
supervisions_num_frames,
|
||||||
),
|
),
|
||||||
1,
|
1,
|
||||||
).to(torch.int32)
|
)
|
||||||
|
|
||||||
texts = batch["txt"]
|
texts = batch["txt"]
|
||||||
|
|
||||||
return supervision_segments, texts
|
return supervision_segments, texts
|
||||||
|
@ -30,7 +30,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from local.dataset_visual import dataset_visual
|
from local.dataset_visual import VisualDataset
|
||||||
from model import VisualNet
|
from model import VisualNet
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
@ -462,7 +462,7 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
grid = dataset_visual(
|
grid = VisualDataset(
|
||||||
params.video_path,
|
params.video_path,
|
||||||
params.anno_path,
|
params.anno_path,
|
||||||
params.val_list,
|
params.val_list,
|
||||||
|
@ -32,7 +32,7 @@ import torch.nn as nn
|
|||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from local.dataset_visual import dataset_visual
|
from local.dataset_visual import VisualDataset
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import VisualNet
|
from model import VisualNet
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -528,7 +528,7 @@ def run(rank, world_size, args):
|
|||||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||||
|
|
||||||
grid = dataset_visual(
|
grid = VisualDataset(
|
||||||
params.video_path,
|
params.video_path,
|
||||||
params.anno_path,
|
params.anno_path,
|
||||||
params.train_list,
|
params.train_list,
|
||||||
|
@ -14,22 +14,32 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script is to encodes the supervisions as Tuple list.
|
||||||
|
The supervision tensor has shape ``(batch_size, 3)``.
|
||||||
|
Its second dimension contains information about sequence index [0],
|
||||||
|
start frames [1] and num frames [2].
|
||||||
|
In GRID, the start frame of each audio sample is 0.
|
||||||
|
"""
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def encode_supervisions(nnet_output_shape, batch):
|
def encode_supervisions(nnet_output_shape: int, batch: dict):
|
||||||
"""
|
"""
|
||||||
In GRID, the lengths of all samples are same.
|
Args:
|
||||||
And here, we don't deploy cut operation on it.
|
nnet_output_shape:
|
||||||
So, the start frame is always 0 among all samples.
|
The shape of nnet_output, e.g: (N, T, D).
|
||||||
|
batch:
|
||||||
|
A batch of dataloader, it's a dict file
|
||||||
|
including text and aud/vid arrays.
|
||||||
|
Return:
|
||||||
|
The tuple list of supervisions and the text in batch.
|
||||||
"""
|
"""
|
||||||
N, T, D = nnet_output_shape
|
N, T, D = nnet_output_shape
|
||||||
|
|
||||||
supervisions_idx = torch.arange(0, N).to(torch.int32)
|
supervisions_idx = torch.arange(0, N, dtype=torch.int32)
|
||||||
start_frames = [0 for _ in range(N)]
|
supervisions_start_frame = torch.full((1, N), 0, dtype=torch.int32)[0]
|
||||||
supervisions_start_frame = torch.tensor(start_frames).to(torch.int32)
|
supervisions_num_frames = torch.full((1, N), T, dtype=torch.int32)[0]
|
||||||
num_frames = [T for _ in range(N)]
|
|
||||||
supervisions_num_frames = torch.tensor(num_frames).to(torch.int32)
|
|
||||||
|
|
||||||
supervision_segments = torch.stack(
|
supervision_segments = torch.stack(
|
||||||
(
|
(
|
||||||
@ -38,8 +48,7 @@ def encode_supervisions(nnet_output_shape, batch):
|
|||||||
supervisions_num_frames,
|
supervisions_num_frames,
|
||||||
),
|
),
|
||||||
1,
|
1,
|
||||||
).to(torch.int32)
|
)
|
||||||
|
|
||||||
texts = batch["txt"]
|
texts = batch["txt"]
|
||||||
|
|
||||||
return supervision_segments, texts
|
return supervision_segments, texts
|
||||||
|
Loading…
x
Reference in New Issue
Block a user