Zengwei Yao b3e6bf66df
Add modified beam search decoding for streaming inference with emformer model (#327)
* Fix torch.nn.Embedding error for torch below 1.8.0

* Changes to fbank computation, use lilcom chunky writer

* Add min in q,k,v of attention

* Remove learnable offset, use relu instead.

* Experiments based on SpecAugment change

* Merge specaug change from Mingshuang.

* Use much more aggressive SpecAug setup

* Fix to num_feature_masks bug I introduced; reduce max_frames_mask_fraction 0.4->0.3

* Change p=0.5->0.9, mask_fraction 0.3->0.2

* Change p=0.9 to p=0.8 in SpecAug

* Fix num_time_masks code; revert 0.8 to 0.9

* Change max_frames from 0.2 to 0.15

* Remove ReLU in attention

* Adding diagnostics code...

* Refactor/simplify ConformerEncoder

* First version of rand-combine iterated-training-like idea.

* Improvements to diagnostics (RE those with 1 dim

* Add pelu to this good-performing setup..

* Small bug fixes/imports

* Add baseline for the PeLU expt, keeping only the small normalization-related changes.

* pelu_base->expscale, add 2xExpScale in subsampling, and in feedforward units.

* Double learning rate of exp-scale units

* Combine ExpScale and swish for memory reduction

* Add import

* Fix backprop bug

* Fix bug in diagnostics

* Increase scale on Scale from 4 to 20

* Increase scale from 20 to 50.

* Fix duplicate Swish; replace norm+swish with swish+exp-scale in convolution module

* Reduce scale from 50 to 20

* Add deriv-balancing code

* Double the threshold in brelu; slightly increase max_factor.

* Fix exp dir

* Convert swish nonlinearities to ReLU

* Replace relu with swish-squared.

* Restore ConvolutionModule to state before changes; change all Swish,Swish(Swish) to SwishOffset.

* Replace norm on input layer with scale of 0.1.

* Extensions to diagnostics code

* Update diagnostics

* Add BasicNorm module

* Replace most normalizations with scales (still have norm in conv)

* Change exp dir

* Replace norm in ConvolutionModule with a scaling factor.

* use nonzero threshold in DerivBalancer

* Add min-abs-value 0.2

* Fix dirname

* Change min-abs threshold from 0.2 to 0.5

* Scale up pos_bias_u and pos_bias_v before use.

* Reduce max_factor to 0.01

* Fix q*scaling logic

* Change max_factor in DerivBalancer from 0.025 to 0.01; fix scaling code.

* init 1st conv module to smaller variance

* Change how scales are applied; fix residual bug

* Reduce min_abs from 0.5 to 0.2

* Introduce in_scale=0.5 for SwishExpScale

* Fix scale from 0.5 to 2.0 as I really intended..

* Set scaling on SwishExpScale

* Add identity pre_norm_final for diagnostics.

* Add learnable post-scale for mha

* Fix self.post-scale-mha

* Another rework, use scales on linear/conv

* Change dir name

* Reduce initial scaling of modules

* Bug-fix RE bias

* Cosmetic change

* Reduce initial_scale.

* Replace ExpScaleRelu with DoubleSwish()

* DoubleSwish fix

* Use learnable scales for joiner and decoder

* Add max-abs-value constraint in DerivBalancer

* Add max-abs-value

* Change dir name

* Remove ExpScale in feedforward layes.

* Reduce max-abs limit from 1000 to 100; introduce 2 DerivBalancer modules in conv layer.

* Make DoubleSwish more memory efficient

* Reduce constraints from deriv-balancer in ConvModule.

* Add warmup mode

* Remove max-positive constraint in deriv-balancing; add second DerivBalancer in conv module.

* Add some extra info to diagnostics

* Add deriv-balancer at output of embedding.

* Add more stats.

* Make epsilon in BasicNorm learnable, optionally.

* Draft of 0mean changes..

* Rework of initialization

* Fix typo

* Remove dead code

* Modifying initialization from normal->uniform; add initial_scale when initializing

* bug fix re sqrt

* Remove xscale from pos_embedding

* Remove some dead code.

* Cosmetic changes/renaming things

* Start adding some files..

* Add more files..

* update decode.py file type

* Add remaining files in pruned_transducer_stateless2

* Fix diagnostics-getting code

* Scale down pruned loss in warmup mode

* Reduce warmup scale on pruned loss form 0.1 to 0.01.

* Remove scale_speed, make swish deriv more efficient.

* Cosmetic changes to swish

* Double warm_step

* Fix bug with import

* Change initial std from 0.05 to 0.025.

* Set also scale for embedding to 0.025.

* Remove logging code that broke with newer Lhotse; fix bug with pruned_loss

* Add norm+balancer to VggSubsampling

* Incorporate changes from master into pruned_transducer_stateless2.

* Add max-abs=6, debugged version

* Change 0.025,0.05 to 0.01 in initializations

* Fix balancer code

* Whitespace fix

* Reduce initial pruned_loss scale from 0.01 to 0.0

* Increase warm_step (and valid_interval)

* Change max-abs from 6 to 10

* Change how warmup works.

* Add changes from master to decode.py, train.py

* Simplify the warmup code; max_abs 10->6

* Make warmup work by scaling layer contributions; leave residual layer-drop

* Fix bug

* Fix test mode with random layer dropout

* Add random-number-setting function in dataloader

* Fix/patch how fix_random_seed() is imported.

* Reduce layer-drop prob

* Reduce layer-drop prob after warmup to 1 in 100

* Change power of lr-schedule from -0.5 to -0.333

* Increase model_warm_step to 4k

* Change max-keep-prob to 0.95

* Refactoring and simplifying conformer and frontend

* Rework conformer, remove some code.

* Reduce 1st conv channels from 64 to 32

* Add another convolutional layer

* Fix padding bug

* Remove dropout in output layer

* Reduce speed of some components

* Initial refactoring to remove unnecessary vocab_size

* Fix RE identity

* Bug-fix

* Add final dropout to conformer

* Remove some un-used code

* Replace nn.Linear with ScaledLinear in simple joiner

* Make 2 projections..

* Reduce initial_speed

* Use initial_speed=0.5

* Reduce initial_speed further from 0.5 to 0.25

* Reduce initial_speed from 0.5 to 0.25

* Change how warmup is applied.

* Bug fix to warmup_scale

* Fix test-mode

* Remove final dropout

* Make layer dropout rate 0.075, was 0.1.

* First draft of model rework

* Various bug fixes

* Change learning speed of simple_lm_proj

* Revert transducer_stateless/ to state in upstream/master

* Fix to joiner to allow different dims

* Some cleanups

* Make training more efficient, avoid redoing some projections.

* Change how warm-step is set

* First draft of new approach to learning rates + init

* Some fixes..

* Change initialization to 0.25

* Fix type of parameter

* Fix weight decay formula by adding 1/1-beta

* Fix weight decay formula by adding 1/1-beta

* Fix checkpoint-writing

* Fix to reading scheudler from optim

* Simplified optimizer, rework somet things..

* Reduce model_warm_step from 4k to 3k

* Fix bug in lambda

* Bug-fix RE sign of target_rms

* Changing initial_speed from 0.25 to 01

* Change some defaults in LR-setting rule.

* Remove initial_speed

* Set new scheduler

* Change exponential part of lrate to be epoch based

* Fix bug

* Set 2n rule..

* Implement 2o schedule

* Make lrate rule more symmetric

* Implement 2p version of learning rate schedule.

* Refactor how learning rate is set.

* Fix import

* Modify init (#301)

* update icefall/__init__.py to import more common functions.

* update icefall/__init__.py

* make imports style consistent.

* exclude black check for icefall/__init__.py in pyproject.toml.

* Minor fixes for logging (#296)

* Minor fixes for logging

* Minor fix

* Fix dir names

* Modify beam search to be efficient with current joienr

* Fix adding learning rate to tensorboard

* Fix docs in optim.py

* Support mix precision training on the reworked model (#305)

* Add mix precision support

* Minor fixes

* Minor fixes

* Minor fixes

* Tedlium3 pruned transducer stateless (#261)

* update tedlium3-pruned-transducer-stateless-codes

* update README.md

* update README.md

* add fast beam search for decoding

* do a change for RESULTS.md

* do a change for RESULTS.md

* do a fix

* do some changes for pruned RNN-T

* Add mix precision support

* Minor fixes

* Minor fixes

* Updating RESULTS.md; fix in beam_search.py

* Fix rebase

* Code style check for librispeech pruned transducer stateless2 (#308)

* Update results for tedlium3 pruned RNN-T (#307)

* Update README.md

* Fix CI errors. (#310)

* Add more results

* Fix tensorboard log location

* Add one more epoch of full expt

* fix comments

* Add results for mixed precision with max-duration 300

* Changes for pretrained.py (tedlium3 pruned RNN-T) (#311)

* GigaSpeech recipe (#120)

* initial commit

* support download, data prep, and fbank

* on-the-fly feature extraction by default

* support BPE based lang

* support HLG for BPE

* small fix

* small fix

* chunked feature extraction by default

* Compute features for GigaSpeech by splitting the manifest.

* Fixes after review.

* Split manifests into 2000 pieces.

* set audio duration mismatch tolerance to 0.01

* small fix

* add conformer training recipe

* Add conformer.py without pre-commit checking

* lazy loading and use SingleCutSampler

* DynamicBucketingSampler

* use KaldifeatFbank to compute fbank for musan

* use pretrained language model and lexicon

* use 3gram to decode, 4gram to rescore

* Add decode.py

* Update .flake8

* Delete compute_fbank_gigaspeech.py

* Use BucketingSampler for valid and test dataloader

* Update params in train.py

* Use bpe_500

* update params in decode.py

* Decrease num_paths while CUDA OOM

* Added README

* Update RESULTS

* black

* Decrease num_paths while CUDA OOM

* Decode with post-processing

* Update results

* Remove lazy_load option

* Use default `storage_type`

* Keep the original tolerance

* Use split-lazy

* black

* Update pretrained model

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>

* Add LG decoding (#277)

* Add LG decoding

* Add log weight pushing

* Minor fixes

* Support computing RNN-T loss with torchaudio (#316)

* Support modified beam search decoding for streaming inference with Emformer model.

* Formatted imports.

* Update results for torchaudio RNN-T. (#322)

* Fixed streaming decoding codes for emformer model.

* Fixed docs.

* Sorted imports for transducer_emformer/streaming_feature_extractor.py

* Minor fix for transducer_emformer/streaming_feature_extractor.py

Co-authored-by: pkufool <wkang@pku.org.cn>
Co-authored-by: Daniel Povey <dpovey@gmail.com>
Co-authored-by: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com>
Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
Co-authored-by: Guo Liyong <guonwpu@qq.com>
Co-authored-by: Wang, Guanbo <wgb14@outlook.com>
2022-04-22 18:06:07 +08:00

332 lines
12 KiB
Python

# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
#
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import torch
from torch.optim import Optimizer
class Eve(Optimizer):
r"""
Implements Eve algorithm. This is a modified version of AdamW with a special
way of setting the weight-decay / shrinkage-factor, which is designed to make the
rms of the parameters approach a particular target_rms (default: 0.1). This is
for use with networks with 'scaled' versions of modules (see scaling.py), which
will be close to invariant to the absolute scale on the parameter matrix.
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
Eve is unpublished so far.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 3e-4;
this value means that the weight would decay significantly after
about 3k minibatches. Is not multiplied by learning rate, but
is conditional on RMS-value of parameter being > target_rms.
target_rms (float, optional): target root-mean-square value of
parameters, if they fall below this we will stop applying weight decay.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.98),
eps=1e-8,
weight_decay=1e-3,
target_rms=0.1,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError(
"Invalid beta parameter at index 0: {}".format(betas[0])
)
if not 0.0 <= betas[1] < 1.0:
raise ValueError(
"Invalid beta parameter at index 1: {}".format(betas[1])
)
if not 0 <= weight_decay <= 0.1:
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
)
if not 0 < target_rms <= 10.0:
raise ValueError("Invalid target_rms value: {}".format(target_rms))
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
target_rms=target_rms,
)
super(Eve, self).__init__(params, defaults)
def __setstate__(self, state):
super(Eve, self).__setstate__(state)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
# Perform optimization step
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"AdamW does not support sparse gradients"
)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(
group["eps"]
)
step_size = group["lr"] / bias_correction1
target_rms = group["target_rms"]
weight_decay = group["weight_decay"]
if p.numel() > 1:
# avoid applying this weight-decay on "scaling factors"
# (which are scalar).
is_above_target_rms = p.norm() > (
target_rms * (p.numel() ** 0.5)
)
p.mul_(1 - (weight_decay * is_above_target_rms))
p.addcdiv_(exp_avg, denom, value=-step_size)
return loss
class LRScheduler(object):
"""
Base-class for learning rate schedulers where the learning-rate depends on both the
batch and the epoch.
"""
def __init__(self, optimizer: Optimizer, verbose: bool = False):
# Attach optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError(
"{} is not an Optimizer".format(type(optimizer).__name__)
)
self.optimizer = optimizer
self.verbose = verbose
for group in optimizer.param_groups:
group.setdefault("initial_lr", group["lr"])
self.base_lrs = [
group["initial_lr"] for group in optimizer.param_groups
]
self.epoch = 0
self.batch = 0
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
"""
return {
"base_lrs": self.base_lrs,
"epoch": self.epoch,
"batch": self.batch,
}
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def get_last_lr(self) -> List[float]:
"""Return last computed learning rate by current scheduler. Will be a list of float."""
return self._last_lr
def get_lr(self):
# Compute list of learning rates from self.epoch and self.batch and
# self.base_lrs; this must be overloaded by the user.
# e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
raise NotImplementedError
def step_batch(self, batch: Optional[int] = None) -> None:
# Step the batch index, or just set it. If `batch` is specified, it
# must be the batch index from the start of training, i.e. summed over
# all epochs.
# You can call this in any order; if you don't provide 'batch', it should
# of course be called once per batch.
if batch is not None:
self.batch = batch
else:
self.batch = self.batch + 1
self._set_lrs()
def step_epoch(self, epoch: Optional[int] = None):
# Step the epoch index, or just set it. If you provide the 'epoch' arg,
# you should call this at the start of the epoch; if you don't provide the 'epoch'
# arg, you should call it at the end of the epoch.
if epoch is not None:
self.epoch = epoch
else:
self.epoch = self.epoch + 1
self._set_lrs()
def _set_lrs(self):
values = self.get_lr()
assert len(values) == len(self.optimizer.param_groups)
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
param_group, lr = data
param_group["lr"] = lr
self.print_lr(self.verbose, i, lr)
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
def print_lr(self, is_verbose, group, lr):
"""Display the current learning rate."""
if is_verbose:
print(
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
f" of group {group} to {lr:.4e}."
)
class Eden(LRScheduler):
"""
Eden scheduler.
lr = initial_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
(((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25))
E.g. suggest initial-lr = 0.003 (passed to optimizer).
Args:
optimizer: the optimizer to change the learning rates on
lr_batches: the number of batches after which we start significantly
decreasing the learning rate, suggest 5000.
lr_epochs: the number of epochs after which we start significantly
decreasing the learning rate, suggest 6 if you plan to do e.g.
20 to 40 epochs, but may need smaller number if dataset is huge
and you will do few epochs.
"""
def __init__(
self,
optimizer: Optimizer,
lr_batches: Union[int, float],
lr_epochs: Union[int, float],
verbose: bool = False,
):
super(Eden, self).__init__(optimizer, verbose)
self.lr_batches = lr_batches
self.lr_epochs = lr_epochs
def get_lr(self):
factor = (
(self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2
) ** -0.25 * (
((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2)
** -0.25
)
return [x * factor for x in self.base_lrs]
def _test_eden():
m = torch.nn.Linear(100, 100)
optim = Eve(m.parameters(), lr=0.003)
scheduler = Eden(optim, lr_batches=30, lr_epochs=2, verbose=True)
for epoch in range(10):
scheduler.step_epoch(epoch) # sets epoch to `epoch`
for step in range(20):
x = torch.randn(200, 100).detach()
x.requires_grad = True
y = m(x)
dy = torch.randn(200, 100).detach()
f = (y * dy).sum()
f.backward()
optim.step()
scheduler.step_batch()
optim.zero_grad()
print("last lr = ", scheduler.get_last_lr())
print("state dict = ", scheduler.state_dict())
if __name__ == "__main__":
_test_eden()