mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-10 01:24:19 +00:00
* Fix torch.nn.Embedding error for torch below 1.8.0 * Changes to fbank computation, use lilcom chunky writer * Add min in q,k,v of attention * Remove learnable offset, use relu instead. * Experiments based on SpecAugment change * Merge specaug change from Mingshuang. * Use much more aggressive SpecAug setup * Fix to num_feature_masks bug I introduced; reduce max_frames_mask_fraction 0.4->0.3 * Change p=0.5->0.9, mask_fraction 0.3->0.2 * Change p=0.9 to p=0.8 in SpecAug * Fix num_time_masks code; revert 0.8 to 0.9 * Change max_frames from 0.2 to 0.15 * Remove ReLU in attention * Adding diagnostics code... * Refactor/simplify ConformerEncoder * First version of rand-combine iterated-training-like idea. * Improvements to diagnostics (RE those with 1 dim * Add pelu to this good-performing setup.. * Small bug fixes/imports * Add baseline for the PeLU expt, keeping only the small normalization-related changes. * pelu_base->expscale, add 2xExpScale in subsampling, and in feedforward units. * Double learning rate of exp-scale units * Combine ExpScale and swish for memory reduction * Add import * Fix backprop bug * Fix bug in diagnostics * Increase scale on Scale from 4 to 20 * Increase scale from 20 to 50. * Fix duplicate Swish; replace norm+swish with swish+exp-scale in convolution module * Reduce scale from 50 to 20 * Add deriv-balancing code * Double the threshold in brelu; slightly increase max_factor. * Fix exp dir * Convert swish nonlinearities to ReLU * Replace relu with swish-squared. * Restore ConvolutionModule to state before changes; change all Swish,Swish(Swish) to SwishOffset. * Replace norm on input layer with scale of 0.1. * Extensions to diagnostics code * Update diagnostics * Add BasicNorm module * Replace most normalizations with scales (still have norm in conv) * Change exp dir * Replace norm in ConvolutionModule with a scaling factor. * use nonzero threshold in DerivBalancer * Add min-abs-value 0.2 * Fix dirname * Change min-abs threshold from 0.2 to 0.5 * Scale up pos_bias_u and pos_bias_v before use. * Reduce max_factor to 0.01 * Fix q*scaling logic * Change max_factor in DerivBalancer from 0.025 to 0.01; fix scaling code. * init 1st conv module to smaller variance * Change how scales are applied; fix residual bug * Reduce min_abs from 0.5 to 0.2 * Introduce in_scale=0.5 for SwishExpScale * Fix scale from 0.5 to 2.0 as I really intended.. * Set scaling on SwishExpScale * Add identity pre_norm_final for diagnostics. * Add learnable post-scale for mha * Fix self.post-scale-mha * Another rework, use scales on linear/conv * Change dir name * Reduce initial scaling of modules * Bug-fix RE bias * Cosmetic change * Reduce initial_scale. * Replace ExpScaleRelu with DoubleSwish() * DoubleSwish fix * Use learnable scales for joiner and decoder * Add max-abs-value constraint in DerivBalancer * Add max-abs-value * Change dir name * Remove ExpScale in feedforward layes. * Reduce max-abs limit from 1000 to 100; introduce 2 DerivBalancer modules in conv layer. * Make DoubleSwish more memory efficient * Reduce constraints from deriv-balancer in ConvModule. * Add warmup mode * Remove max-positive constraint in deriv-balancing; add second DerivBalancer in conv module. * Add some extra info to diagnostics * Add deriv-balancer at output of embedding. * Add more stats. * Make epsilon in BasicNorm learnable, optionally. * Draft of 0mean changes.. * Rework of initialization * Fix typo * Remove dead code * Modifying initialization from normal->uniform; add initial_scale when initializing * bug fix re sqrt * Remove xscale from pos_embedding * Remove some dead code. * Cosmetic changes/renaming things * Start adding some files.. * Add more files.. * update decode.py file type * Add remaining files in pruned_transducer_stateless2 * Fix diagnostics-getting code * Scale down pruned loss in warmup mode * Reduce warmup scale on pruned loss form 0.1 to 0.01. * Remove scale_speed, make swish deriv more efficient. * Cosmetic changes to swish * Double warm_step * Fix bug with import * Change initial std from 0.05 to 0.025. * Set also scale for embedding to 0.025. * Remove logging code that broke with newer Lhotse; fix bug with pruned_loss * Add norm+balancer to VggSubsampling * Incorporate changes from master into pruned_transducer_stateless2. * Add max-abs=6, debugged version * Change 0.025,0.05 to 0.01 in initializations * Fix balancer code * Whitespace fix * Reduce initial pruned_loss scale from 0.01 to 0.0 * Increase warm_step (and valid_interval) * Change max-abs from 6 to 10 * Change how warmup works. * Add changes from master to decode.py, train.py * Simplify the warmup code; max_abs 10->6 * Make warmup work by scaling layer contributions; leave residual layer-drop * Fix bug * Fix test mode with random layer dropout * Add random-number-setting function in dataloader * Fix/patch how fix_random_seed() is imported. * Reduce layer-drop prob * Reduce layer-drop prob after warmup to 1 in 100 * Change power of lr-schedule from -0.5 to -0.333 * Increase model_warm_step to 4k * Change max-keep-prob to 0.95 * Refactoring and simplifying conformer and frontend * Rework conformer, remove some code. * Reduce 1st conv channels from 64 to 32 * Add another convolutional layer * Fix padding bug * Remove dropout in output layer * Reduce speed of some components * Initial refactoring to remove unnecessary vocab_size * Fix RE identity * Bug-fix * Add final dropout to conformer * Remove some un-used code * Replace nn.Linear with ScaledLinear in simple joiner * Make 2 projections.. * Reduce initial_speed * Use initial_speed=0.5 * Reduce initial_speed further from 0.5 to 0.25 * Reduce initial_speed from 0.5 to 0.25 * Change how warmup is applied. * Bug fix to warmup_scale * Fix test-mode * Remove final dropout * Make layer dropout rate 0.075, was 0.1. * First draft of model rework * Various bug fixes * Change learning speed of simple_lm_proj * Revert transducer_stateless/ to state in upstream/master * Fix to joiner to allow different dims * Some cleanups * Make training more efficient, avoid redoing some projections. * Change how warm-step is set * First draft of new approach to learning rates + init * Some fixes.. * Change initialization to 0.25 * Fix type of parameter * Fix weight decay formula by adding 1/1-beta * Fix weight decay formula by adding 1/1-beta * Fix checkpoint-writing * Fix to reading scheudler from optim * Simplified optimizer, rework somet things.. * Reduce model_warm_step from 4k to 3k * Fix bug in lambda * Bug-fix RE sign of target_rms * Changing initial_speed from 0.25 to 01 * Change some defaults in LR-setting rule. * Remove initial_speed * Set new scheduler * Change exponential part of lrate to be epoch based * Fix bug * Set 2n rule.. * Implement 2o schedule * Make lrate rule more symmetric * Implement 2p version of learning rate schedule. * Refactor how learning rate is set. * Fix import * Modify init (#301) * update icefall/__init__.py to import more common functions. * update icefall/__init__.py * make imports style consistent. * exclude black check for icefall/__init__.py in pyproject.toml. * Minor fixes for logging (#296) * Minor fixes for logging * Minor fix * Fix dir names * Modify beam search to be efficient with current joienr * Fix adding learning rate to tensorboard * Fix docs in optim.py * Support mix precision training on the reworked model (#305) * Add mix precision support * Minor fixes * Minor fixes * Minor fixes * Tedlium3 pruned transducer stateless (#261) * update tedlium3-pruned-transducer-stateless-codes * update README.md * update README.md * add fast beam search for decoding * do a change for RESULTS.md * do a change for RESULTS.md * do a fix * do some changes for pruned RNN-T * Add mix precision support * Minor fixes * Minor fixes * Updating RESULTS.md; fix in beam_search.py * Fix rebase * Code style check for librispeech pruned transducer stateless2 (#308) * Update results for tedlium3 pruned RNN-T (#307) * Update README.md * Fix CI errors. (#310) * Add more results * Fix tensorboard log location * Add one more epoch of full expt * fix comments * Add results for mixed precision with max-duration 300 * Changes for pretrained.py (tedlium3 pruned RNN-T) (#311) * GigaSpeech recipe (#120) * initial commit * support download, data prep, and fbank * on-the-fly feature extraction by default * support BPE based lang * support HLG for BPE * small fix * small fix * chunked feature extraction by default * Compute features for GigaSpeech by splitting the manifest. * Fixes after review. * Split manifests into 2000 pieces. * set audio duration mismatch tolerance to 0.01 * small fix * add conformer training recipe * Add conformer.py without pre-commit checking * lazy loading and use SingleCutSampler * DynamicBucketingSampler * use KaldifeatFbank to compute fbank for musan * use pretrained language model and lexicon * use 3gram to decode, 4gram to rescore * Add decode.py * Update .flake8 * Delete compute_fbank_gigaspeech.py * Use BucketingSampler for valid and test dataloader * Update params in train.py * Use bpe_500 * update params in decode.py * Decrease num_paths while CUDA OOM * Added README * Update RESULTS * black * Decrease num_paths while CUDA OOM * Decode with post-processing * Update results * Remove lazy_load option * Use default `storage_type` * Keep the original tolerance * Use split-lazy * black * Update pretrained model Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com> * Add LG decoding (#277) * Add LG decoding * Add log weight pushing * Minor fixes * Support computing RNN-T loss with torchaudio (#316) * Support modified beam search decoding for streaming inference with Emformer model. * Formatted imports. * Update results for torchaudio RNN-T. (#322) * Fixed streaming decoding codes for emformer model. * Fixed docs. * Sorted imports for transducer_emformer/streaming_feature_extractor.py * Minor fix for transducer_emformer/streaming_feature_extractor.py Co-authored-by: pkufool <wkang@pku.org.cn> Co-authored-by: Daniel Povey <dpovey@gmail.com> Co-authored-by: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com> Co-authored-by: Guo Liyong <guonwpu@qq.com> Co-authored-by: Wang, Guanbo <wgb14@outlook.com>
317 lines
9.6 KiB
Python
317 lines
9.6 KiB
Python
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
#
|
|
# See ../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import glob
|
|
import logging
|
|
import os
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from lhotse.dataset.sampling.base import CutSampler
|
|
from torch.cuda.amp import GradScaler
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from torch.optim import Optimizer
|
|
|
|
# use duck typing for LRScheduler since we have different possibilities, see
|
|
# our class LRScheduler.
|
|
LRSchedulerType = object
|
|
|
|
|
|
def save_checkpoint(
|
|
filename: Path,
|
|
model: Union[nn.Module, DDP],
|
|
params: Optional[Dict[str, Any]] = None,
|
|
optimizer: Optional[Optimizer] = None,
|
|
scheduler: Optional[LRSchedulerType] = None,
|
|
scaler: Optional[GradScaler] = None,
|
|
sampler: Optional[CutSampler] = None,
|
|
rank: int = 0,
|
|
) -> None:
|
|
"""Save training information to a file.
|
|
|
|
Args:
|
|
filename:
|
|
The checkpoint filename.
|
|
model:
|
|
The model to be saved. We only save its `state_dict()`.
|
|
params:
|
|
User defined parameters, e.g., epoch, loss.
|
|
optimizer:
|
|
The optimizer to be saved. We only save its `state_dict()`.
|
|
scheduler:
|
|
The scheduler to be saved. We only save its `state_dict()`.
|
|
scalar:
|
|
The GradScaler to be saved. We only save its `state_dict()`.
|
|
rank:
|
|
Used in DDP. We save checkpoint only for the node whose rank is 0.
|
|
Returns:
|
|
Return None.
|
|
"""
|
|
if rank != 0:
|
|
return
|
|
|
|
logging.info(f"Saving checkpoint to {filename}")
|
|
|
|
if isinstance(model, DDP):
|
|
model = model.module
|
|
|
|
checkpoint = {
|
|
"model": model.state_dict(),
|
|
"optimizer": optimizer.state_dict() if optimizer is not None else None,
|
|
"scheduler": scheduler.state_dict() if scheduler is not None else None,
|
|
"grad_scaler": scaler.state_dict() if scaler is not None else None,
|
|
"sampler": sampler.state_dict() if sampler is not None else None,
|
|
}
|
|
|
|
if params:
|
|
for k, v in params.items():
|
|
assert k not in checkpoint
|
|
checkpoint[k] = v
|
|
|
|
torch.save(checkpoint, filename)
|
|
|
|
|
|
def load_checkpoint(
|
|
filename: Path,
|
|
model: nn.Module,
|
|
optimizer: Optional[Optimizer] = None,
|
|
scheduler: Optional[LRSchedulerType] = None,
|
|
scaler: Optional[GradScaler] = None,
|
|
sampler: Optional[CutSampler] = None,
|
|
strict: bool = False,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
TODO: document it
|
|
"""
|
|
logging.info(f"Loading checkpoint from {filename}")
|
|
checkpoint = torch.load(filename, map_location="cpu")
|
|
|
|
if next(iter(checkpoint["model"])).startswith("module."):
|
|
logging.info("Loading checkpoint saved by DDP")
|
|
|
|
dst_state_dict = model.state_dict()
|
|
src_state_dict = checkpoint["model"]
|
|
for key in dst_state_dict.keys():
|
|
src_key = "{}.{}".format("module", key)
|
|
dst_state_dict[key] = src_state_dict.pop(src_key)
|
|
assert len(src_state_dict) == 0
|
|
model.load_state_dict(dst_state_dict, strict=strict)
|
|
else:
|
|
model.load_state_dict(checkpoint["model"], strict=strict)
|
|
|
|
checkpoint.pop("model")
|
|
|
|
def load(name, obj):
|
|
s = checkpoint.get(name, None)
|
|
if obj and s:
|
|
obj.load_state_dict(s)
|
|
checkpoint.pop(name)
|
|
|
|
load("optimizer", optimizer)
|
|
load("scheduler", scheduler)
|
|
load("grad_scaler", scaler)
|
|
load("sampler", sampler)
|
|
|
|
return checkpoint
|
|
|
|
|
|
def average_checkpoints(
|
|
filenames: List[Path], device: torch.device = torch.device("cpu")
|
|
) -> dict:
|
|
"""Average a list of checkpoints.
|
|
|
|
Args:
|
|
filenames:
|
|
Filenames of the checkpoints to be averaged. We assume all
|
|
checkpoints are saved by :func:`save_checkpoint`.
|
|
device:
|
|
Move checkpoints to this device before averaging.
|
|
Returns:
|
|
Return a dict (i.e., state_dict) which is the average of all
|
|
model state dicts contained in the checkpoints.
|
|
"""
|
|
n = len(filenames)
|
|
|
|
avg = torch.load(filenames[0], map_location=device)["model"]
|
|
for i in range(1, n):
|
|
state_dict = torch.load(filenames[i], map_location=device)["model"]
|
|
for k in avg:
|
|
avg[k] += state_dict[k]
|
|
|
|
for k in avg:
|
|
if avg[k].is_floating_point():
|
|
avg[k] /= n
|
|
else:
|
|
avg[k] //= n
|
|
|
|
return avg
|
|
|
|
|
|
def save_checkpoint_with_global_batch_idx(
|
|
out_dir: Path,
|
|
global_batch_idx: int,
|
|
model: Union[nn.Module, DDP],
|
|
params: Optional[Dict[str, Any]] = None,
|
|
optimizer: Optional[Optimizer] = None,
|
|
scheduler: Optional[LRSchedulerType] = None,
|
|
scaler: Optional[GradScaler] = None,
|
|
sampler: Optional[CutSampler] = None,
|
|
rank: int = 0,
|
|
):
|
|
"""Save training info after processing given number of batches.
|
|
|
|
Args:
|
|
out_dir:
|
|
The directory to save the checkpoint.
|
|
global_batch_idx:
|
|
The number of batches processed so far from the very start of the
|
|
training. The saved checkpoint will have the following filename:
|
|
|
|
f'out_dir / checkpoint-{global_batch_idx}.pt'
|
|
model:
|
|
The neural network model whose `state_dict` will be saved in the
|
|
checkpoint.
|
|
params:
|
|
A dict of training configurations to be saved.
|
|
optimizer:
|
|
The optimizer used in the training. Its `state_dict` will be saved.
|
|
scheduler:
|
|
The learning rate scheduler used in the training. Its `state_dict` will
|
|
be saved.
|
|
scaler:
|
|
The scaler used for mix precision training. Its `state_dict` will
|
|
be saved.
|
|
sampler:
|
|
The sampler used in the training dataset.
|
|
rank:
|
|
The rank ID used in DDP training of the current node. Set it to 0
|
|
if DDP is not used.
|
|
"""
|
|
out_dir = Path(out_dir)
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
|
|
save_checkpoint(
|
|
filename=filename,
|
|
model=model,
|
|
params=params,
|
|
optimizer=optimizer,
|
|
scheduler=scheduler,
|
|
scaler=scaler,
|
|
sampler=sampler,
|
|
rank=rank,
|
|
)
|
|
|
|
|
|
def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
|
|
"""Find all available checkpoints in a directory.
|
|
|
|
The checkpoint filenames have the form: `checkpoint-xxx.pt`
|
|
where xxx is a numerical value.
|
|
|
|
Assume you have the following checkpoints in the folder `foo`:
|
|
|
|
- checkpoint-1.pt
|
|
- checkpoint-20.pt
|
|
- checkpoint-300.pt
|
|
- checkpoint-4000.pt
|
|
|
|
Case 1 (Return all checkpoints)::
|
|
|
|
find_checkpoints(out_dir='foo')
|
|
|
|
Case 2 (Return checkpoints newer than checkpoint-20.pt, i.e.,
|
|
checkpoint-4000.pt, checkpoint-300.pt, and checkpoint-20.pt)
|
|
|
|
find_checkpoints(out_dir='foo', iteration=20)
|
|
|
|
Case 3 (Return checkpoints older than checkpoint-20.pt, i.e.,
|
|
checkpoint-20.pt, checkpoint-1.pt)::
|
|
|
|
find_checkpoints(out_dir='foo', iteration=-20)
|
|
|
|
Args:
|
|
out_dir:
|
|
The directory where to search for checkpoints.
|
|
iteration:
|
|
If it is 0, return all available checkpoints.
|
|
If it is positive, return the checkpoints whose iteration number is
|
|
greater than or equal to `iteration`.
|
|
If it is negative, return the checkpoints whose iteration number is
|
|
less than or equal to `-iteration`.
|
|
Returns:
|
|
Return a list of checkpoint filenames, sorted in descending
|
|
order by the numerical value in the filename.
|
|
"""
|
|
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
|
|
pattern = re.compile(r"checkpoint-([0-9]+).pt")
|
|
iter_checkpoints = [
|
|
(int(pattern.search(c).group(1)), c) for c in checkpoints
|
|
]
|
|
# iter_checkpoints is a list of tuples. Each tuple contains
|
|
# two elements: (iteration_number, checkpoint-iteration_number.pt)
|
|
|
|
iter_checkpoints = sorted(
|
|
iter_checkpoints, reverse=True, key=lambda x: x[0]
|
|
)
|
|
if iteration >= 0:
|
|
ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration]
|
|
else:
|
|
ans = [ic[1] for ic in iter_checkpoints if ic[0] <= -iteration]
|
|
|
|
return ans
|
|
|
|
|
|
def remove_checkpoints(
|
|
out_dir: Path,
|
|
topk: int,
|
|
rank: int = 0,
|
|
):
|
|
"""Remove checkpoints from the given directory.
|
|
|
|
We assume that checkpoint filename has the form `checkpoint-xxx.pt`
|
|
where xxx is a number, representing the number of processed batches
|
|
when saving that checkpoint. We sort checkpoints by filename and keep
|
|
only the `topk` checkpoints with the highest `xxx`.
|
|
|
|
Args:
|
|
out_dir:
|
|
The directory containing checkpoints to be removed.
|
|
topk:
|
|
Number of checkpoints to keep.
|
|
rank:
|
|
If using DDP for training, it is the rank of the current node.
|
|
Use 0 if no DDP is used for training.
|
|
"""
|
|
assert topk >= 1, topk
|
|
if rank != 0:
|
|
return
|
|
checkpoints = find_checkpoints(out_dir)
|
|
|
|
if len(checkpoints) == 0:
|
|
logging.warn(f"No checkpoints found in {out_dir}")
|
|
return
|
|
|
|
if len(checkpoints) <= topk:
|
|
return
|
|
|
|
to_remove = checkpoints[topk:]
|
|
for c in to_remove:
|
|
os.remove(c)
|