mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 10:44:19 +00:00
add pretrained.py
This commit is contained in:
parent
2b5881f4d9
commit
283bd126c5
274
egs/grid/AVSR/audionet_ctc_asr/pretrained.py
Normal file
274
egs/grid/AVSR/audionet_ctc_asr/pretrained.py
Normal file
@ -0,0 +1,274 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Wei Kang
|
||||||
|
# Mingshuang Luo)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import kaldifeat
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from model import AudioNet
|
||||||
|
|
||||||
|
from icefall.decode import (
|
||||||
|
get_lattice,
|
||||||
|
one_best_decoding,
|
||||||
|
rescore_with_whole_lattice,
|
||||||
|
)
|
||||||
|
from icefall.utils import AttributeDict, get_texts
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the checkpoint. "
|
||||||
|
"The checkpoint is assumed to be saved by "
|
||||||
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--words-file",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to words.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--HLG", type=str, required=True, help="Path to HLG.pt."
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--method",
|
||||||
|
type=str,
|
||||||
|
default="1best",
|
||||||
|
help="""Decoding method.
|
||||||
|
Possible values are:
|
||||||
|
(1) 1best - Use the best path as decoding output. Only
|
||||||
|
the transformer encoder output is used for decoding.
|
||||||
|
We call it HLG decoding.
|
||||||
|
(2) whole-lattice-rescoring - Use an LM to rescore the
|
||||||
|
decoding lattice and then use 1best to decode the
|
||||||
|
rescored lattice.
|
||||||
|
We call it HLG decoding + n-gram LM rescoring.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--G",
|
||||||
|
type=str,
|
||||||
|
help="""An LM for rescoring.
|
||||||
|
Used only when method is
|
||||||
|
whole-lattice-rescoring.
|
||||||
|
It's usually a 4-gram LM.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngram-lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="""
|
||||||
|
Used only when method is whole-lattice-rescoring.
|
||||||
|
It specifies the scale for n-gram LM scores.
|
||||||
|
(Note: You need to tune it on a dataset.)
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_params() -> AttributeDict:
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
"feature_dim": 80,
|
||||||
|
"subsampling_factor": 3,
|
||||||
|
"num_classes": 28,
|
||||||
|
"sample_rate": 16000,
|
||||||
|
"search_beam": 20,
|
||||||
|
"output_beam": 5,
|
||||||
|
"min_active_states": 30,
|
||||||
|
"max_active_states": 10000,
|
||||||
|
"use_double_scores": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
logging.info(f"{params}")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
logging.info("Creating model")
|
||||||
|
model = AudioNet(
|
||||||
|
num_features=params.feature_dim,
|
||||||
|
num_classes=params.num_classes,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||||
|
model.load_state_dict(checkpoint["model"])
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
logging.info(f"Loading HLG from {params.HLG}")
|
||||||
|
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||||
|
HLG = HLG.to(device)
|
||||||
|
if not hasattr(HLG, "lm_scores"):
|
||||||
|
# For whole-lattice-rescoring and attention-decoder
|
||||||
|
HLG.lm_scores = HLG.scores.clone()
|
||||||
|
|
||||||
|
if params.method == "whole-lattice-rescoring":
|
||||||
|
logging.info(f"Loading G from {params.G}")
|
||||||
|
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||||
|
# Add epsilon self-loops to G as we will compose
|
||||||
|
# it with the whole lattice later
|
||||||
|
G = G.to(device)
|
||||||
|
G = k2.add_epsilon_self_loops(G)
|
||||||
|
G = k2.arc_sort(G)
|
||||||
|
G.lm_scores = G.scores.clone()
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.device = device
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = params.sample_rate
|
||||||
|
opts.mel_opts.num_bins = params.feature_dim
|
||||||
|
|
||||||
|
fbank = kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {params.sound_files}")
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||||
|
)
|
||||||
|
waves = [w.to(device) for w in waves]
|
||||||
|
|
||||||
|
logging.info("Decoding started")
|
||||||
|
features = fbank(waves)
|
||||||
|
|
||||||
|
features_new = torch.zeros(len(features), 480, params.feature_dim).to(
|
||||||
|
device
|
||||||
|
)
|
||||||
|
for i in range(len(features)):
|
||||||
|
length = features[i].shape[0]
|
||||||
|
features_new[i][:length] = features[i]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
nnet_output = model(features_new.permute(0, 2, 1))
|
||||||
|
# nnet_output is (N, T, C)
|
||||||
|
|
||||||
|
batch_size = nnet_output.shape[0]
|
||||||
|
supervision_segments = torch.tensor(
|
||||||
|
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
|
||||||
|
lattice = get_lattice(
|
||||||
|
nnet_output=nnet_output,
|
||||||
|
decoding_graph=HLG,
|
||||||
|
supervision_segments=supervision_segments,
|
||||||
|
search_beam=params.search_beam,
|
||||||
|
output_beam=params.output_beam,
|
||||||
|
min_active_states=params.min_active_states,
|
||||||
|
max_active_states=params.max_active_states,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.method == "1best":
|
||||||
|
logging.info("Use HLG decoding")
|
||||||
|
best_path = one_best_decoding(
|
||||||
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
|
)
|
||||||
|
elif params.method == "whole-lattice-rescoring":
|
||||||
|
logging.info("Use HLG decoding + LM rescoring")
|
||||||
|
best_path_dict = rescore_with_whole_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
G_with_epsilon_loops=G,
|
||||||
|
lm_scale_list=[params.ngram_lm_scale],
|
||||||
|
)
|
||||||
|
best_path = next(iter(best_path_dict.values()))
|
||||||
|
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||||
|
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
||||||
|
|
||||||
|
s = "\n"
|
||||||
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
|
words = " ".join(hyp)
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
270
egs/grid/AVSR/combinenet_ctc_avsr/pretrained.py
Normal file
270
egs/grid/AVSR/combinenet_ctc_avsr/pretrained.py
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Wei Kang
|
||||||
|
# Mingshuang Luo)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import cv2
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import kaldifeat
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from model import TdnnLstm
|
||||||
|
|
||||||
|
from icefall.decode import (
|
||||||
|
get_lattice,
|
||||||
|
one_best_decoding,
|
||||||
|
rescore_with_whole_lattice,
|
||||||
|
)
|
||||||
|
from icefall.utils import AttributeDict, get_texts
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the checkpoint. "
|
||||||
|
"The checkpoint is assumed to be saved by "
|
||||||
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--words-file",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to words.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--HLG", type=str, required=True, help="Path to HLG.pt."
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--method",
|
||||||
|
type=str,
|
||||||
|
default="1best",
|
||||||
|
help="""Decoding method.
|
||||||
|
Possible values are:
|
||||||
|
(1) 1best - Use the best path as decoding output. Only
|
||||||
|
the transformer encoder output is used for decoding.
|
||||||
|
We call it HLG decoding.
|
||||||
|
(2) whole-lattice-rescoring - Use an LM to rescore the
|
||||||
|
decoding lattice and then use 1best to decode the
|
||||||
|
rescored lattice.
|
||||||
|
We call it HLG decoding + n-gram LM rescoring.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--G",
|
||||||
|
type=str,
|
||||||
|
help="""An LM for rescoring.
|
||||||
|
Used only when method is
|
||||||
|
whole-lattice-rescoring.
|
||||||
|
It's usually a 4-gram LM.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngram-lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="""
|
||||||
|
Used only when method is whole-lattice-rescoring.
|
||||||
|
It specifies the scale for n-gram LM scores.
|
||||||
|
(Note: You need to tune it on a dataset.)
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lipframes-dirs",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input visual file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by cv2.imread(). "
|
||||||
|
"The frames sample rate is 25fps.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_params() -> AttributeDict:
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
"num_classes": 28,
|
||||||
|
"search_beam": 20,
|
||||||
|
"output_beam": 5,
|
||||||
|
"min_active_states": 30,
|
||||||
|
"max_active_states": 10000,
|
||||||
|
"use_double_scores": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
logging.info(f"{params}")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
logging.info("Creating model")
|
||||||
|
model = TdnnLstm(num_features=80, num_classes=28, subsampling_factor=3)
|
||||||
|
|
||||||
|
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||||
|
model.load_state_dict(checkpoint["model"])
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
logging.info(f"Loading HLG from {params.HLG}")
|
||||||
|
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||||
|
HLG = HLG.to(device)
|
||||||
|
if not hasattr(HLG, "lm_scores"):
|
||||||
|
# For whole-lattice-rescoring and attention-decoder
|
||||||
|
HLG.lm_scores = HLG.scores.clone()
|
||||||
|
|
||||||
|
if params.method == "whole-lattice-rescoring":
|
||||||
|
logging.info(f"Loading G from {params.G}")
|
||||||
|
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||||
|
# Add epsilon self-loops to G as we will compose
|
||||||
|
# it with the whole lattice later
|
||||||
|
G = G.to(device)
|
||||||
|
G = k2.add_epsilon_self_loops(G)
|
||||||
|
G = k2.arc_sort(G)
|
||||||
|
G.lm_scores = G.scores.clone()
|
||||||
|
|
||||||
|
logging.info("Loading lip roi frames and audio wav files")
|
||||||
|
aud = []
|
||||||
|
vid = []
|
||||||
|
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.device = device
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = 16000
|
||||||
|
opts.mel_opts.num_bins = 80
|
||||||
|
fbank = kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
for sample_dir in params.lipframes_dirs:
|
||||||
|
wave, sr = torchaudio.load(
|
||||||
|
sample_dir.replace("lip", "audio_25k").replace(
|
||||||
|
"video/mpg_6000/", ""
|
||||||
|
)
|
||||||
|
+ ".wav"
|
||||||
|
)
|
||||||
|
wave = wave[0]
|
||||||
|
aud.append(fbank(wave))
|
||||||
|
|
||||||
|
files = os.listdir(sample_dir)
|
||||||
|
files = list(filter(lambda file: file.find(".jpg") != -1, files))
|
||||||
|
files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
|
||||||
|
array = [cv2.imread(os.path.join(sample_dir, file)) for file in files]
|
||||||
|
array = list(filter(lambda im: im is not None, array))
|
||||||
|
array = [
|
||||||
|
cv2.resize(im, (128, 64), interpolation=cv2.INTER_LANCZOS4)
|
||||||
|
for im in array
|
||||||
|
]
|
||||||
|
array = np.stack(array, axis=0).astype(np.float32)
|
||||||
|
vid.append(array)
|
||||||
|
|
||||||
|
L, H, W, C = vid[0].shape
|
||||||
|
features_v = torch.zeros(len(vid), 75, H, W, C).to(device)
|
||||||
|
for i in range(len(vid)):
|
||||||
|
length = vid[i].shape[0]
|
||||||
|
features_v[i][:length] = torch.FloatTensor(vid[i]).to(device)
|
||||||
|
|
||||||
|
features_a = torch.zeros(len(aud), 450, 80).to(device)
|
||||||
|
for i in range(len(aud)):
|
||||||
|
length = aud[i].shape[0]
|
||||||
|
features_a[i][:length] = torch.FloatTensor(aud[i]).to(device)
|
||||||
|
|
||||||
|
logging.info("Decoding started")
|
||||||
|
with torch.no_grad():
|
||||||
|
nnet_output = model(
|
||||||
|
features_v.permute(0, 4, 1, 2, 3) / 255.0,
|
||||||
|
features_a.permute(0, 2, 1),
|
||||||
|
)
|
||||||
|
# nnet_output is (N, T, C)
|
||||||
|
|
||||||
|
batch_size = nnet_output.shape[0]
|
||||||
|
supervision_segments = torch.tensor(
|
||||||
|
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
|
||||||
|
lattice = get_lattice(
|
||||||
|
nnet_output=nnet_output,
|
||||||
|
decoding_graph=HLG,
|
||||||
|
supervision_segments=supervision_segments,
|
||||||
|
search_beam=params.search_beam,
|
||||||
|
output_beam=params.output_beam,
|
||||||
|
min_active_states=params.min_active_states,
|
||||||
|
max_active_states=params.max_active_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.method == "1best":
|
||||||
|
logging.info("Use HLG decoding")
|
||||||
|
best_path = one_best_decoding(
|
||||||
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
|
)
|
||||||
|
elif params.method == "whole-lattice-rescoring":
|
||||||
|
logging.info("Use HLG decoding + LM rescoring")
|
||||||
|
best_path_dict = rescore_with_whole_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
G_with_epsilon_loops=G,
|
||||||
|
lm_scale_list=[params.ngram_lm_scale],
|
||||||
|
)
|
||||||
|
best_path = next(iter(best_path_dict.values()))
|
||||||
|
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||||
|
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
||||||
|
|
||||||
|
s = "\n"
|
||||||
|
for filename, hyp in zip(params.lipframes_dirs, hyps):
|
||||||
|
words = " ".join(hyp)
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
@ -32,8 +32,7 @@ import torch.nn as nn
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from local.dataset_visual import dataset_visual
|
from local.dataset_visual import dataset_visual
|
||||||
|
|
||||||
# from model import LipNet
|
from model import VisualNet2
|
||||||
from model import visual_frontend
|
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
@ -131,7 +130,7 @@ def get_parser():
|
|||||||
def get_params() -> AttributeDict:
|
def get_params() -> AttributeDict:
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
"exp_dir": Path("visualnet_ctc_vsr2/exp"),
|
"exp_dir": Path("visualnet2_ctc_vsr/exp"),
|
||||||
"lang_dir": Path("data/lang_character"),
|
"lang_dir": Path("data/lang_character"),
|
||||||
"lm_dir": Path("data/lm"),
|
"lm_dir": Path("data/lm"),
|
||||||
"search_beam": 20,
|
"search_beam": 20,
|
||||||
@ -388,6 +387,7 @@ def main():
|
|||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
max_token_id = max(lexicon.tokens)
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -441,7 +441,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
G = None
|
G = None
|
||||||
|
|
||||||
model = visual_frontend()
|
model = VisualNet2(num_classes=max_token_id + 1)
|
||||||
if params.avg == 1:
|
if params.avg == 1:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
else:
|
else:
|
||||||
|
@ -115,9 +115,10 @@ class ResNet(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class VisualNet2(nn.Module):
|
class VisualNet2(nn.Module):
|
||||||
def __init__(self, inputDim=512):
|
def __init__(self, num_classes):
|
||||||
super(VisualNet2, self).__init__()
|
super(VisualNet2, self).__init__()
|
||||||
self.inputDim = inputDim
|
self.num_classes = num_classes
|
||||||
|
self.inputDim = 512
|
||||||
self.conv3d = nn.Conv3d(
|
self.conv3d = nn.Conv3d(
|
||||||
3,
|
3,
|
||||||
64,
|
64,
|
||||||
@ -143,7 +144,7 @@ class VisualNet2(nn.Module):
|
|||||||
self.dropout = nn.Dropout(p=0.5)
|
self.dropout = nn.Dropout(p=0.5)
|
||||||
|
|
||||||
# fc
|
# fc
|
||||||
self.linear = nn.Linear(1024, 28)
|
self.linear = nn.Linear(1024, self.num_classes)
|
||||||
|
|
||||||
# initialize
|
# initialize
|
||||||
self._initialize_weights()
|
self._initialize_weights()
|
||||||
|
243
egs/grid/AVSR/visualnet2_ctc_vsr/pretrained.py
Normal file
243
egs/grid/AVSR/visualnet2_ctc_vsr/pretrained.py
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Wei Kang
|
||||||
|
# Mingshuang Luo)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import cv2
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
from model import VisualNet2
|
||||||
|
|
||||||
|
from icefall.decode import (
|
||||||
|
get_lattice,
|
||||||
|
one_best_decoding,
|
||||||
|
rescore_with_whole_lattice,
|
||||||
|
)
|
||||||
|
from icefall.utils import AttributeDict, get_texts
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the checkpoint. "
|
||||||
|
"The checkpoint is assumed to be saved by "
|
||||||
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--words-file",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to words.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--HLG", type=str, required=True, help="Path to HLG.pt."
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--method",
|
||||||
|
type=str,
|
||||||
|
default="1best",
|
||||||
|
help="""Decoding method.
|
||||||
|
Possible values are:
|
||||||
|
(1) 1best - Use the best path as decoding output. Only
|
||||||
|
the transformer encoder output is used for decoding.
|
||||||
|
We call it HLG decoding.
|
||||||
|
(2) whole-lattice-rescoring - Use an LM to rescore the
|
||||||
|
decoding lattice and then use 1best to decode the
|
||||||
|
rescored lattice.
|
||||||
|
We call it HLG decoding + n-gram LM rescoring.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--G",
|
||||||
|
type=str,
|
||||||
|
help="""An LM for rescoring.
|
||||||
|
Used only when method is
|
||||||
|
whole-lattice-rescoring.
|
||||||
|
It's usually a 4-gram LM.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngram-lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="""
|
||||||
|
Used only when method is whole-lattice-rescoring.
|
||||||
|
It specifies the scale for n-gram LM scores.
|
||||||
|
(Note: You need to tune it on a dataset.)
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lipframes-dirs",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input visual file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by cv2.imread(). "
|
||||||
|
"The frames sample rate is 25fps.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_params() -> AttributeDict:
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
"num_classes": 28,
|
||||||
|
"search_beam": 20,
|
||||||
|
"output_beam": 5,
|
||||||
|
"min_active_states": 30,
|
||||||
|
"max_active_states": 10000,
|
||||||
|
"use_double_scores": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
logging.info(f"{params}")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
logging.info("Creating model")
|
||||||
|
model = VisualNet2(num_classes=params.num_classes)
|
||||||
|
|
||||||
|
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||||
|
model.load_state_dict(checkpoint["model"])
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
logging.info(f"Loading HLG from {params.HLG}")
|
||||||
|
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||||
|
HLG = HLG.to(device)
|
||||||
|
if not hasattr(HLG, "lm_scores"):
|
||||||
|
# For whole-lattice-rescoring and attention-decoder
|
||||||
|
HLG.lm_scores = HLG.scores.clone()
|
||||||
|
|
||||||
|
if params.method == "whole-lattice-rescoring":
|
||||||
|
logging.info(f"Loading G from {params.G}")
|
||||||
|
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||||
|
# Add epsilon self-loops to G as we will compose
|
||||||
|
# it with the whole lattice later
|
||||||
|
G = G.to(device)
|
||||||
|
G = k2.add_epsilon_self_loops(G)
|
||||||
|
G = k2.arc_sort(G)
|
||||||
|
G.lm_scores = G.scores.clone()
|
||||||
|
|
||||||
|
logging.info("Loading lip roi frames")
|
||||||
|
|
||||||
|
vid = []
|
||||||
|
for sample_dir in params.lipframes_dirs:
|
||||||
|
files = os.listdir(sample_dir)
|
||||||
|
files = list(filter(lambda file: file.find(".jpg") != -1, files))
|
||||||
|
files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
|
||||||
|
array = [cv2.imread(os.path.join(sample_dir, file)) for file in files]
|
||||||
|
array = list(filter(lambda im: im is not None, array))
|
||||||
|
array = [
|
||||||
|
cv2.resize(im, (128, 64), interpolation=cv2.INTER_LANCZOS4)
|
||||||
|
for im in array
|
||||||
|
]
|
||||||
|
array = np.stack(array, axis=0).astype(np.float32)
|
||||||
|
vid.append(array)
|
||||||
|
|
||||||
|
_, H, W, C = vid[0].shape
|
||||||
|
features = torch.zeros(len(vid), 75, H, W, C).to(device)
|
||||||
|
for i in range(len(vid)):
|
||||||
|
length = vid[i].shape[0]
|
||||||
|
features[i][:length] = torch.FloatTensor(vid[i]).to(device)
|
||||||
|
|
||||||
|
logging.info("Decoding started")
|
||||||
|
features = features / 255.0
|
||||||
|
with torch.no_grad():
|
||||||
|
nnet_output = model(features.permute(0, 4, 1, 2, 3))
|
||||||
|
# nnet_output is (N, T, C)
|
||||||
|
|
||||||
|
batch_size = nnet_output.shape[0]
|
||||||
|
supervision_segments = torch.tensor(
|
||||||
|
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
|
||||||
|
lattice = get_lattice(
|
||||||
|
nnet_output=nnet_output,
|
||||||
|
decoding_graph=HLG,
|
||||||
|
supervision_segments=supervision_segments,
|
||||||
|
search_beam=params.search_beam,
|
||||||
|
output_beam=params.output_beam,
|
||||||
|
min_active_states=params.min_active_states,
|
||||||
|
max_active_states=params.max_active_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.method == "1best":
|
||||||
|
logging.info("Use HLG decoding")
|
||||||
|
best_path = one_best_decoding(
|
||||||
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
|
)
|
||||||
|
elif params.method == "whole-lattice-rescoring":
|
||||||
|
logging.info("Use HLG decoding + LM rescoring")
|
||||||
|
best_path_dict = rescore_with_whole_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
G_with_epsilon_loops=G,
|
||||||
|
lm_scale_list=[params.ngram_lm_scale],
|
||||||
|
)
|
||||||
|
best_path = next(iter(best_path_dict.values()))
|
||||||
|
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||||
|
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
||||||
|
|
||||||
|
s = "\n"
|
||||||
|
for filename, hyp in zip(params.lipframes_dirs, hyps):
|
||||||
|
words = " ".join(hyp)
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
@ -503,13 +503,14 @@ def run(rank, world_size, args):
|
|||||||
tb_writer = None
|
tb_writer = None
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
max_token_id = max(lexicon.tokens)
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
|
|
||||||
graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device)
|
graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device)
|
||||||
model = VisualNet2()
|
model = VisualNet2(num_classes=max_token_id + 1)
|
||||||
|
|
||||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||||
|
|
||||||
|
243
egs/grid/AVSR/visualnet_ctc_vsr/pretrained.py
Normal file
243
egs/grid/AVSR/visualnet_ctc_vsr/pretrained.py
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Wei Kang
|
||||||
|
# Mingshuang Luo)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import cv2
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
from model import VisualNet
|
||||||
|
|
||||||
|
from icefall.decode import (
|
||||||
|
get_lattice,
|
||||||
|
one_best_decoding,
|
||||||
|
rescore_with_whole_lattice,
|
||||||
|
)
|
||||||
|
from icefall.utils import AttributeDict, get_texts
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the checkpoint. "
|
||||||
|
"The checkpoint is assumed to be saved by "
|
||||||
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--words-file",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to words.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--HLG", type=str, required=True, help="Path to HLG.pt."
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--method",
|
||||||
|
type=str,
|
||||||
|
default="1best",
|
||||||
|
help="""Decoding method.
|
||||||
|
Possible values are:
|
||||||
|
(1) 1best - Use the best path as decoding output. Only
|
||||||
|
the transformer encoder output is used for decoding.
|
||||||
|
We call it HLG decoding.
|
||||||
|
(2) whole-lattice-rescoring - Use an LM to rescore the
|
||||||
|
decoding lattice and then use 1best to decode the
|
||||||
|
rescored lattice.
|
||||||
|
We call it HLG decoding + n-gram LM rescoring.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--G",
|
||||||
|
type=str,
|
||||||
|
help="""An LM for rescoring.
|
||||||
|
Used only when method is
|
||||||
|
whole-lattice-rescoring.
|
||||||
|
It's usually a 4-gram LM.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngram-lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="""
|
||||||
|
Used only when method is whole-lattice-rescoring.
|
||||||
|
It specifies the scale for n-gram LM scores.
|
||||||
|
(Note: You need to tune it on a dataset.)
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lipframes-dirs",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input visual file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by cv2.imread(). "
|
||||||
|
"The frames sample rate is 25fps.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_params() -> AttributeDict:
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
"num_classes": 28,
|
||||||
|
"search_beam": 20,
|
||||||
|
"output_beam": 5,
|
||||||
|
"min_active_states": 30,
|
||||||
|
"max_active_states": 10000,
|
||||||
|
"use_double_scores": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
logging.info(f"{params}")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
logging.info("Creating model")
|
||||||
|
model = VisualNet(num_classes=params.num_classes)
|
||||||
|
|
||||||
|
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||||
|
model.load_state_dict(checkpoint["model"])
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
logging.info(f"Loading HLG from {params.HLG}")
|
||||||
|
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||||
|
HLG = HLG.to(device)
|
||||||
|
if not hasattr(HLG, "lm_scores"):
|
||||||
|
# For whole-lattice-rescoring and attention-decoder
|
||||||
|
HLG.lm_scores = HLG.scores.clone()
|
||||||
|
|
||||||
|
if params.method == "whole-lattice-rescoring":
|
||||||
|
logging.info(f"Loading G from {params.G}")
|
||||||
|
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||||
|
# Add epsilon self-loops to G as we will compose
|
||||||
|
# it with the whole lattice later
|
||||||
|
G = G.to(device)
|
||||||
|
G = k2.add_epsilon_self_loops(G)
|
||||||
|
G = k2.arc_sort(G)
|
||||||
|
G.lm_scores = G.scores.clone()
|
||||||
|
|
||||||
|
logging.info("Loading lip roi frames")
|
||||||
|
|
||||||
|
vid = []
|
||||||
|
for sample_dir in params.lipframes_dirs:
|
||||||
|
files = os.listdir(sample_dir)
|
||||||
|
files = list(filter(lambda file: file.find(".jpg") != -1, files))
|
||||||
|
files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
|
||||||
|
array = [cv2.imread(os.path.join(sample_dir, file)) for file in files]
|
||||||
|
array = list(filter(lambda im: im is not None, array))
|
||||||
|
array = [
|
||||||
|
cv2.resize(im, (128, 64), interpolation=cv2.INTER_LANCZOS4)
|
||||||
|
for im in array
|
||||||
|
]
|
||||||
|
array = np.stack(array, axis=0).astype(np.float32)
|
||||||
|
vid.append(array)
|
||||||
|
|
||||||
|
_, H, W, C = vid[0].shape
|
||||||
|
features = torch.zeros(len(vid), 75, H, W, C).to(device)
|
||||||
|
for i in range(len(vid)):
|
||||||
|
length = vid[i].shape[0]
|
||||||
|
features[i][:length] = torch.FloatTensor(vid[i]).to(device)
|
||||||
|
|
||||||
|
logging.info("Decoding started")
|
||||||
|
features = features / 255.0
|
||||||
|
with torch.no_grad():
|
||||||
|
nnet_output = model(features.permute(0, 4, 1, 2, 3))
|
||||||
|
# nnet_output is (N, T, C)
|
||||||
|
|
||||||
|
batch_size = nnet_output.shape[0]
|
||||||
|
supervision_segments = torch.tensor(
|
||||||
|
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
|
||||||
|
lattice = get_lattice(
|
||||||
|
nnet_output=nnet_output,
|
||||||
|
decoding_graph=HLG,
|
||||||
|
supervision_segments=supervision_segments,
|
||||||
|
search_beam=params.search_beam,
|
||||||
|
output_beam=params.output_beam,
|
||||||
|
min_active_states=params.min_active_states,
|
||||||
|
max_active_states=params.max_active_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.method == "1best":
|
||||||
|
logging.info("Use HLG decoding")
|
||||||
|
best_path = one_best_decoding(
|
||||||
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
|
)
|
||||||
|
elif params.method == "whole-lattice-rescoring":
|
||||||
|
logging.info("Use HLG decoding + LM rescoring")
|
||||||
|
best_path_dict = rescore_with_whole_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
G_with_epsilon_loops=G,
|
||||||
|
lm_scale_list=[params.ngram_lm_scale],
|
||||||
|
)
|
||||||
|
best_path = next(iter(best_path_dict.values()))
|
||||||
|
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||||
|
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
||||||
|
|
||||||
|
s = "\n"
|
||||||
|
for filename, hyp in zip(params.lipframes_dirs, hyps):
|
||||||
|
words = " ".join(hyp)
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user