From d55cec6b8741a6007eea53150f2ef4cdf82e19cb Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 30 Sep 2022 16:54:11 +0800 Subject: [PATCH] Add lstm-transducer model for the wenetspeech dataset --- .../ASR/lstm_transducer_stateless/export.py | 403 ++++++++++++++++++ .../lstm_transducer_stateless/ncnn-decode.py | 296 +++++++++++++ .../ASR/lstm_transducer_stateless/train.py | 3 + 3 files changed, 702 insertions(+) create mode 100755 egs/wenetspeech/ASR/lstm_transducer_stateless/export.py create mode 100755 egs/wenetspeech/ASR/lstm_transducer_stateless/ncnn-decode.py diff --git a/egs/wenetspeech/ASR/lstm_transducer_stateless/export.py b/egs/wenetspeech/ASR/lstm_transducer_stateless/export.py new file mode 100755 index 000000000..3529cbfe4 --- /dev/null +++ b/egs/wenetspeech/ASR/lstm_transducer_stateless/export.py @@ -0,0 +1,403 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.trace() + +./lstm_transducer_stateless/export.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --lang-dir data/lang_char \ + --epoch 35 \ + --avg 10 \ + --jit-trace 1 + +It will generate 3 files: `encoder_jit_trace.pt`, +`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`. + +(2) Export `model.state_dict()` + +./lstm_transducer_stateless/export.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --lang-dir data/lang_char \ + --epoch 35 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `lstm_transducer_stateless/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./lstm_transducer_stateless/decode.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --lang-dir data/lang_char \ + +Check ./pretrained.py for its usage. +""" + +import argparse +import logging +from pathlib import Path + +import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model +from icefall.lexicon import Lexicon + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless3/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="The lang dir", + ) + + parser.add_argument( + "--jit-trace", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.trace. + It will generate 3 files: + - encoder_jit_trace.pt + - decoder_jit_trace.pt + - joiner_jit_trace.pt + + Check ./jit_pretrained.py for how to use them. + """, + ) + + parser.add_argument( + "--pnnx", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.trace for later + converting to PNNX. It will generate 3 files: + - encoder_jit_trace-pnnx.pt + - decoder_jit_trace-pnnx.pt + - joiner_jit_trace-pnnx.pt + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + states = encoder_model.get_init_states() + + traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + lexicon = Lexicon(params.lang_dir) + + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + if params.pnnx: + params.is_pnnx = params.pnnx + logging.info("For PNNX") + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.pnnx: + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.trace()") + encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + elif params.jit_trace is True: + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.trace()") + encoder_filename = params.exp_dir / "encoder_jit_trace.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + decoder_filename = params.exp_dir / "decoder_jit_trace.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + joiner_filename = params.exp_dir / "joiner_jit_trace.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + else: + logging.info("Not using torchscript") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/wenetspeech/ASR/lstm_transducer_stateless/ncnn-decode.py b/egs/wenetspeech/ASR/lstm_transducer_stateless/ncnn-decode.py new file mode 100755 index 000000000..50901e6da --- /dev/null +++ b/egs/wenetspeech/ASR/lstm_transducer_stateless/ncnn-decode.py @@ -0,0 +1,296 @@ +#!/usr/bin/env python3 +# flake8: noqa +# +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + ./lstm_transducer_stateless/ncnn-decode.py \ + --token-filename ./data/lang_char/tokens.txt \ + --encoder-param-filename ./lstm_transducer_stateless/exp-new/encoder_jit_trace-epoch-11-avg-2-pnnx.ncnn.param \ + --encoder-bin-filename ./lstm_transducer_stateless/exp-new/encoder_jit_trace-epoch-11-avg-2-pnnx.ncnn.bin \ + --decoder-param-filename ./lstm_transducer_stateless/exp-new/decoder_jit_trace-epoch-11-avg-2-pnnx.ncnn.param \ + --decoder-bin-filename ./lstm_transducer_stateless/exp-new/decoder_jit_trace-epoch-11-avg-2-pnnx.ncnn.bin \ + --joiner-param-filename ./lstm_transducer_stateless/exp-new/joiner_jit_trace-epoch-11-avg-2-pnnx.ncnn.param \ + --joiner-bin-filename ./lstm_transducer_stateless/exp-new/joiner_jit_trace-epoch-11-avg-2-pnnx.ncnn.bin \ + ./test_wavs/DEV_T0000000001.wav +""" + +import argparse +import logging +from typing import List + +import k2 +import kaldifeat +import ncnn +import torch +import torchaudio + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--token-filename", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--encoder-param-filename", + type=str, + help="Path to encoder.ncnn.param", + ) + + parser.add_argument( + "--encoder-bin-filename", + type=str, + help="Path to encoder.ncnn.bin", + ) + + parser.add_argument( + "--decoder-param-filename", + type=str, + help="Path to decoder.ncnn.param", + ) + + parser.add_argument( + "--decoder-bin-filename", + type=str, + help="Path to decoder.ncnn.bin", + ) + + parser.add_argument( + "--joiner-param-filename", + type=str, + help="Path to joiner.ncnn.param", + ) + + parser.add_argument( + "--joiner-bin-filename", + type=str, + help="Path to joiner.ncnn.bin", + ) + + parser.add_argument( + "sound_filename", + type=str, + help="Path to foo.wav", + ) + + return parser.parse_args() + + +class Model: + def __init__(self, args): + self.init_encoder(args) + self.init_decoder(args) + self.init_joiner(args) + + def init_encoder(self, args): + encoder_net = ncnn.Net() + encoder_net.opt.use_packing_layout = False + encoder_net.opt.use_fp16_storage = False + encoder_param = args.encoder_param_filename + encoder_model = args.encoder_bin_filename + + encoder_net.load_param(encoder_param) + encoder_net.load_model(encoder_model) + + self.encoder_net = encoder_net + + def init_decoder(self, args): + decoder_param = args.decoder_param_filename + decoder_model = args.decoder_bin_filename + + decoder_net = ncnn.Net() + decoder_net.opt.use_packing_layout = False + + decoder_net.load_param(decoder_param) + decoder_net.load_model(decoder_model) + + self.decoder_net = decoder_net + + def init_joiner(self, args): + joiner_param = args.joiner_param_filename + joiner_model = args.joiner_bin_filename + joiner_net = ncnn.Net() + joiner_net.opt.use_packing_layout = False + joiner_net.load_param(joiner_param) + joiner_net.load_model(joiner_model) + + self.joiner_net = joiner_net + + def run_encoder(self, x, states): + with self.encoder_net.create_extractor() as ex: + ex.set_num_threads(10) + ex.input("in0", ncnn.Mat(x.numpy()).clone()) + x_lens = torch.tensor([x.size(0)], dtype=torch.float32) + ex.input("in1", ncnn.Mat(x_lens.numpy()).clone()) + ex.input("in2", ncnn.Mat(states[0].numpy()).clone()) + ex.input("in3", ncnn.Mat(states[1].numpy()).clone()) + + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + + ret, ncnn_out1 = ex.extract("out1") + assert ret == 0, ret + + ret, ncnn_out2 = ex.extract("out2") + assert ret == 0, ret + + ret, ncnn_out3 = ex.extract("out3") + assert ret == 0, ret + + encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( + torch.int32 + ) + hx = torch.from_numpy(ncnn_out2.numpy()).clone() + cx = torch.from_numpy(ncnn_out3.numpy()).clone() + return encoder_out, encoder_out_lens, hx, cx + + def run_decoder(self, decoder_input): + assert decoder_input.dtype == torch.int32 + + with self.decoder_net.create_extractor() as ex: + ex.set_num_threads(10) + ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone()) + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() + return decoder_out + + def run_joiner(self, encoder_out, decoder_out): + with self.joiner_net.create_extractor() as ex: + ex.set_num_threads(10) + ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone()) + ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone()) + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone() + return joiner_out + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search(model: Model, encoder_out: torch.Tensor): + assert encoder_out.ndim == 2 + T = encoder_out.size(0) + + context_size = 2 + blank_id = 0 # hard-code to 0 + hyp = [blank_id] * context_size + + decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size) + + decoder_out = model.run_decoder(decoder_input).squeeze(0) + # print(decoder_out.shape) # (512,) + + for t in range(T): + encoder_out_t = encoder_out[t] + joiner_out = model.run_joiner(encoder_out_t, decoder_out) + # print(joiner_out.shape) # [500] + y = joiner_out.argmax(dim=0).tolist() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor(decoder_input, dtype=torch.int32) + decoder_out = model.run_decoder(decoder_input).squeeze(0) + return hyp[context_size:] + + +def main(): + args = get_args() + logging.info(vars(args)) + + model = Model(args) + + sound_file = args.sound_filename + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {sound_file}") + wave_samples = read_sound_files( + filenames=[sound_file], + expected_sample_rate=sample_rate, + )[0] + + logging.info("Decoding started") + features = fbank(wave_samples) + + num_encoder_layers = 12 + d_model = 512 + rnn_hidden_size = 1024 + + states = ( + torch.zeros(num_encoder_layers, d_model), + torch.zeros( + num_encoder_layers, + rnn_hidden_size, + ), + ) + + encoder_out, encoder_out_lens, hx, cx = model.run_encoder(features, states) + hyp = greedy_search(model, encoder_out) + + logging.info(sound_file) + + token_table = k2.SymbolTable.from_file(args.token_filename) + words = [token_table[i] for i in hyp] + logging.info("".join(words)) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/wenetspeech/ASR/lstm_transducer_stateless/train.py b/egs/wenetspeech/ASR/lstm_transducer_stateless/train.py index 770833ef5..5b9712fda 100755 --- a/egs/wenetspeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/wenetspeech/ASR/lstm_transducer_stateless/train.py @@ -405,6 +405,8 @@ def get_params() -> AttributeDict: "decoder_dim": 512, # parameters for joiner "joiner_dim": 512, + # True to generate a model that can be exported via PNNX + "is_pnnx": False, "env_info": get_env_info(), } ) @@ -421,6 +423,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, aux_layer_period=params.aux_layer_period, + is_pnnx=params.is_pnnx, ) return encoder