mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
* Bug fix * Change subsamplling factor from 1 to 2 * Implement AttentionCombine as replacement for RandomCombine * Decrease random_prob from 0.5 to 0.333 * Add print statement * Apply single_prob mask, so sometimes we just get one layer as output. * Introduce feature mask per frame * Include changes from Liyong about padding conformer module. * Reduce single_prob from 0.5 to 0.25 * Reduce feature_mask_dropout_prob from 0.25 to 0.15. * Remove dropout from inside ConformerEncoderLayer, for adding to residuals * Increase feature_mask_dropout_prob from 0.15 to 0.2. * Swap random_prob and single_prob, to reduce prob of being randomized. * Decrease feature_mask_dropout_prob back from 0.2 to 0.15, i.e. revert the 43->48 change. * Randomize order of some modules * Bug fix * Stop backprop bug * Introduce a scale dependent on the masking value * Implement efficient layer dropout * Simplify the learned scaling factor on the modules * Compute valid loss on batch 0. * Make the scaling factors more global and the randomness of dropout more random * Bug fix * Introduce offset in layerdrop_scaleS * Remove final combination; implement layer drop that drops the final layers. * Bug fices * Fix bug RE self.training * Fix bug setting layerdrop mask * Fix eigs call * Add debug info * Remove warmup * Remove layer dropout and model-level warmup * Don't always apply the frame mask * Slight code cleanup/simplification * Various fixes, finish implementating frame masking * Remove debug info * Don't compute validation if printing diagnostics. * Apply layer bypass during warmup in a new way, including 2s and 4s of layers. * Update checkpoint.py to deal with int params * Revert initial_scale to previous values. * Remove the feature where it was bypassing groups of layers. * Implement layer dropout with probability 0.075 * Fix issue with warmup in test time * Add warmup schedule where dropout disappears from earlier layers first. * Have warmup that gradually removes dropout from layers; multiply initialization scales by 0.1. * Do dropout a different way * Fix bug in warmup * Remove debug print * Make the warmup mask per frame. * Implement layer dropout (in a relatively efficient way) * Decrease initial keep_prob to 0.25. * Make it start warming up from the very start, and increase warmup_batches to 6k * Change warmup schedule and increase warmup_batches from 4k to 6k * Make the bypass scale trainable. * Change the initial keep-prob back from 0.25 to 0.5 * Bug fix * Limit bypass scale to >= 0.1 * Revert "Change warmup schedule and increase warmup_batches from 4k to 6k" This reverts commit 86845bd5d859ceb6f83cd83f3719c3e6641de987. * Do warmup by dropping out whole layers. * Decrease frequency of logging variance_proportion * Make layerdrop different in different processes. * For speed, drop the same num layers per job. * Decrease initial_layerdrop_prob from 0.75 to 0.5 * Revert also the changes in scaled_adam_exp85 regarding warmup schedule * Remove unused code LearnedScale. * Reintroduce batching to the optimizer * Various fixes from debugging with nvtx, but removed the NVTX annotations. * Only apply ActivationBalancer with prob 0.25. * Fix s -> scaling for import. * Increase final layerdrop prob from 0.05 to 0.075 * Fix bug where fewer layers were dropped than should be; remove unnecesary print statement. * Fix bug in choosing layers to drop * Refactor RelPosMultiheadAttention to have 2nd forward function and introduce more modules in conformer encoder layer * Reduce final layerdrop_prob from 0.075 to 0.05. * Fix issue with diagnostics if stats is None * Remove persistent attention scores. * Make ActivationBalancer and MaxEig more efficient. * Cosmetic improvements * Change scale_factor_scale from 0.5 to 0.8 * Make the ActivationBalancer regress to the data mean, not zero, when enforcing abs constraint. * Remove unused config value * Fix bug when channel_dim < 0 * Fix bug when channel_dim < 0 * Simplify how the positional-embedding scores work in attention (thanks to Zengwei for this concept) * Revert dropout on attention scores to 0.0. * This should just be a cosmetic change, regularizing how we get the warmup times from the layers. * Reduce beta from 0.75 to 0.0. * Reduce stats period from 10 to 4. * Reworking of ActivationBalancer code to hopefully balance speed and effectiveness. * Add debug code for attention weihts and eigs * Remove debug statement * Add different debug info. * Penalize attention-weight entropies above a limit. * Remove debug statements * use larger delta but only penalize if small grad norm * Bug fixes; change debug freq * Change cutoff for small_grad_norm * Implement whitening of values in conformer. * Also whiten the keys in conformer. * Fix an issue with scaling of grad. * Decrease whitening limit from 2.0 to 1.1. * Fix debug stats. * Reorganize Whiten() code; configs are not the same as before. Also remove MaxEig for self_attn module * Bug fix RE float16 * Revert whitening_limit from 1.1 to 2.2. * Replace MaxEig with Whiten with limit=5.0, and move it to end of ConformerEncoderLayer * Change LR schedule to start off higher * Simplify the dropout mask, no non-dropped-out sequences * Make attention dims configurable, not embed_dim//2, trying 256. * Reduce attention_dim to 192; cherry-pick scaled_adam_exp130 which is linear_pos interacting with query * Use half the dim for values, vs. keys and queries. * Increase initial-lr from 0.04 to 0.05, plus changes for diagnostics * Cosmetic changes * Changes to avoid bug in backward hooks, affecting diagnostics. * Random clip attention scores to -5..5. * Add some random clamping in model.py * Add reflect=0.1 to invocations of random_clamp() * Remove in_balancer. * Revert model.py so there are no constraints on the output. * Implement randomized backprop for softmax. * Reduce min_abs from 1e-03 to 1e-04 * Add RandomGrad with min_abs=1.0e-04 * Use full precision to do softmax and store ans. * Fix bug in backprop of random_clamp() * Get the randomized backprop for softmax in autocast mode working. * Remove debug print * Reduce min_abs from 1.0e-04 to 5.0e-06 * Add hard limit of attention weights to +- 50 * Use normal implementation of softmax. * Remove use of RandomGrad * Remove the use of random_clamp in conformer.py. * Reduce the limit on attention weights from 50 to 25. * Reduce min_prob of ActivationBalancer from 0.1 to 0.05. * Penalize too large weights in softmax of AttentionDownsample() * Also apply limit on logit in SimpleCombiner * Increase limit on logit for SimpleCombiner to 25.0 * Add more diagnostics to debug gradient scale problems * Changes to grad scale logging; increase grad scale more frequently if less than one. * Add logging * Remove comparison diagnostics, which were not that useful. * Configuration changes: scores limit 5->10, min_prob 0.05->0.1, cur_grad_scale more aggressive increase * Reset optimizer state when we change loss function definition. * Make warmup period decrease scale on simple loss, leaving pruned loss scale constant. * Cosmetic change * Increase initial-lr from 0.05 to 0.06. * Increase initial-lr from 0.06 to 0.075 and decrease lr-epochs from 3.5 to 3. * Fixes to logging statements. * Introduce warmup schedule in optimizer * Increase grad_scale to Whiten module * Add inf check hooks * Renaming in optim.py; remove step() from scan_pessimistic_batches_for_oom in train.py * Change base lr to 0.1, also rename from initial lr in train.py * Adding activation balancers after simple_am_prob and simple_lm_prob * Reduce max_abs on am_balancer * Increase max_factor in final lm_balancer and am_balancer * Use penalize_abs_values_gt, not ActivationBalancer. * Trying to reduce grad_scale of Whiten() from 0.02 to 0.01. * Add hooks.py, had negleted to git add it. * don't do penalize_values_gt on simple_lm_proj and simple_am_proj; reduce --base-lr from 0.1 to 0.075 * Increase probs of activation balancer and make it decay slower. * Dont print out full non-finite tensor * Increase default max_factor for ActivationBalancer from 0.02 to 0.04; decrease max_abs in ConvolutionModule.deriv_balancer2 from 100.0 to 20.0 * reduce initial scale in GradScaler * Increase max_abs in ActivationBalancer of conv module from 20 to 50 * --base-lr0.075->0.5; --lr-epochs 3->3.5 * Revert 179->180 change, i.e. change max_abs for deriv_balancer2 back from 50.0 20.0 * Save some memory in the autograd of DoubleSwish. * Change the discretization of the sigmoid to be expectation preserving. * Fix randn to rand * Try a more exact way to round to uint8 that should prevent ever wrapping around to zero * Make it use float16 if in amp but use clamp to avoid wrapping error * Store only half precision output for softmax. * More memory efficient backprop for DoubleSwish. * Change to warmup schedule. * Changes to more accurately estimate OOM conditions * Reduce cutoff from 100 to 5 for estimating OOM with warmup * Make 20 the limit for warmup_count * Cast to float16 in DoubleSwish forward * Hopefully make penalize_abs_values_gt more memory efficient. * Add logging about memory used. * Change scalar_max in optim.py from 2.0 to 5.0 * Regularize how we apply the min and max to the eps of BasicNorm * Fix clamping of bypass scale; remove a couple unused variables. * Increase floor on bypass_scale from 0.1 to 0.2. * Increase bypass_scale from 0.2 to 0.4. * Increase bypass_scale min from 0.4 to 0.5 * Rename conformer.py to zipformer.py * Rename Conformer to Zipformer * Update decode.py by copying from pruned_transducer_stateless5 and changing directory name * Remove some unused variables. * Fix clamping of epsilon * Refactor zipformer for more flexibility so we can change number of encoder layers. * Have a 3rd encoder, at downsampling factor of 8. * Refactor how the downsampling is done so that it happens later, but the 1st encoder stack still operates after a subsampling of 2. * Fix bug RE seq lengths * Have 4 encoder stacks * Have 6 different encoder stacks, U-shaped network. * Reduce dim of linear positional encoding in attention layers. * Reduce min of bypass_scale from 0.5 to 0.3, and make it not applied in test mode. * Tuning change to num encoder layers, inspired by relative param importance. * Make decoder group size equal to 4. * Add skip connections as in normal U-net * Avoid falling off the loop for weird inputs * Apply layer-skip dropout prob * Have warmup schedule for layer-skipping * Rework how warmup count is produced; should not affect results. * Add warmup schedule for zipformer encoder layer, from 1.0 -> 0.2. * Reduce initial clamp_min for bypass_scale from 1.0 to 0.5. * Restore the changes from scaled_adam_219 and scaled_adam_exp220, accidentally lost, re layer skipping * Change to schedule of bypass_scale min: make it larger, decrease slower. * Change schedule after initial loss not promising * Implement pooling module, add it after initial feedforward. * Bug fix * Introduce dropout rate to dynamic submodules of conformer. * Introduce minimum probs in the SimpleCombiner * Add bias in weight module * Remove dynamic weights in SimpleCombine * Remove the 5th of 6 encoder stacks * Fix some typos * small fixes * small fixes * Copy files * Update decode.py * Add changes from the master * Add changes from the master * update results * Add CI * Small fixes * Small fixes Co-authored-by: Daniel Povey <dpovey@gmail.com>
335 lines
11 KiB
Python
Executable File
335 lines
11 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
#
|
|
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# This script converts several saved checkpoints
|
|
# to a single one using model averaging.
|
|
"""
|
|
|
|
Usage:
|
|
|
|
(1) Export to torchscript model using torch.jit.script()
|
|
|
|
./pruned_transducer_stateless8/export.py \
|
|
--exp-dir ./pruned_transducer_stateless8/exp \
|
|
--bpe-model data/lang_bpe_500/bpe.model \
|
|
--epoch 30 \
|
|
--avg 9 \
|
|
--jit 1
|
|
|
|
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
|
|
load it by `torch.jit.load("cpu_jit.pt")`.
|
|
|
|
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
|
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
|
|
|
Check
|
|
https://github.com/k2-fsa/sherpa
|
|
for how to use the exported models outside of icefall.
|
|
|
|
(2) Export `model.state_dict()`
|
|
|
|
./pruned_transducer_stateless8/export.py \
|
|
--exp-dir ./pruned_transducer_stateless8/exp \
|
|
--bpe-model data/lang_bpe_500/bpe.model \
|
|
--epoch 20 \
|
|
--avg 10
|
|
|
|
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
|
load it by `icefall.checkpoint.load_checkpoint()`.
|
|
|
|
To use the generated file with `pruned_transducer_stateless8/decode.py`,
|
|
you can do:
|
|
|
|
cd /path/to/exp_dir
|
|
ln -s pretrained.pt epoch-9999.pt
|
|
|
|
cd /path/to/egs/librispeech/ASR
|
|
./pruned_transducer_stateless8/decode.py \
|
|
--exp-dir ./pruned_transducer_stateless8/exp \
|
|
--epoch 9999 \
|
|
--avg 1 \
|
|
--max-duration 600 \
|
|
--decoding-method greedy_search \
|
|
--bpe-model data/lang_bpe_500/bpe.model
|
|
|
|
Check ./pretrained.py for its usage.
|
|
|
|
Note: If you don't want to train a model from scratch, we have
|
|
provided one for you. You can get it at
|
|
|
|
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14
|
|
|
|
with the following commands:
|
|
|
|
sudo apt-get install git-lfs
|
|
git lfs install
|
|
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14
|
|
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14/exp
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
import sentencepiece as spm
|
|
import torch
|
|
import torch.nn as nn
|
|
from scaling_converter import convert_scaled_to_non_scaled
|
|
from train import add_model_arguments, get_params, get_transducer_model
|
|
|
|
from icefall.checkpoint import (
|
|
average_checkpoints,
|
|
average_checkpoints_with_averaged_model,
|
|
find_checkpoints,
|
|
load_checkpoint,
|
|
)
|
|
from icefall.utils import str2bool
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--epoch",
|
|
type=int,
|
|
default=30,
|
|
help="""It specifies the checkpoint to use for decoding.
|
|
Note: Epoch counts from 1.
|
|
You can specify --avg to use more checkpoints for model averaging.""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--iter",
|
|
type=int,
|
|
default=0,
|
|
help="""If positive, --epoch is ignored and it
|
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
|
You can specify --avg to use more checkpoints for model averaging.
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--avg",
|
|
type=int,
|
|
default=9,
|
|
help="Number of checkpoints to average. Automatically select "
|
|
"consecutive checkpoints before the checkpoint specified by "
|
|
"'--epoch' and '--iter'",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--use-averaged-model",
|
|
type=str2bool,
|
|
default=True,
|
|
help="Whether to load averaged model. Currently it only supports "
|
|
"using --epoch. If True, it would decode with the averaged model "
|
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
|
"Actually only the models with epoch number of `epoch-avg` and "
|
|
"`epoch` are loaded for averaging. ",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--exp-dir",
|
|
type=str,
|
|
default="pruned_transducer_stateless8/exp",
|
|
help="""It specifies the directory where all training related
|
|
files, e.g., checkpoints, log, etc, are saved
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--bpe-model",
|
|
type=str,
|
|
default="data/lang_bpe_500/bpe.model",
|
|
help="Path to the BPE model",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--jit",
|
|
type=str2bool,
|
|
default=False,
|
|
help="""True to save a model after applying torch.jit.script.
|
|
It will generate a file named cpu_jit.pt
|
|
|
|
Check ./jit_pretrained.py for how to use it.
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--context-size",
|
|
type=int,
|
|
default=2,
|
|
help="The context size in the decoder. 1 means bigram; "
|
|
"2 means tri-gram",
|
|
)
|
|
|
|
add_model_arguments(parser)
|
|
|
|
return parser
|
|
|
|
|
|
@torch.no_grad()
|
|
def main():
|
|
args = get_parser().parse_args()
|
|
args.exp_dir = Path(args.exp_dir)
|
|
|
|
params = get_params()
|
|
params.update(vars(args))
|
|
|
|
device = torch.device("cpu")
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda", 0)
|
|
|
|
logging.info(f"device: {device}")
|
|
|
|
sp = spm.SentencePieceProcessor()
|
|
sp.load(params.bpe_model)
|
|
|
|
# <blk> is defined in local/train_bpe_model.py
|
|
params.blank_id = sp.piece_to_id("<blk>")
|
|
params.vocab_size = sp.get_piece_size()
|
|
|
|
logging.info(params)
|
|
|
|
logging.info("About to create model")
|
|
model = get_transducer_model(params, enable_giga=False)
|
|
num_param = sum([p.numel() for p in model.parameters()])
|
|
logging.info(f"Number of model parameters: {num_param}")
|
|
|
|
model.to(device)
|
|
|
|
if not params.use_averaged_model:
|
|
if params.iter > 0:
|
|
filenames = find_checkpoints(
|
|
params.exp_dir, iteration=-params.iter
|
|
)[: params.avg]
|
|
if len(filenames) == 0:
|
|
raise ValueError(
|
|
f"No checkpoints found for"
|
|
f" --iter {params.iter}, --avg {params.avg}"
|
|
)
|
|
elif len(filenames) < params.avg:
|
|
raise ValueError(
|
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
|
f" --iter {params.iter}, --avg {params.avg}"
|
|
)
|
|
logging.info(f"averaging {filenames}")
|
|
model.to(device)
|
|
model.load_state_dict(
|
|
average_checkpoints(filenames, device=device),
|
|
strict=False,
|
|
)
|
|
elif params.avg == 1:
|
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
|
else:
|
|
start = params.epoch - params.avg + 1
|
|
filenames = []
|
|
for i in range(start, params.epoch + 1):
|
|
if i >= 1:
|
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
|
logging.info(f"averaging {filenames}")
|
|
model.to(device)
|
|
model.load_state_dict(
|
|
average_checkpoints(filenames, device=device),
|
|
strict=False,
|
|
)
|
|
else:
|
|
if params.iter > 0:
|
|
filenames = find_checkpoints(
|
|
params.exp_dir, iteration=-params.iter
|
|
)[: params.avg + 1]
|
|
if len(filenames) == 0:
|
|
raise ValueError(
|
|
f"No checkpoints found for"
|
|
f" --iter {params.iter}, --avg {params.avg}"
|
|
)
|
|
elif len(filenames) < params.avg + 1:
|
|
raise ValueError(
|
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
|
f" --iter {params.iter}, --avg {params.avg}"
|
|
)
|
|
filename_start = filenames[-1]
|
|
filename_end = filenames[0]
|
|
logging.info(
|
|
"Calculating the averaged model over iteration checkpoints"
|
|
f" from {filename_start} (excluded) to {filename_end}"
|
|
)
|
|
model.to(device)
|
|
model.load_state_dict(
|
|
average_checkpoints_with_averaged_model(
|
|
filename_start=filename_start,
|
|
filename_end=filename_end,
|
|
device=device,
|
|
),
|
|
strict=False,
|
|
)
|
|
else:
|
|
assert params.avg > 0, params.avg
|
|
start = params.epoch - params.avg
|
|
assert start >= 1, start
|
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
|
logging.info(
|
|
f"Calculating the averaged model over epoch range from "
|
|
f"{start} (excluded) to {params.epoch}"
|
|
)
|
|
model.to(device)
|
|
model.load_state_dict(
|
|
average_checkpoints_with_averaged_model(
|
|
filename_start=filename_start,
|
|
filename_end=filename_end,
|
|
device=device,
|
|
),
|
|
strict=False,
|
|
)
|
|
|
|
model.to("cpu")
|
|
model.eval()
|
|
|
|
if params.jit is True:
|
|
convert_scaled_to_non_scaled(model, inplace=True)
|
|
logging.info("Using torch.jit.script()")
|
|
# We won't use the forward() method of the model in C++, so just ignore
|
|
# it here.
|
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
|
# torch scriptabe.
|
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
|
logging.info("Using torch.jit.script")
|
|
model = torch.jit.script(model)
|
|
filename = params.exp_dir / "cpu_jit.pt"
|
|
model.save(str(filename))
|
|
logging.info(f"Saved to {filename}")
|
|
else:
|
|
logging.info("Not using torchscript. Export model.state_dict()")
|
|
# Save it using a format so that it can be loaded
|
|
# by :func:`load_checkpoint`
|
|
filename = params.exp_dir / "pretrained.pt"
|
|
torch.save({"model": model.state_dict()}, str(filename))
|
|
logging.info(f"Saved to {filename}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
formatter = (
|
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
)
|
|
|
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
main()
|