From aebe9c22dd5cd519e7005d1944b432fcbbfc85d3 Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 7 Jun 2022 15:59:58 +0800 Subject: [PATCH] Minor fixes --- .../ASR/pruned_transducer_stateless/decode.py | 16 +- .../ASR/pruned_transducer_stateless/export.py | 11 +- .../ASR/pruned_transducer_stateless/train.py | 15 +- .../pruned_transducer_stateless2/decode.py | 15 +- .../pruned_transducer_stateless2/export.py | 11 +- .../ASR/pruned_transducer_stateless2/train.py | 15 +- .../pruned_transducer_stateless3/decode.py | 14 +- .../pruned_transducer_stateless3/export.py | 11 +- .../ASR/pruned_transducer_stateless3/train.py | 14 +- .../pruned_transducer_stateless4/decode.py | 15 +- .../pruned_transducer_stateless4/export.py | 329 +++++++++++++++++- .../ASR/pruned_transducer_stateless4/train.py | 15 +- 12 files changed, 460 insertions(+), 21 deletions(-) mode change 120000 => 100755 egs/librispeech/ASR/pruned_transducer_stateless4/export.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 6b048519a..ee7bfac53 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -303,6 +303,15 @@ def get_parser(): """, ) + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + parser.add_argument( "--decode-chunk-size", type=int, @@ -368,7 +377,7 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.simulate_streaming: - encoder_out, encoder_out_lens = model.encoder.streaming_forward( + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, states=[], @@ -639,8 +648,9 @@ def main(): params.vocab_size = sp.get_piece_size() if params.simulate_streaming: - # Decoding in streaming requires causal convolution" - params.causal_convolution = True + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index 4c4c0ee2d..8f3ec3a20 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -150,6 +150,15 @@ def get_parser(): """, ) + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + exporting a streaming model. + """, + ) + return parser @@ -175,7 +184,7 @@ def main(): params.vocab_size = sp.get_piece_size() if params.streaming_model: - params.causal_convolution = True + assert params.causal_convolution logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 7dcbc23ec..94c597693 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -37,6 +37,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --exp-dir pruned_transducer_stateless/exp \ --full-libri 1 \ --dynamic-chunk-training 1 \ + --causal-convolution 1 \ --short-chunk-size 25 \ --num-left-chunks 4 \ --max-duration 300 @@ -243,6 +244,15 @@ def get_parser(): """, ) + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + parser.add_argument( "--short-chunk-size", type=int, @@ -804,8 +814,9 @@ def run(rank, world_size, args): params.vocab_size = sp.get_piece_size() if params.dynamic_chunk_training: - # dynamic_chunk_training requires causal convolution - params.causal_convolution = True + assert ( + params.causal_convolution + ), "dynamic_chunk_training requires causal convolution" logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 96d5f12e2..9cd494b05 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -59,6 +59,7 @@ Usage: --epoch 28 \ --avg 15 \ --simulate-streaming 1 \ + --causal-convolution 1 \ --decode-chunk-size 16 \ --left-context 64 \ --exp-dir ./pruned_transducer_stateless2/exp \ @@ -255,6 +256,15 @@ def get_parser(): """, ) + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + parser.add_argument( "--decode-chunk-size", type=int, @@ -584,8 +594,9 @@ def main(): params.vocab_size = sp.get_piece_size() if params.simulate_streaming: - # Decoding in streaming requires causal convolution - params.causal_convolution = True + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index 3c6424359..97f2facb2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -165,6 +165,15 @@ def get_parser(): """, ) + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + exporting a streaming model. + """, + ) + return parser @@ -189,7 +198,7 @@ def main(): params.vocab_size = sp.get_piece_size() if params.streaming_model: - params.causal_convolution = True + assert params.causal_convolution logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 465d7da0f..3a3c889dd 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -48,6 +48,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --exp-dir pruned_transducer_stateless/exp \ --full-libri 1 \ --dynamic-chunk-training 1 \ + --causal_convolution 1 \ --short-chunk-size 25 \ --num-left-chunks 4 \ --max-duration 300 @@ -283,6 +284,15 @@ def get_parser(): """, ) + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + parser.add_argument( "--short-chunk-size", type=int, @@ -847,8 +857,9 @@ def run(rank, world_size, args): params.vocab_size = sp.get_piece_size() if params.dynamic_chunk_training: - # dynamic_chunk_training requires causal convolution - params.causal_convolution = True + assert ( + params.causal_convolution + ), "dynamic_chunk_training requires causal convolution" logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index a95af9ddc..154e1f074 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -265,6 +265,15 @@ def get_parser(): """, ) + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + parser.add_argument( "--decode-chunk-size", type=int, @@ -626,8 +635,9 @@ def main(): params.vocab_size = sp.get_piece_size() if params.simulate_streaming: - # Decoding in streaming requires causal convolution - params.causal_convolution = True + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index 020586e7d..bab8a9910 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -166,6 +166,15 @@ def get_parser(): """, ) + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + exporting a streaming model. + """, + ) + return parser @@ -190,7 +199,7 @@ def main(): params.vocab_size = sp.get_piece_size() if params.streaming_model: - params.causal_convolution = True + assert params.causal_convolution logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 18473167d..b2339df75 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -295,6 +295,15 @@ def get_parser(): """, ) + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + parser.add_argument( "--short-chunk-size", type=int, @@ -935,8 +944,9 @@ def run(rank, world_size, args): params.vocab_size = sp.get_piece_size() if params.dynamic_chunk_training: - # dynamic_chunk_training requires causal convolution - params.causal_convolution = True + assert ( + params.causal_convolution + ), "dynamic_chunk_training requires causal convolution" logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index d830886d3..a1ec8d046 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -60,6 +60,7 @@ Usage: --epoch 30 \ --avg 15 \ --simulate-streaming 1 \ + --causal-convolution 1 \ --decode-chunk-size 16 \ --left-context 64 \ --exp-dir ./pruned_transducer_stateless4/exp \ @@ -267,6 +268,15 @@ def get_parser(): """, ) + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + parser.add_argument( "--decode-chunk-size", type=int, @@ -599,8 +609,9 @@ def main(): params.vocab_size = sp.get_piece_size() if params.simulate_streaming: - # Decoding in streaming requires causal convolution - params.causal_convolution = True + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py deleted file mode 120000 index 19c56a722..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/export.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py new file mode 100755 index 000000000..6bc8b2a39 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py @@ -0,0 +1,328 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" +Usage: +./pruned_transducer_stateless4/export.py \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file exp_dir/pretrained.pt + +To use the generated file with `pruned_transducer_stateless4/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./pruned_transducer_stateless4/decode.py \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 100 \ + --bpe-model data/lang_bpe_500/bpe.model +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from train import get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + Note: not needed here, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + Note: not needed for here, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="""How many left context can be seen in chunks when calculating attention. + Note: not needed here, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--streaming-model", + type=str2bool, + default=False, + help="""Whether to export a streaming model, if the models in exp-dir + are streaming model, this should be True. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + exporting a streaming model. + """, + ) + + return parser + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.streaming_model: + assert params.causal_convolution + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 637c6214b..5cf92c77e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -49,6 +49,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --exp-dir pruned_transducer_stateless4/exp \ --full-libri 1 \ --dynamic-chunk-training 1 \ + --causal-convolution 1 \ --short-chunk-size 25 \ --num-left-chunks 4 \ --max-duration 300 @@ -301,6 +302,15 @@ def get_parser(): """, ) + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + parser.add_argument( "--short-chunk-size", type=int, @@ -888,8 +898,9 @@ def run(rank, world_size, args): params.vocab_size = sp.get_piece_size() if params.dynamic_chunk_training: - # dynamic_chunk_training requires causal convolution - params.causal_convolution = True + assert ( + params.causal_convolution + ), "dynamic_chunk_training requires causal convolution" logging.info(params)