From 283bd126c599692d064c77b92a2733a6c772ba10 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Wed, 29 Dec 2021 19:10:56 +0800 Subject: [PATCH] add pretrained.py --- egs/grid/AVSR/audionet_ctc_asr/pretrained.py | 274 ++++++++++++++++++ .../AVSR/combinenet_ctc_avsr/pretrained.py | 270 +++++++++++++++++ egs/grid/AVSR/visualnet2_ctc_vsr/decode.py | 8 +- egs/grid/AVSR/visualnet2_ctc_vsr/model.py | 7 +- .../AVSR/visualnet2_ctc_vsr/pretrained.py | 243 ++++++++++++++++ egs/grid/AVSR/visualnet2_ctc_vsr/train.py | 3 +- egs/grid/AVSR/visualnet_ctc_vsr/pretrained.py | 243 ++++++++++++++++ 7 files changed, 1040 insertions(+), 8 deletions(-) create mode 100644 egs/grid/AVSR/audionet_ctc_asr/pretrained.py create mode 100644 egs/grid/AVSR/combinenet_ctc_avsr/pretrained.py create mode 100644 egs/grid/AVSR/visualnet2_ctc_vsr/pretrained.py create mode 100644 egs/grid/AVSR/visualnet_ctc_vsr/pretrained.py diff --git a/egs/grid/AVSR/audionet_ctc_asr/pretrained.py b/egs/grid/AVSR/audionet_ctc_asr/pretrained.py new file mode 100644 index 000000000..fe81ded6b --- /dev/null +++ b/egs/grid/AVSR/audionet_ctc_asr/pretrained.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from model import AudioNet + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_whole_lattice, +) +from icefall.utils import AttributeDict, get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--words-file", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument( + "--HLG", type=str, required=True, help="Path to HLG.pt." + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.1, + help=""" + Used only when method is whole-lattice-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 80, + "subsampling_factor": 3, + "num_classes": 28, + "sample_rate": 16000, + "search_beam": 20, + "output_beam": 5, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = AudioNet( + num_features=params.feature_dim, + num_classes=params.num_classes, + subsampling_factor=params.subsampling_factor, + ) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"]) + model.to(device) + model.eval() + + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method == "whole-lattice-rescoring": + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = G.to(device) + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G.lm_scores = G.scores.clone() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + + features_new = torch.zeros(len(features), 480, params.feature_dim).to( + device + ) + for i in range(len(features)): + length = features[i].shape[0] + features_new[i][:length] = features[i] + + with torch.no_grad(): + nnet_output = model(features_new.permute(0, 2, 1)) + # nnet_output is (N, T, C) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/grid/AVSR/combinenet_ctc_avsr/pretrained.py b/egs/grid/AVSR/combinenet_ctc_avsr/pretrained.py new file mode 100644 index 000000000..121d0cdd7 --- /dev/null +++ b/egs/grid/AVSR/combinenet_ctc_avsr/pretrained.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import cv2 +import logging +import numpy as np +import os + +import k2 +import kaldifeat +import torch +import torchaudio +from model import TdnnLstm + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_whole_lattice, +) +from icefall.utils import AttributeDict, get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--words-file", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument( + "--HLG", type=str, required=True, help="Path to HLG.pt." + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.1, + help=""" + Used only when method is whole-lattice-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--lipframes-dirs", + type=str, + nargs="+", + help="The input visual file(s) to transcribe. " + "Supported formats are those supported by cv2.imread(). " + "The frames sample rate is 25fps.", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "num_classes": 28, + "search_beam": 20, + "output_beam": 5, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = TdnnLstm(num_features=80, num_classes=28, subsampling_factor=3) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"]) + model.to(device) + model.eval() + + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method == "whole-lattice-rescoring": + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = G.to(device) + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G.lm_scores = G.scores.clone() + + logging.info("Loading lip roi frames and audio wav files") + aud = [] + vid = [] + + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + fbank = kaldifeat.Fbank(opts) + + for sample_dir in params.lipframes_dirs: + wave, sr = torchaudio.load( + sample_dir.replace("lip", "audio_25k").replace( + "video/mpg_6000/", "" + ) + + ".wav" + ) + wave = wave[0] + aud.append(fbank(wave)) + + files = os.listdir(sample_dir) + files = list(filter(lambda file: file.find(".jpg") != -1, files)) + files = sorted(files, key=lambda file: int(os.path.splitext(file)[0])) + array = [cv2.imread(os.path.join(sample_dir, file)) for file in files] + array = list(filter(lambda im: im is not None, array)) + array = [ + cv2.resize(im, (128, 64), interpolation=cv2.INTER_LANCZOS4) + for im in array + ] + array = np.stack(array, axis=0).astype(np.float32) + vid.append(array) + + L, H, W, C = vid[0].shape + features_v = torch.zeros(len(vid), 75, H, W, C).to(device) + for i in range(len(vid)): + length = vid[i].shape[0] + features_v[i][:length] = torch.FloatTensor(vid[i]).to(device) + + features_a = torch.zeros(len(aud), 450, 80).to(device) + for i in range(len(aud)): + length = aud[i].shape[0] + features_a[i][:length] = torch.FloatTensor(aud[i]).to(device) + + logging.info("Decoding started") + with torch.no_grad(): + nnet_output = model( + features_v.permute(0, 4, 1, 2, 3) / 255.0, + features_a.permute(0, 2, 1), + ) + # nnet_output is (N, T, C) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + + s = "\n" + for filename, hyp in zip(params.lipframes_dirs, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/grid/AVSR/visualnet2_ctc_vsr/decode.py b/egs/grid/AVSR/visualnet2_ctc_vsr/decode.py index a8fe0a515..1fbfd7650 100644 --- a/egs/grid/AVSR/visualnet2_ctc_vsr/decode.py +++ b/egs/grid/AVSR/visualnet2_ctc_vsr/decode.py @@ -32,8 +32,7 @@ import torch.nn as nn from torch.utils.data import DataLoader from local.dataset_visual import dataset_visual -# from model import LipNet -from model import visual_frontend +from model import VisualNet2 from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.decode import ( @@ -131,7 +130,7 @@ def get_parser(): def get_params() -> AttributeDict: params = AttributeDict( { - "exp_dir": Path("visualnet_ctc_vsr2/exp"), + "exp_dir": Path("visualnet2_ctc_vsr/exp"), "lang_dir": Path("data/lang_character"), "lm_dir": Path("data/lm"), "search_beam": 20, @@ -388,6 +387,7 @@ def main(): logging.info(params) lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) device = torch.device("cpu") if torch.cuda.is_available(): @@ -441,7 +441,7 @@ def main(): else: G = None - model = visual_frontend() + model = VisualNet2(num_classes=max_token_id + 1) if params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: diff --git a/egs/grid/AVSR/visualnet2_ctc_vsr/model.py b/egs/grid/AVSR/visualnet2_ctc_vsr/model.py index 14f102108..bf3ceadd0 100644 --- a/egs/grid/AVSR/visualnet2_ctc_vsr/model.py +++ b/egs/grid/AVSR/visualnet2_ctc_vsr/model.py @@ -115,9 +115,10 @@ class ResNet(nn.Module): class VisualNet2(nn.Module): - def __init__(self, inputDim=512): + def __init__(self, num_classes): super(VisualNet2, self).__init__() - self.inputDim = inputDim + self.num_classes = num_classes + self.inputDim = 512 self.conv3d = nn.Conv3d( 3, 64, @@ -143,7 +144,7 @@ class VisualNet2(nn.Module): self.dropout = nn.Dropout(p=0.5) # fc - self.linear = nn.Linear(1024, 28) + self.linear = nn.Linear(1024, self.num_classes) # initialize self._initialize_weights() diff --git a/egs/grid/AVSR/visualnet2_ctc_vsr/pretrained.py b/egs/grid/AVSR/visualnet2_ctc_vsr/pretrained.py new file mode 100644 index 000000000..08589db57 --- /dev/null +++ b/egs/grid/AVSR/visualnet2_ctc_vsr/pretrained.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import cv2 +import logging +import numpy as np +import os + +import k2 +import torch +from model import VisualNet2 + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_whole_lattice, +) +from icefall.utils import AttributeDict, get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--words-file", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument( + "--HLG", type=str, required=True, help="Path to HLG.pt." + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.1, + help=""" + Used only when method is whole-lattice-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--lipframes-dirs", + type=str, + nargs="+", + help="The input visual file(s) to transcribe. " + "Supported formats are those supported by cv2.imread(). " + "The frames sample rate is 25fps.", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "num_classes": 28, + "search_beam": 20, + "output_beam": 5, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = VisualNet2(num_classes=params.num_classes) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"]) + model.to(device) + model.eval() + + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method == "whole-lattice-rescoring": + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = G.to(device) + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G.lm_scores = G.scores.clone() + + logging.info("Loading lip roi frames") + + vid = [] + for sample_dir in params.lipframes_dirs: + files = os.listdir(sample_dir) + files = list(filter(lambda file: file.find(".jpg") != -1, files)) + files = sorted(files, key=lambda file: int(os.path.splitext(file)[0])) + array = [cv2.imread(os.path.join(sample_dir, file)) for file in files] + array = list(filter(lambda im: im is not None, array)) + array = [ + cv2.resize(im, (128, 64), interpolation=cv2.INTER_LANCZOS4) + for im in array + ] + array = np.stack(array, axis=0).astype(np.float32) + vid.append(array) + + _, H, W, C = vid[0].shape + features = torch.zeros(len(vid), 75, H, W, C).to(device) + for i in range(len(vid)): + length = vid[i].shape[0] + features[i][:length] = torch.FloatTensor(vid[i]).to(device) + + logging.info("Decoding started") + features = features / 255.0 + with torch.no_grad(): + nnet_output = model(features.permute(0, 4, 1, 2, 3)) + # nnet_output is (N, T, C) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + + s = "\n" + for filename, hyp in zip(params.lipframes_dirs, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/grid/AVSR/visualnet2_ctc_vsr/train.py b/egs/grid/AVSR/visualnet2_ctc_vsr/train.py index 91a1b024a..af9bdec9e 100644 --- a/egs/grid/AVSR/visualnet2_ctc_vsr/train.py +++ b/egs/grid/AVSR/visualnet2_ctc_vsr/train.py @@ -503,13 +503,14 @@ def run(rank, world_size, args): tb_writer = None lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", rank) graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device) - model = VisualNet2() + model = VisualNet2(num_classes=max_token_id + 1) checkpoints = load_checkpoint_if_available(params=params, model=model) diff --git a/egs/grid/AVSR/visualnet_ctc_vsr/pretrained.py b/egs/grid/AVSR/visualnet_ctc_vsr/pretrained.py new file mode 100644 index 000000000..e479b46f2 --- /dev/null +++ b/egs/grid/AVSR/visualnet_ctc_vsr/pretrained.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import cv2 +import logging +import numpy as np +import os + +import k2 +import torch +from model import VisualNet + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_whole_lattice, +) +from icefall.utils import AttributeDict, get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--words-file", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument( + "--HLG", type=str, required=True, help="Path to HLG.pt." + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.1, + help=""" + Used only when method is whole-lattice-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--lipframes-dirs", + type=str, + nargs="+", + help="The input visual file(s) to transcribe. " + "Supported formats are those supported by cv2.imread(). " + "The frames sample rate is 25fps.", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "num_classes": 28, + "search_beam": 20, + "output_beam": 5, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = VisualNet(num_classes=params.num_classes) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"]) + model.to(device) + model.eval() + + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method == "whole-lattice-rescoring": + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = G.to(device) + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G.lm_scores = G.scores.clone() + + logging.info("Loading lip roi frames") + + vid = [] + for sample_dir in params.lipframes_dirs: + files = os.listdir(sample_dir) + files = list(filter(lambda file: file.find(".jpg") != -1, files)) + files = sorted(files, key=lambda file: int(os.path.splitext(file)[0])) + array = [cv2.imread(os.path.join(sample_dir, file)) for file in files] + array = list(filter(lambda im: im is not None, array)) + array = [ + cv2.resize(im, (128, 64), interpolation=cv2.INTER_LANCZOS4) + for im in array + ] + array = np.stack(array, axis=0).astype(np.float32) + vid.append(array) + + _, H, W, C = vid[0].shape + features = torch.zeros(len(vid), 75, H, W, C).to(device) + for i in range(len(vid)): + length = vid[i].shape[0] + features[i][:length] = torch.FloatTensor(vid[i]).to(device) + + logging.info("Decoding started") + features = features / 255.0 + with torch.no_grad(): + nnet_output = model(features.permute(0, 4, 1, 2, 3)) + # nnet_output is (N, T, C) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + + s = "\n" + for filename, hyp in zip(params.lipframes_dirs, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main()