mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
* Fix torch.nn.Embedding error for torch below 1.8.0 * Changes to fbank computation, use lilcom chunky writer * Add min in q,k,v of attention * Remove learnable offset, use relu instead. * Experiments based on SpecAugment change * Merge specaug change from Mingshuang. * Use much more aggressive SpecAug setup * Fix to num_feature_masks bug I introduced; reduce max_frames_mask_fraction 0.4->0.3 * Change p=0.5->0.9, mask_fraction 0.3->0.2 * Change p=0.9 to p=0.8 in SpecAug * Fix num_time_masks code; revert 0.8 to 0.9 * Change max_frames from 0.2 to 0.15 * Remove ReLU in attention * Adding diagnostics code... * Refactor/simplify ConformerEncoder * First version of rand-combine iterated-training-like idea. * Improvements to diagnostics (RE those with 1 dim * Add pelu to this good-performing setup.. * Small bug fixes/imports * Add baseline for the PeLU expt, keeping only the small normalization-related changes. * pelu_base->expscale, add 2xExpScale in subsampling, and in feedforward units. * Double learning rate of exp-scale units * Combine ExpScale and swish for memory reduction * Add import * Fix backprop bug * Fix bug in diagnostics * Increase scale on Scale from 4 to 20 * Increase scale from 20 to 50. * Fix duplicate Swish; replace norm+swish with swish+exp-scale in convolution module * Reduce scale from 50 to 20 * Add deriv-balancing code * Double the threshold in brelu; slightly increase max_factor. * Fix exp dir * Convert swish nonlinearities to ReLU * Replace relu with swish-squared. * Restore ConvolutionModule to state before changes; change all Swish,Swish(Swish) to SwishOffset. * Replace norm on input layer with scale of 0.1. * Extensions to diagnostics code * Update diagnostics * Add BasicNorm module * Replace most normalizations with scales (still have norm in conv) * Change exp dir * Replace norm in ConvolutionModule with a scaling factor. * use nonzero threshold in DerivBalancer * Add min-abs-value 0.2 * Fix dirname * Change min-abs threshold from 0.2 to 0.5 * Scale up pos_bias_u and pos_bias_v before use. * Reduce max_factor to 0.01 * Fix q*scaling logic * Change max_factor in DerivBalancer from 0.025 to 0.01; fix scaling code. * init 1st conv module to smaller variance * Change how scales are applied; fix residual bug * Reduce min_abs from 0.5 to 0.2 * Introduce in_scale=0.5 for SwishExpScale * Fix scale from 0.5 to 2.0 as I really intended.. * Set scaling on SwishExpScale * Add identity pre_norm_final for diagnostics. * Add learnable post-scale for mha * Fix self.post-scale-mha * Another rework, use scales on linear/conv * Change dir name * Reduce initial scaling of modules * Bug-fix RE bias * Cosmetic change * Reduce initial_scale. * Replace ExpScaleRelu with DoubleSwish() * DoubleSwish fix * Use learnable scales for joiner and decoder * Add max-abs-value constraint in DerivBalancer * Add max-abs-value * Change dir name * Remove ExpScale in feedforward layes. * Reduce max-abs limit from 1000 to 100; introduce 2 DerivBalancer modules in conv layer. * Make DoubleSwish more memory efficient * Reduce constraints from deriv-balancer in ConvModule. * Add warmup mode * Remove max-positive constraint in deriv-balancing; add second DerivBalancer in conv module. * Add some extra info to diagnostics * Add deriv-balancer at output of embedding. * Add more stats. * Make epsilon in BasicNorm learnable, optionally. * Draft of 0mean changes.. * Rework of initialization * Fix typo * Remove dead code * Modifying initialization from normal->uniform; add initial_scale when initializing * bug fix re sqrt * Remove xscale from pos_embedding * Remove some dead code. * Cosmetic changes/renaming things * Start adding some files.. * Add more files.. * update decode.py file type * Add remaining files in pruned_transducer_stateless2 * Fix diagnostics-getting code * Scale down pruned loss in warmup mode * Reduce warmup scale on pruned loss form 0.1 to 0.01. * Remove scale_speed, make swish deriv more efficient. * Cosmetic changes to swish * Double warm_step * Fix bug with import * Change initial std from 0.05 to 0.025. * Set also scale for embedding to 0.025. * Remove logging code that broke with newer Lhotse; fix bug with pruned_loss * Add norm+balancer to VggSubsampling * Incorporate changes from master into pruned_transducer_stateless2. * Add max-abs=6, debugged version * Change 0.025,0.05 to 0.01 in initializations * Fix balancer code * Whitespace fix * Reduce initial pruned_loss scale from 0.01 to 0.0 * Increase warm_step (and valid_interval) * Change max-abs from 6 to 10 * Change how warmup works. * Add changes from master to decode.py, train.py * Simplify the warmup code; max_abs 10->6 * Make warmup work by scaling layer contributions; leave residual layer-drop * Fix bug * Fix test mode with random layer dropout * Add random-number-setting function in dataloader * Fix/patch how fix_random_seed() is imported. * Reduce layer-drop prob * Reduce layer-drop prob after warmup to 1 in 100 * Change power of lr-schedule from -0.5 to -0.333 * Increase model_warm_step to 4k * Change max-keep-prob to 0.95 * Refactoring and simplifying conformer and frontend * Rework conformer, remove some code. * Reduce 1st conv channels from 64 to 32 * Add another convolutional layer * Fix padding bug * Remove dropout in output layer * Reduce speed of some components * Initial refactoring to remove unnecessary vocab_size * Fix RE identity * Bug-fix * Add final dropout to conformer * Remove some un-used code * Replace nn.Linear with ScaledLinear in simple joiner * Make 2 projections.. * Reduce initial_speed * Use initial_speed=0.5 * Reduce initial_speed further from 0.5 to 0.25 * Reduce initial_speed from 0.5 to 0.25 * Change how warmup is applied. * Bug fix to warmup_scale * Fix test-mode * Remove final dropout * Make layer dropout rate 0.075, was 0.1. * First draft of model rework * Various bug fixes * Change learning speed of simple_lm_proj * Revert transducer_stateless/ to state in upstream/master * Fix to joiner to allow different dims * Some cleanups * Make training more efficient, avoid redoing some projections. * Change how warm-step is set * First draft of new approach to learning rates + init * Some fixes.. * Change initialization to 0.25 * Fix type of parameter * Fix weight decay formula by adding 1/1-beta * Fix weight decay formula by adding 1/1-beta * Fix checkpoint-writing * Fix to reading scheudler from optim * Simplified optimizer, rework somet things.. * Reduce model_warm_step from 4k to 3k * Fix bug in lambda * Bug-fix RE sign of target_rms * Changing initial_speed from 0.25 to 01 * Change some defaults in LR-setting rule. * Remove initial_speed * Set new scheduler * Change exponential part of lrate to be epoch based * Fix bug * Set 2n rule.. * Implement 2o schedule * Make lrate rule more symmetric * Implement 2p version of learning rate schedule. * Refactor how learning rate is set. * Fix import * Modify init (#301) * update icefall/__init__.py to import more common functions. * update icefall/__init__.py * make imports style consistent. * exclude black check for icefall/__init__.py in pyproject.toml. * Minor fixes for logging (#296) * Minor fixes for logging * Minor fix * Fix dir names * Modify beam search to be efficient with current joienr * Fix adding learning rate to tensorboard * Fix docs in optim.py * Support mix precision training on the reworked model (#305) * Add mix precision support * Minor fixes * Minor fixes * Minor fixes * Tedlium3 pruned transducer stateless (#261) * update tedlium3-pruned-transducer-stateless-codes * update README.md * update README.md * add fast beam search for decoding * do a change for RESULTS.md * do a change for RESULTS.md * do a fix * do some changes for pruned RNN-T * Add mix precision support * Minor fixes * Minor fixes * Updating RESULTS.md; fix in beam_search.py * Fix rebase * Code style check for librispeech pruned transducer stateless2 (#308) * Update results for tedlium3 pruned RNN-T (#307) * Update README.md * Fix CI errors. (#310) * Add more results * Fix tensorboard log location * Add one more epoch of full expt * fix comments * Add results for mixed precision with max-duration 300 * Changes for pretrained.py (tedlium3 pruned RNN-T) (#311) * GigaSpeech recipe (#120) * initial commit * support download, data prep, and fbank * on-the-fly feature extraction by default * support BPE based lang * support HLG for BPE * small fix * small fix * chunked feature extraction by default * Compute features for GigaSpeech by splitting the manifest. * Fixes after review. * Split manifests into 2000 pieces. * set audio duration mismatch tolerance to 0.01 * small fix * add conformer training recipe * Add conformer.py without pre-commit checking * lazy loading and use SingleCutSampler * DynamicBucketingSampler * use KaldifeatFbank to compute fbank for musan * use pretrained language model and lexicon * use 3gram to decode, 4gram to rescore * Add decode.py * Update .flake8 * Delete compute_fbank_gigaspeech.py * Use BucketingSampler for valid and test dataloader * Update params in train.py * Use bpe_500 * update params in decode.py * Decrease num_paths while CUDA OOM * Added README * Update RESULTS * black * Decrease num_paths while CUDA OOM * Decode with post-processing * Update results * Remove lazy_load option * Use default `storage_type` * Keep the original tolerance * Use split-lazy * black * Update pretrained model Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com> * Add LG decoding (#277) * Add LG decoding * Add log weight pushing * Minor fixes * Support computing RNN-T loss with torchaudio (#316) * Support modified beam search decoding for streaming inference with Emformer model. * Formatted imports. * Update results for torchaudio RNN-T. (#322) * Fixed streaming decoding codes for emformer model. * Fixed docs. * Sorted imports for transducer_emformer/streaming_feature_extractor.py * Minor fix for transducer_emformer/streaming_feature_extractor.py Co-authored-by: pkufool <wkang@pku.org.cn> Co-authored-by: Daniel Povey <dpovey@gmail.com> Co-authored-by: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com> Co-authored-by: Guo Liyong <guonwpu@qq.com> Co-authored-by: Wang, Guanbo <wgb14@outlook.com>
185 lines
5.0 KiB
Python
185 lines
5.0 KiB
Python
#!/usr/bin/env python3
|
|
#
|
|
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang
|
|
# Mingshuang Luo)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# This script converts several saved checkpoints
|
|
# to a single one using model averaging.
|
|
"""
|
|
Usage:
|
|
./pruned_transducer_stateless/export.py \
|
|
--exp-dir ./pruned_transducer_stateless/exp \
|
|
--bpe-model data/lang_bpe_500/bpe.model \
|
|
--epoch 29 \
|
|
--avg 13
|
|
|
|
It will generate a file exp_dir/pretrained.pt
|
|
|
|
To use the generated file with `pruned_transducer_stateless/decode.py`,
|
|
you can do:
|
|
|
|
cd /path/to/exp_dir
|
|
ln -s pretrained.pt epoch-9999.pt
|
|
|
|
cd /path/to/egs/tedlium3/ASR
|
|
./pruned_transducer_stateless/decode.py \
|
|
--exp-dir ./pruned_transducer_stateless/exp \
|
|
--epoch 9999 \
|
|
--avg 1 \
|
|
--max-duration 1 \
|
|
--bpe-model data/lang_bpe_500/bpe.model
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
import sentencepiece as spm
|
|
import torch
|
|
from train import get_params, get_transducer_model
|
|
|
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
|
from icefall.utils import str2bool
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--epoch",
|
|
type=int,
|
|
default=30,
|
|
help="It specifies the checkpoint to use for decoding."
|
|
"Note: Epoch counts from 0.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--avg",
|
|
type=int,
|
|
default=13,
|
|
help="Number of checkpoints to average. Automatically select "
|
|
"consecutive checkpoints before the checkpoint specified by "
|
|
"'--epoch'. ",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--exp-dir",
|
|
type=str,
|
|
default="pruned_transducer_stateless/exp",
|
|
help="""It specifies the directory where all training related
|
|
files, e.g., checkpoints, log, etc, are saved
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--bpe-model",
|
|
type=str,
|
|
default="data/lang_bpe_500/bpe.model",
|
|
help="Path to the BPE model",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--jit",
|
|
type=str2bool,
|
|
default=False,
|
|
help="""True to save a model after applying torch.jit.script.
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--context-size",
|
|
type=int,
|
|
default=2,
|
|
help="The context size in the decoder. 1 means bigram; "
|
|
"2 means tri-gram",
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
def main():
|
|
args = get_parser().parse_args()
|
|
args.exp_dir = Path(args.exp_dir)
|
|
|
|
assert args.jit is False, "Support torchscript will be added later"
|
|
|
|
params = get_params()
|
|
params.update(vars(args))
|
|
|
|
device = torch.device("cpu")
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda", 0)
|
|
|
|
logging.info(f"device: {device}")
|
|
|
|
sp = spm.SentencePieceProcessor()
|
|
sp.load(params.bpe_model)
|
|
|
|
# <blk> and <unk> are defined in local/train_bpe_model.py
|
|
params.blank_id = sp.piece_to_id("<blk>")
|
|
params.unk_id = sp.piece_to_id("<unk>")
|
|
params.vocab_size = sp.get_piece_size()
|
|
|
|
logging.info(params)
|
|
|
|
logging.info("About to create model")
|
|
model = get_transducer_model(params)
|
|
|
|
model.to(device)
|
|
|
|
if params.avg == 1:
|
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
|
else:
|
|
start = params.epoch - params.avg + 1
|
|
filenames = []
|
|
for i in range(start, params.epoch + 1):
|
|
if start >= 0:
|
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
|
logging.info(f"averaging {filenames}")
|
|
model.to(device)
|
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
|
|
|
model.eval()
|
|
|
|
model.to("cpu")
|
|
model.eval()
|
|
|
|
if params.jit:
|
|
logging.info("Using torch.jit.script")
|
|
model = torch.jit.script(model)
|
|
filename = params.exp_dir / "cpu_jit.pt"
|
|
model.save(str(filename))
|
|
logging.info(f"Saved to {filename}")
|
|
else:
|
|
logging.info("Not using torch.jit.script")
|
|
# Save it using a format so that it can be loaded
|
|
# by :func:`load_checkpoint`
|
|
filename = params.exp_dir / "pretrained.pt"
|
|
torch.save({"model": model.state_dict()}, str(filename))
|
|
logging.info(f"Saved to {filename}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
formatter = (
|
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
)
|
|
|
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
main()
|