From ef56cd87e4194dee298442961ae4e6b7d7cd31c6 Mon Sep 17 00:00:00 2001 From: Gaurav Shegokar Date: Mon, 16 May 2022 17:04:57 +0000 Subject: [PATCH] pretrain scriptping --- .gitignore | 2 + egs/gigaspeech/ASR/conformer_ctc/decode.py | 9 +- .../ASR/conformer_ctc/pretrained.py | 533 ++++++++++++++++++ env.sh | 4 + 4 files changed, 545 insertions(+), 3 deletions(-) create mode 100755 egs/gigaspeech/ASR/conformer_ctc/pretrained.py create mode 100644 env.sh diff --git a/.gitignore b/.gitignore index 1dbf8f395..76987050d 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,5 @@ log *.bak *-bak *bak.py +birch +.vscode \ No newline at end of file diff --git a/egs/gigaspeech/ASR/conformer_ctc/decode.py b/egs/gigaspeech/ASR/conformer_ctc/decode.py index a810bef06..6d6b3f6a0 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/decode.py +++ b/egs/gigaspeech/ASR/conformer_ctc/decode.py @@ -660,7 +660,10 @@ def main(): ) if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + if params.epoch == 0: + load_checkpoint(f"{params.exp_dir}/pretrained.pt", model) + else: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: start = params.epoch - params.avg + 1 filenames = [] @@ -684,8 +687,8 @@ def main(): dev_dl = gigaspeech.test_dataloaders(dev_cuts) test_dl = gigaspeech.test_dataloaders(test_cuts) - test_sets = ["dev", "test"] - test_dls = [dev_dl, test_dl] + test_sets = ["test"] + test_dls = [test_dl] for test_set, test_dl in zip(test_sets, test_dls): results_dict = decode_dataset( diff --git a/egs/gigaspeech/ASR/conformer_ctc/pretrained.py b/egs/gigaspeech/ASR/conformer_ctc/pretrained.py new file mode 100755 index 000000000..4638a14a7 --- /dev/null +++ b/egs/gigaspeech/ASR/conformer_ctc/pretrained.py @@ -0,0 +1,533 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +from genericpath import exists +import logging +import math +from typing import Dict, List, Optional, Tuple + + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +import torch.nn as nn +from shortuuid import ShortUUID +from pathlib import Path +import os +import datetime +from torch.utils.data import Dataset, DataLoader + +from conformer import Conformer +from torch.nn.utils.rnn import pad_sequence + + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_whole_lattice, +) +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, get_texts +import string + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + (3) attention-decoder - Extract n paths from the rescored + lattice and use the transformer attention decoder for + rescoring. + We call it HLG decoding + n-gram LM rescoring + attention + decoder rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or attention-decoder. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.2, + help=""" + Used only when method is whole-lattice-rescoring and attention-decoder. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--attention-decoder-scale", + type=float, + default=1.2, + help=""" + Used only when method is attention-decoder. + It specifies the scale for attention decoder scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help=""" + Used only when method is attention-decoder. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--sos-id", + type=int, + default=1, + help=""" + Used only when method is attention-decoder. + It specifies ID for the SOS token. + """, + ) + + parser.add_argument( + "--num-classes", + type=int, + default=500, + help=""" + Vocab size in the BPE model. + """, + ) + + parser.add_argument( + "--eos-id", + type=int, + default=1, + help=""" + Used only when method is attention-decoder. + It specifies ID for the EOS token. + """, + ) + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help=""" + sampling rate for input sound files + """, + ) + parser.add_argument( + "--output-dir", + type=str, + default="output", + help=""" + Output directory name, we store output hypothesis + """ + + ) + + parser.add_argument( + "--batch-size", + type=int, + default=10, + help=""" + Number of input files in one batch, defaulted to 10 + """ + ) + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "sample_rate": 16000, + # parameters for conformer + "subsampling_factor": 4, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "num_decoder_layers": 6, + # parameters for decoding + "search_beam": 20, # default 20 + "output_beam": 8, # default 8 + "min_active_states": 30, + "max_active_states": 3000, + "use_double_scores": True, + } + ) + return params + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + wave_names = [] + + def loadfile(filename): + wave, sample_rate = torchaudio.load(filename) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + wave_names.append(str(filename)) + + for f in filenames: + file_path = Path(f) + if file_path.is_file(): + loadfile(file_path) + elif file_path.is_dir(): + for filename in file_path.iterdir(): + loadfile(filename) + else: + logging.error(f"{f} must be a filename or a dirname") + return ans, wave_names + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + decoding_graph: Optional[k2.Fsa], + feature: torch.Tensor, + bpe_model: Optional[spm.SentencePieceProcessor] = None, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + + device = decoding_graph.device + assert feature.ndim == 3 + feature = feature.to(device) + with torch.no_grad(): + nnet_output, memory, memory_key_padding_mask = model(feature) + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "ctc-decoding": + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + + elif params.method in [ + "1best", + "whole-lattice-rescoring", + "attention-decoder", + ]: + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "attention-decoder": + logging.info("Use HLG + LM rescoring + attention decoder rescoring") + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None + ) + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=params.sos_id, + eos_id=params.eos_id, + nbest_scale=params.nbest_scale, + ngram_lm_scale=params.ngram_lm_scale, + attention_scale=params.attention_decoder_scale, + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + + return hyps + + +class TestDataset(Dataset): + def __init__(self, features: torch.Tensor): + self.features = features + + def __len__(self): + return len(self.features) + + def __getitem__(self, idx): + return (self.features[idx], 0) + + +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + if args.method != "attention-decoder": + # to save memory as the attention decoder + # will not be used + params.num_decoder_layers = 0 + + params.update(vars(args)) + + su = ShortUUID(alphabet=string.ascii_lowercase + string.digits) + output_dir = Path(args.output_dir)/(su.random(length=5) + '-' + datetime.datetime.now().strftime( + "%Y-%m-%d_%H-%M")) + output_dir.mkdir(exist_ok=True, parents=True) + + with (output_dir/"command").open("w") as f_cmd: + f_cmd.write(f"{args}\n") + f_cmd.write(f"{params}\n") + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=params.num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves, wave_names = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + G, bpe_model = None, None + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + H.to(device) + decoding_graph = H + elif params.method in [ + "1best", + "whole-lattice-rescoring", + "attention-decoder", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + decoding_graph = HLG + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "whole-lattice-rescoring", + "attention-decoder", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = G.to(device) + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G.lm_scores = G.scores.clone() + + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + testdata = TestDataset(features) + tl = DataLoader(testdata, batch_size=args.batch_size) + hyps = [] + num_batches = len(tl) + + for batch_idx, batch in enumerate(tl): + hyps.extend(decode_one_batch( + params, model, decoding_graph, batch[0], bpe_model, G)) + logging.info( + f"batch {batch_idx + 1}/{num_batches}, cuts processed until now is {len(hyps)}") + + logging.info(f"Writing hypothesis to output dir {output_dir}") + s = "\n" + for filename, hyp in zip(wave_names, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + with (output_dir/os.path.basename(filename.replace(".wav", ".txt"))).open("w") as f_hyp: + f_hyp.write(words) + logging.info(s) + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/env.sh b/env.sh new file mode 100644 index 000000000..f6e5611a7 --- /dev/null +++ b/env.sh @@ -0,0 +1,4 @@ +export PYTHONPATH=$PYTHONPATH:/usr/local/cuda-10.1/targets/x86_64-linux/lib/ +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-10.1/targets/x86_64-linux/lib/ +export PATH=$PATH:/usr/local/cuda-10.1/bin/ +export CUDA_HOME=/usr/local/cuda-10.1/ \ No newline at end of file