mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
* Remove ReLU in attention * Adding diagnostics code... * Refactor/simplify ConformerEncoder * First version of rand-combine iterated-training-like idea. * Improvements to diagnostics (RE those with 1 dim * Add pelu to this good-performing setup.. * Small bug fixes/imports * Add baseline for the PeLU expt, keeping only the small normalization-related changes. * pelu_base->expscale, add 2xExpScale in subsampling, and in feedforward units. * Double learning rate of exp-scale units * Combine ExpScale and swish for memory reduction * Add import * Fix backprop bug * Fix bug in diagnostics * Increase scale on Scale from 4 to 20 * Increase scale from 20 to 50. * Fix duplicate Swish; replace norm+swish with swish+exp-scale in convolution module * Reduce scale from 50 to 20 * Add deriv-balancing code * Double the threshold in brelu; slightly increase max_factor. * Fix exp dir * Convert swish nonlinearities to ReLU * Replace relu with swish-squared. * Restore ConvolutionModule to state before changes; change all Swish,Swish(Swish) to SwishOffset. * Replace norm on input layer with scale of 0.1. * Extensions to diagnostics code * Update diagnostics * Add BasicNorm module * Replace most normalizations with scales (still have norm in conv) * Change exp dir * Replace norm in ConvolutionModule with a scaling factor. * use nonzero threshold in DerivBalancer * Add min-abs-value 0.2 * Fix dirname * Change min-abs threshold from 0.2 to 0.5 * Scale up pos_bias_u and pos_bias_v before use. * Reduce max_factor to 0.01 * Fix q*scaling logic * Change max_factor in DerivBalancer from 0.025 to 0.01; fix scaling code. * init 1st conv module to smaller variance * Change how scales are applied; fix residual bug * Reduce min_abs from 0.5 to 0.2 * Introduce in_scale=0.5 for SwishExpScale * Fix scale from 0.5 to 2.0 as I really intended.. * Set scaling on SwishExpScale * Add identity pre_norm_final for diagnostics. * Add learnable post-scale for mha * Fix self.post-scale-mha * Another rework, use scales on linear/conv * Change dir name * Reduce initial scaling of modules * Bug-fix RE bias * Cosmetic change * Reduce initial_scale. * Replace ExpScaleRelu with DoubleSwish() * DoubleSwish fix * Use learnable scales for joiner and decoder * Add max-abs-value constraint in DerivBalancer * Add max-abs-value * Change dir name * Remove ExpScale in feedforward layes. * Reduce max-abs limit from 1000 to 100; introduce 2 DerivBalancer modules in conv layer. * Make DoubleSwish more memory efficient * Reduce constraints from deriv-balancer in ConvModule. * Add warmup mode * Remove max-positive constraint in deriv-balancing; add second DerivBalancer in conv module. * Add some extra info to diagnostics * Add deriv-balancer at output of embedding. * Add more stats. * Make epsilon in BasicNorm learnable, optionally. * Draft of 0mean changes.. * Rework of initialization * Fix typo * Remove dead code * Modifying initialization from normal->uniform; add initial_scale when initializing * bug fix re sqrt * Remove xscale from pos_embedding * Remove some dead code. * Cosmetic changes/renaming things * Start adding some files.. * Add more files.. * update decode.py file type * Add remaining files in pruned_transducer_stateless2 * Fix diagnostics-getting code * Scale down pruned loss in warmup mode * Reduce warmup scale on pruned loss form 0.1 to 0.01. * Remove scale_speed, make swish deriv more efficient. * Cosmetic changes to swish * Double warm_step * Fix bug with import * Change initial std from 0.05 to 0.025. * Set also scale for embedding to 0.025. * Remove logging code that broke with newer Lhotse; fix bug with pruned_loss * Add norm+balancer to VggSubsampling * Incorporate changes from master into pruned_transducer_stateless2. * Add max-abs=6, debugged version * Change 0.025,0.05 to 0.01 in initializations * Fix balancer code * Whitespace fix * Reduce initial pruned_loss scale from 0.01 to 0.0 * Increase warm_step (and valid_interval) * Change max-abs from 6 to 10 * Change how warmup works. * Add changes from master to decode.py, train.py * Simplify the warmup code; max_abs 10->6 * Make warmup work by scaling layer contributions; leave residual layer-drop * Fix bug * Fix test mode with random layer dropout * Add random-number-setting function in dataloader * Fix/patch how fix_random_seed() is imported. * Reduce layer-drop prob * Reduce layer-drop prob after warmup to 1 in 100 * Change power of lr-schedule from -0.5 to -0.333 * Increase model_warm_step to 4k * Change max-keep-prob to 0.95 * Refactoring and simplifying conformer and frontend * Rework conformer, remove some code. * Reduce 1st conv channels from 64 to 32 * Add another convolutional layer * Fix padding bug * Remove dropout in output layer * Reduce speed of some components * Initial refactoring to remove unnecessary vocab_size * Fix RE identity * Bug-fix * Add final dropout to conformer * Remove some un-used code * Replace nn.Linear with ScaledLinear in simple joiner * Make 2 projections.. * Reduce initial_speed * Use initial_speed=0.5 * Reduce initial_speed further from 0.5 to 0.25 * Reduce initial_speed from 0.5 to 0.25 * Change how warmup is applied. * Bug fix to warmup_scale * Fix test-mode * Remove final dropout * Make layer dropout rate 0.075, was 0.1. * First draft of model rework * Various bug fixes * Change learning speed of simple_lm_proj * Revert transducer_stateless/ to state in upstream/master * Fix to joiner to allow different dims * Some cleanups * Make training more efficient, avoid redoing some projections. * Change how warm-step is set * First draft of new approach to learning rates + init * Some fixes.. * Change initialization to 0.25 * Fix type of parameter * Fix weight decay formula by adding 1/1-beta * Fix weight decay formula by adding 1/1-beta * Fix checkpoint-writing * Fix to reading scheudler from optim * Simplified optimizer, rework somet things.. * Reduce model_warm_step from 4k to 3k * Fix bug in lambda * Bug-fix RE sign of target_rms * Changing initial_speed from 0.25 to 01 * Change some defaults in LR-setting rule. * Remove initial_speed * Set new scheduler * Change exponential part of lrate to be epoch based * Fix bug * Set 2n rule.. * Implement 2o schedule * Make lrate rule more symmetric * Implement 2p version of learning rate schedule. * Refactor how learning rate is set. * Fix import * Modify init (#301) * update icefall/__init__.py to import more common functions. * update icefall/__init__.py * make imports style consistent. * exclude black check for icefall/__init__.py in pyproject.toml. * Minor fixes for logging (#296) * Minor fixes for logging * Minor fix * Fix dir names * Modify beam search to be efficient with current joienr * Fix adding learning rate to tensorboard * Fix docs in optim.py * Support mix precision training on the reworked model (#305) * Add mix precision support * Minor fixes * Minor fixes * Minor fixes * Tedlium3 pruned transducer stateless (#261) * update tedlium3-pruned-transducer-stateless-codes * update README.md * update README.md * add fast beam search for decoding * do a change for RESULTS.md * do a change for RESULTS.md * do a fix * do some changes for pruned RNN-T * Add mix precision support * Minor fixes * Minor fixes * Updating RESULTS.md; fix in beam_search.py * Fix rebase * Code style check for librispeech pruned transducer stateless2 (#308) * Update results for tedlium3 pruned RNN-T (#307) * Update README.md * Fix CI errors. (#310) * Add more results * Fix tensorboard log location * Add one more epoch of full expt * fix comments * Add results for mixed precision with max-duration 300 * Changes for pretrained.py (tedlium3 pruned RNN-T) (#311) * GigaSpeech recipe (#120) * initial commit * support download, data prep, and fbank * on-the-fly feature extraction by default * support BPE based lang * support HLG for BPE * small fix * small fix * chunked feature extraction by default * Compute features for GigaSpeech by splitting the manifest. * Fixes after review. * Split manifests into 2000 pieces. * set audio duration mismatch tolerance to 0.01 * small fix * add conformer training recipe * Add conformer.py without pre-commit checking * lazy loading and use SingleCutSampler * DynamicBucketingSampler * use KaldifeatFbank to compute fbank for musan * use pretrained language model and lexicon * use 3gram to decode, 4gram to rescore * Add decode.py * Update .flake8 * Delete compute_fbank_gigaspeech.py * Use BucketingSampler for valid and test dataloader * Update params in train.py * Use bpe_500 * update params in decode.py * Decrease num_paths while CUDA OOM * Added README * Update RESULTS * black * Decrease num_paths while CUDA OOM * Decode with post-processing * Update results * Remove lazy_load option * Use default `storage_type` * Keep the original tolerance * Use split-lazy * black * Update pretrained model Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com> * Add LG decoding (#277) * Add LG decoding * Add log weight pushing * Minor fixes * Support computing RNN-T loss with torchaudio (#316) * Update results for torchaudio RNN-T. (#322) * Fix some typos. (#329) * fix fp16 option in example usage (#332) * Support averaging models with weight tying. (#333) * Support specifying iteration number of checkpoints for decoding. (#336) See also #289 * Modified conformer with multi datasets (#312) * Copy files for editing. * Use librispeech + gigaspeech with modified conformer. * Support specifying number of workers for on-the-fly feature extraction. * Feature extraction code for GigaSpeech. * Combine XL splits lazily during training. * Fix warnings in decoding. * Add decoding code for GigaSpeech. * Fix decoding the gigaspeech dataset. We have to use the decoder/joiner networks for the GigaSpeech dataset. * Disable speed perturbe for XL subset. * Compute the Nbest oracle WER for RNN-T decoding. * Minor fixes. * Minor fixes. * Add results. * Update results. * Update CI. * Update results. * Fix style issues. * Update results. * Fix style issues. * Update results. (#340) * Update results. * Typo fixes. * Validate generated manifest files. (#338) * Validate generated manifest files. (#338) * Save batch to disk on OOM. (#343) * Save batch to disk on OOM. * minor fixes * Fixes after review. * Fix style issues. * Fix decoding for gigaspeech in the libri + giga setup. (#345) * Model average (#344) * First upload of model average codes. * minor fix * update decode file * update .flake8 * rename pruned_transducer_stateless3 to pruned_transducer_stateless4 * change epoch number counter starting from 1 instead of 0 * minor fix of pruned_transducer_stateless4/train.py * refactor the checkpoint.py * minor fix, update docs, and modify the epoch number to count from 1 in the pruned_transducer_stateless4/decode.py * update author info * add docs of the scaling in function average_checkpoints_with_averaged_model * Save batch to disk on exception. (#350) * Bug fix (#352) * Keep model_avg on cpu (#348) * keep model_avg on cpu * explicitly convert model_avg to cpu * minor fix * remove device convertion for model_avg * modify usage of the model device in train.py * change model.device to next(model.parameters()).device for decoding * assert params.start_epoch>0 * assert params.start_epoch>0, params.start_epoch * Do some changes for aishell/ASR/transducer stateless/export.py (#347) * do some changes for aishell/ASR/transducer_stateless/export.py * Support decoding with averaged model when using --iter (#353) * support decoding with averaged model when using --iter * minor fix * monir fix of copyright date * Stringify torch.__version__ before serializing it. (#354) * Run decode.py in GitHub actions. (#356) * Ignore padding frames during RNN-T decoding. (#358) * Ignore padding frames during RNN-T decoding. * Fix outdated decoding code. * Minor fixes. * Support --iter in export.py (#360) * GigaSpeech RNN-T experiments (#318) * Copy RNN-T recipe from librispeech * flake8 * flake8 * Update params * gigaspeech decode * black * Update results * syntax highlight * Update RESULTS.md * typo * Update decoding script for gigaspeech and remove duplicate files. (#361) * Validate that there are no OOV tokens in BPE-based lexicons. (#359) * Validate that there are no OOV tokens in BPE-based lexicons. * Typo fixes. * Decode gigaspeech in GitHub actions (#362) * Add CI for gigaspeech. * Update results for libri+giga multi dataset setup. (#363) * Update results for libri+giga multi dataset setup. * Update GigaSpeech reults (#364) * Update decode.py * Update export.py * Update results * Update README.md * Fix GitHub CI for decoding GigaSpeech dev/test datasets (#366) * modify .flake8 * minor fix * minor fix Co-authored-by: Daniel Povey <dpovey@gmail.com> Co-authored-by: Wei Kang <wkang@pku.org.cn> Co-authored-by: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com> Co-authored-by: Guo Liyong <guonwpu@qq.com> Co-authored-by: Wang, Guanbo <wgb14@outlook.com> Co-authored-by: whsqkaak <whsqkaak@naver.com> Co-authored-by: pehonnet <pe.honnet@gmail.com>
1127 lines
34 KiB
Python
Executable File
1127 lines
34 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
|
# Wei Kang
|
|
# Mingshuang Luo)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Usage:
|
|
|
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|
|
|
cd egs/librispeech/ASR/
|
|
./prepare.sh
|
|
./prepare_giga_speech.sh
|
|
|
|
./pruned_transducer_stateless3/train.py \
|
|
--world-size 4 \
|
|
--num-epochs 30 \
|
|
--start-epoch 0 \
|
|
--exp-dir pruned_transducer_stateless3/exp \
|
|
--full-libri 1 \
|
|
--max-duration 300
|
|
|
|
# For mix precision training:
|
|
|
|
./pruned_transducer_stateless3/train.py \
|
|
--world-size 4 \
|
|
--num-epochs 30 \
|
|
--start-epoch 0 \
|
|
--use_fp16 1 \
|
|
--exp-dir pruned_transducer_stateless3/exp \
|
|
--full-libri 1 \
|
|
--max-duration 550
|
|
|
|
"""
|
|
|
|
|
|
import argparse
|
|
import logging
|
|
import random
|
|
import warnings
|
|
from pathlib import Path
|
|
from shutil import copyfile
|
|
from typing import Any, Dict, Optional, Tuple, Union
|
|
|
|
import k2
|
|
import optim
|
|
import sentencepiece as spm
|
|
import torch
|
|
import torch.multiprocessing as mp
|
|
import torch.nn as nn
|
|
from asr_datamodule import AsrDataModule
|
|
from conformer import Conformer
|
|
from decoder import Decoder
|
|
from gigaspeech import GigaSpeech
|
|
from joiner import Joiner
|
|
from lhotse import CutSet, load_manifest
|
|
from lhotse.cut import Cut
|
|
from lhotse.dataset.sampling.base import CutSampler
|
|
from lhotse.utils import fix_random_seed
|
|
from librispeech import LibriSpeech
|
|
from model import Transducer
|
|
from optim import Eden, Eve
|
|
from torch import Tensor
|
|
from torch.cuda.amp import GradScaler
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from icefall import diagnostics
|
|
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
|
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
|
from icefall.dist import cleanup_dist, setup_dist
|
|
from icefall.env import get_env_info
|
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
|
|
|
LRSchedulerType = Union[
|
|
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
|
]
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--world-size",
|
|
type=int,
|
|
default=1,
|
|
help="Number of GPUs for DDP training.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--master-port",
|
|
type=int,
|
|
default=12354,
|
|
help="Master port to use for DDP training.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--tensorboard",
|
|
type=str2bool,
|
|
default=True,
|
|
help="Should various information be logged in tensorboard.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--full-libri",
|
|
type=str2bool,
|
|
default=True,
|
|
help="When enabled, use 960h LibriSpeech. "
|
|
"Otherwise, use 100h subset.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--num-epochs",
|
|
type=int,
|
|
default=30,
|
|
help="Number of epochs to train.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--start-epoch",
|
|
type=int,
|
|
default=0,
|
|
help="""Resume training from from this epoch.
|
|
If it is positive, it will load checkpoint from
|
|
transducer_stateless3/exp/epoch-{start_epoch-1}.pt
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--start-batch",
|
|
type=int,
|
|
default=0,
|
|
help="""If positive, --start-epoch is ignored and
|
|
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--exp-dir",
|
|
type=str,
|
|
default="pruned_transducer_stateless3/exp",
|
|
help="""The experiment dir.
|
|
It specifies the directory where all training related
|
|
files, e.g., checkpoints, log, etc, are saved
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--bpe-model",
|
|
type=str,
|
|
default="data/lang_bpe_500/bpe.model",
|
|
help="Path to the BPE model",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--initial-lr",
|
|
type=float,
|
|
default=0.003,
|
|
help="The initial learning rate. This value should not need "
|
|
"to be changed.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--lr-batches",
|
|
type=float,
|
|
default=5000,
|
|
help="""Number of steps that affects how rapidly the learning rate decreases.
|
|
We suggest not to change this.""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--lr-epochs",
|
|
type=float,
|
|
default=4,
|
|
help="""Number of epochs that affects how rapidly the learning rate decreases.
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--context-size",
|
|
type=int,
|
|
default=2,
|
|
help="The context size in the decoder. 1 means bigram; "
|
|
"2 means tri-gram",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--prune-range",
|
|
type=int,
|
|
default=5,
|
|
help="The prune range for rnnt loss, it means how many symbols(context)"
|
|
"we are using to compute the loss",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--lm-scale",
|
|
type=float,
|
|
default=0.25,
|
|
help="The scale to smooth the loss with lm "
|
|
"(output of prediction network) part.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--am-scale",
|
|
type=float,
|
|
default=0.0,
|
|
help="The scale to smooth the loss with am (output of encoder network)"
|
|
"part.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--simple-loss-scale",
|
|
type=float,
|
|
default=0.5,
|
|
help="To get pruning ranges, we will calculate a simple version"
|
|
"loss(joiner is just addition), this simple loss also uses for"
|
|
"training (as a regularization item). We will scale the simple loss"
|
|
"with this parameter before adding to the final loss.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--seed",
|
|
type=int,
|
|
default=42,
|
|
help="The seed for random generators intended for reproducibility",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--print-diagnostics",
|
|
type=str2bool,
|
|
default=False,
|
|
help="Accumulate stats on activations, print them and exit.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--save-every-n",
|
|
type=int,
|
|
default=8000,
|
|
help="""Save checkpoint after processing this number of batches"
|
|
periodically. We save checkpoint to exp-dir/ whenever
|
|
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
|
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
|
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
|
end of each epoch where `xxx` is the epoch number counting from 0.
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--keep-last-k",
|
|
type=int,
|
|
default=20,
|
|
help="""Only keep this number of checkpoints on disk.
|
|
For instance, if it is 3, there are only 3 checkpoints
|
|
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
|
It does not affect checkpoints with name `epoch-xxx.pt`.
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--use-fp16",
|
|
type=str2bool,
|
|
default=False,
|
|
help="Whether to use half precision training.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--giga-prob",
|
|
type=float,
|
|
default=0.5,
|
|
help="The probability to select a batch from the GigaSpeech dataset",
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
def get_params() -> AttributeDict:
|
|
"""Return a dict containing training parameters.
|
|
|
|
All training related parameters that are not passed from the commandline
|
|
are saved in the variable `params`.
|
|
|
|
Commandline options are merged into `params` after they are parsed, so
|
|
you can also access them via `params`.
|
|
|
|
Explanation of options saved in `params`:
|
|
|
|
- best_train_loss: Best training loss so far. It is used to select
|
|
the model that has the lowest training loss. It is
|
|
updated during the training.
|
|
|
|
- best_valid_loss: Best validation loss so far. It is used to select
|
|
the model that has the lowest validation loss. It is
|
|
updated during the training.
|
|
|
|
- best_train_epoch: It is the epoch that has the best training loss.
|
|
|
|
- best_valid_epoch: It is the epoch that has the best validation loss.
|
|
|
|
- batch_idx_train: Used to writing statistics to tensorboard. It
|
|
contains number of batches trained so far across
|
|
epochs.
|
|
|
|
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
|
|
|
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
|
|
|
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
|
|
|
- feature_dim: The model input dim. It has to match the one used
|
|
in computing features.
|
|
|
|
- subsampling_factor: The subsampling factor for the model.
|
|
|
|
- encoder_dim: Hidden dim for multi-head attention model.
|
|
|
|
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
|
|
|
- warm_step: The warm_step for Noam optimizer.
|
|
"""
|
|
params = AttributeDict(
|
|
{
|
|
"best_train_loss": float("inf"),
|
|
"best_valid_loss": float("inf"),
|
|
"best_train_epoch": -1,
|
|
"best_valid_epoch": -1,
|
|
"batch_idx_train": 0,
|
|
"log_interval": 50,
|
|
"reset_interval": 200,
|
|
"valid_interval": 3000, # For the 100h subset, use 800
|
|
# parameters for conformer
|
|
"feature_dim": 80,
|
|
"subsampling_factor": 4,
|
|
"encoder_dim": 512,
|
|
"nhead": 8,
|
|
"dim_feedforward": 2048,
|
|
"num_encoder_layers": 12,
|
|
# parameters for decoder
|
|
"decoder_dim": 512,
|
|
# parameters for joiner
|
|
"joiner_dim": 512,
|
|
# parameters for Noam
|
|
"model_warm_step": 3000, # arg given to model, not for lrate
|
|
"env_info": get_env_info(),
|
|
}
|
|
)
|
|
|
|
return params
|
|
|
|
|
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|
# TODO: We can add an option to switch between Conformer and Transformer
|
|
encoder = Conformer(
|
|
num_features=params.feature_dim,
|
|
subsampling_factor=params.subsampling_factor,
|
|
d_model=params.encoder_dim,
|
|
nhead=params.nhead,
|
|
dim_feedforward=params.dim_feedforward,
|
|
num_encoder_layers=params.num_encoder_layers,
|
|
)
|
|
return encoder
|
|
|
|
|
|
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
|
decoder = Decoder(
|
|
vocab_size=params.vocab_size,
|
|
decoder_dim=params.decoder_dim,
|
|
blank_id=params.blank_id,
|
|
context_size=params.context_size,
|
|
)
|
|
return decoder
|
|
|
|
|
|
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|
joiner = Joiner(
|
|
encoder_dim=params.encoder_dim,
|
|
decoder_dim=params.decoder_dim,
|
|
joiner_dim=params.joiner_dim,
|
|
vocab_size=params.vocab_size,
|
|
)
|
|
return joiner
|
|
|
|
|
|
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|
encoder = get_encoder_model(params)
|
|
decoder = get_decoder_model(params)
|
|
joiner = get_joiner_model(params)
|
|
|
|
decoder_giga = get_decoder_model(params)
|
|
joiner_giga = get_joiner_model(params)
|
|
|
|
model = Transducer(
|
|
encoder=encoder,
|
|
decoder=decoder,
|
|
joiner=joiner,
|
|
decoder_giga=decoder_giga,
|
|
joiner_giga=joiner_giga,
|
|
encoder_dim=params.encoder_dim,
|
|
decoder_dim=params.decoder_dim,
|
|
joiner_dim=params.joiner_dim,
|
|
vocab_size=params.vocab_size,
|
|
)
|
|
return model
|
|
|
|
|
|
def load_checkpoint_if_available(
|
|
params: AttributeDict,
|
|
model: nn.Module,
|
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
|
scheduler: Optional[LRSchedulerType] = None,
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""Load checkpoint from file.
|
|
|
|
If params.start_batch is positive, it will load the checkpoint from
|
|
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
|
params.start_epoch is positive, it will load the checkpoint from
|
|
`params.start_epoch - 1`.
|
|
|
|
Apart from loading state dict for `model` and `optimizer` it also updates
|
|
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
|
and `best_valid_loss` in `params`.
|
|
|
|
Args:
|
|
params:
|
|
The return value of :func:`get_params`.
|
|
model:
|
|
The training model.
|
|
optimizer:
|
|
The optimizer that we are using.
|
|
scheduler:
|
|
The scheduler that we are using.
|
|
Returns:
|
|
Return a dict containing previously saved training info.
|
|
"""
|
|
if params.start_batch > 0:
|
|
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
|
elif params.start_epoch > 0:
|
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
|
else:
|
|
return None
|
|
|
|
assert filename.is_file(), f"{filename} does not exist!"
|
|
|
|
saved_params = load_checkpoint(
|
|
filename,
|
|
model=model,
|
|
optimizer=optimizer,
|
|
scheduler=scheduler,
|
|
)
|
|
|
|
keys = [
|
|
"best_train_epoch",
|
|
"best_valid_epoch",
|
|
"batch_idx_train",
|
|
"best_train_loss",
|
|
"best_valid_loss",
|
|
]
|
|
for k in keys:
|
|
params[k] = saved_params[k]
|
|
|
|
if params.start_batch > 0:
|
|
if "cur_epoch" in saved_params:
|
|
params["start_epoch"] = saved_params["cur_epoch"]
|
|
|
|
return saved_params
|
|
|
|
|
|
def save_checkpoint(
|
|
params: AttributeDict,
|
|
model: nn.Module,
|
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
|
scheduler: Optional[LRSchedulerType] = None,
|
|
sampler: Optional[CutSampler] = None,
|
|
scaler: Optional[GradScaler] = None,
|
|
rank: int = 0,
|
|
) -> None:
|
|
"""Save model, optimizer, scheduler and training stats to file.
|
|
|
|
Args:
|
|
params:
|
|
It is returned by :func:`get_params`.
|
|
model:
|
|
The training model.
|
|
optimizer:
|
|
The optimizer used in the training.
|
|
sampler:
|
|
The sampler for the training dataset.
|
|
scaler:
|
|
The scaler used for mix precision training.
|
|
"""
|
|
if rank != 0:
|
|
return
|
|
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
|
save_checkpoint_impl(
|
|
filename=filename,
|
|
model=model,
|
|
params=params,
|
|
optimizer=optimizer,
|
|
scheduler=scheduler,
|
|
sampler=sampler,
|
|
scaler=scaler,
|
|
rank=rank,
|
|
)
|
|
|
|
if params.best_train_epoch == params.cur_epoch:
|
|
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
|
copyfile(src=filename, dst=best_train_filename)
|
|
|
|
if params.best_valid_epoch == params.cur_epoch:
|
|
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
|
copyfile(src=filename, dst=best_valid_filename)
|
|
|
|
|
|
def is_libri(c: Cut) -> bool:
|
|
"""Return True if this cut is from the LibriSpeech dataset.
|
|
|
|
Note:
|
|
During data preparation, we set the custom field in
|
|
the supervision segment of GigaSpeech to dict(origin='giga')
|
|
See ../local/preprocess_gigaspeech.py.
|
|
"""
|
|
return c.supervisions[0].custom is None
|
|
|
|
|
|
def compute_loss(
|
|
params: AttributeDict,
|
|
model: nn.Module,
|
|
sp: spm.SentencePieceProcessor,
|
|
batch: dict,
|
|
is_training: bool,
|
|
warmup: float = 1.0,
|
|
) -> Tuple[Tensor, MetricsTracker]:
|
|
"""
|
|
Compute CTC loss given the model and its inputs.
|
|
|
|
Args:
|
|
params:
|
|
Parameters for training. See :func:`get_params`.
|
|
model:
|
|
The model for training. It is an instance of Conformer in our case.
|
|
batch:
|
|
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
|
for the content in it.
|
|
is_training:
|
|
True for training. False for validation. When it is True, this
|
|
function enables autograd during computation; when it is False, it
|
|
disables autograd.
|
|
warmup: a floating point value which increases throughout training;
|
|
values >= 1.0 are fully warmed up and have all modules present.
|
|
"""
|
|
device = model.device
|
|
feature = batch["inputs"]
|
|
# at entry, feature is (N, T, C)
|
|
assert feature.ndim == 3
|
|
feature = feature.to(device)
|
|
|
|
supervisions = batch["supervisions"]
|
|
feature_lens = supervisions["num_frames"].to(device)
|
|
|
|
libri = is_libri(supervisions["cut"][0])
|
|
|
|
texts = batch["supervisions"]["text"]
|
|
y = sp.encode(texts, out_type=int)
|
|
y = k2.RaggedTensor(y).to(device)
|
|
|
|
with torch.set_grad_enabled(is_training):
|
|
simple_loss, pruned_loss = model(
|
|
x=feature,
|
|
x_lens=feature_lens,
|
|
y=y,
|
|
libri=libri,
|
|
prune_range=params.prune_range,
|
|
am_scale=params.am_scale,
|
|
lm_scale=params.lm_scale,
|
|
warmup=warmup,
|
|
)
|
|
# after the main warmup step, we keep pruned_loss_scale small
|
|
# for the same amount of time (model_warm_step), to avoid
|
|
# overwhelming the simple_loss and causing it to diverge,
|
|
# in case it had not fully learned the alignment yet.
|
|
pruned_loss_scale = (
|
|
0.0
|
|
if warmup < 1.0
|
|
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
|
)
|
|
loss = (
|
|
params.simple_loss_scale * simple_loss
|
|
+ pruned_loss_scale * pruned_loss
|
|
)
|
|
|
|
assert loss.requires_grad == is_training
|
|
|
|
info = MetricsTracker()
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("ignore")
|
|
info["frames"] = (
|
|
(feature_lens // params.subsampling_factor).sum().item()
|
|
)
|
|
|
|
# Note: We use reduction=sum while computing the loss.
|
|
info["loss"] = loss.detach().cpu().item()
|
|
info["simple_loss"] = simple_loss.detach().cpu().item()
|
|
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
|
|
|
return loss, info
|
|
|
|
|
|
def compute_validation_loss(
|
|
params: AttributeDict,
|
|
model: nn.Module,
|
|
sp: spm.SentencePieceProcessor,
|
|
valid_dl: torch.utils.data.DataLoader,
|
|
world_size: int = 1,
|
|
) -> MetricsTracker:
|
|
"""Run the validation process."""
|
|
model.eval()
|
|
|
|
tot_loss = MetricsTracker()
|
|
|
|
for batch_idx, batch in enumerate(valid_dl):
|
|
loss, loss_info = compute_loss(
|
|
params=params,
|
|
model=model,
|
|
sp=sp,
|
|
batch=batch,
|
|
is_training=False,
|
|
)
|
|
assert loss.requires_grad is False
|
|
tot_loss = tot_loss + loss_info
|
|
|
|
if world_size > 1:
|
|
tot_loss.reduce(loss.device)
|
|
|
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
|
if loss_value < params.best_valid_loss:
|
|
params.best_valid_epoch = params.cur_epoch
|
|
params.best_valid_loss = loss_value
|
|
|
|
return tot_loss
|
|
|
|
|
|
def train_one_epoch(
|
|
params: AttributeDict,
|
|
model: nn.Module,
|
|
optimizer: torch.optim.Optimizer,
|
|
scheduler: LRSchedulerType,
|
|
sp: spm.SentencePieceProcessor,
|
|
train_dl: torch.utils.data.DataLoader,
|
|
giga_train_dl: torch.utils.data.DataLoader,
|
|
valid_dl: torch.utils.data.DataLoader,
|
|
rng: random.Random,
|
|
scaler: GradScaler,
|
|
tb_writer: Optional[SummaryWriter] = None,
|
|
world_size: int = 1,
|
|
rank: int = 0,
|
|
) -> None:
|
|
"""Train the model for one epoch.
|
|
|
|
The training loss from the mean of all frames is saved in
|
|
`params.train_loss`. It runs the validation process every
|
|
`params.valid_interval` batches.
|
|
|
|
Args:
|
|
params:
|
|
It is returned by :func:`get_params`.
|
|
model:
|
|
The model for training.
|
|
optimizer:
|
|
The optimizer we are using.
|
|
scheduler:
|
|
The learning rate scheduler, we call step() every step.
|
|
train_dl:
|
|
Dataloader for the training dataset.
|
|
giga_train_dl:
|
|
Dataloader for the GigaSpeech training dataset.
|
|
valid_dl:
|
|
Dataloader for the validation dataset.
|
|
rng:
|
|
For selecting which dataset to use.
|
|
scaler:
|
|
The scaler used for mix precision training.
|
|
tb_writer:
|
|
Writer to write log messages to tensorboard.
|
|
world_size:
|
|
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
|
rank:
|
|
The rank of the node in DDP training. If no DDP is used, it should
|
|
be set to 0.
|
|
"""
|
|
model.train()
|
|
|
|
libri_tot_loss = MetricsTracker()
|
|
giga_tot_loss = MetricsTracker()
|
|
tot_loss = MetricsTracker()
|
|
|
|
# index 0: for LibriSpeech
|
|
# index 1: for GigaSpeech
|
|
# This sets the probabilities for choosing which datasets
|
|
dl_weights = [1 - params.giga_prob, params.giga_prob]
|
|
|
|
iter_libri = iter(train_dl)
|
|
iter_giga = iter(giga_train_dl)
|
|
|
|
batch_idx = 0
|
|
|
|
while True:
|
|
idx = rng.choices((0, 1), weights=dl_weights, k=1)[0]
|
|
dl = iter_libri if idx == 0 else iter_giga
|
|
|
|
try:
|
|
batch = next(dl)
|
|
except StopIteration:
|
|
name = "libri" if idx == 0 else "giga"
|
|
logging.info(f"{name} reaches end of dataloader")
|
|
break
|
|
|
|
batch_idx += 1
|
|
|
|
params.batch_idx_train += 1
|
|
batch_size = len(batch["supervisions"]["text"])
|
|
|
|
libri = is_libri(batch["supervisions"]["cut"][0])
|
|
|
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
|
loss, loss_info = compute_loss(
|
|
params=params,
|
|
model=model,
|
|
sp=sp,
|
|
batch=batch,
|
|
is_training=True,
|
|
warmup=(params.batch_idx_train / params.model_warm_step),
|
|
)
|
|
# summary stats
|
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
|
|
|
if libri:
|
|
libri_tot_loss = (
|
|
libri_tot_loss * (1 - 1 / params.reset_interval)
|
|
) + loss_info
|
|
prefix = "libri" # for logging only
|
|
else:
|
|
giga_tot_loss = (
|
|
giga_tot_loss * (1 - 1 / params.reset_interval)
|
|
) + loss_info
|
|
prefix = "giga"
|
|
|
|
# NOTE: We use reduction==sum and loss is computed over utterances
|
|
# in the batch and there is no normalization to it so far.
|
|
scaler.scale(loss).backward()
|
|
scheduler.step_batch(params.batch_idx_train)
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
optimizer.zero_grad()
|
|
|
|
if params.print_diagnostics and batch_idx == 5:
|
|
return
|
|
|
|
if (
|
|
params.batch_idx_train > 0
|
|
and params.batch_idx_train % params.save_every_n == 0
|
|
):
|
|
save_checkpoint_with_global_batch_idx(
|
|
out_dir=params.exp_dir,
|
|
global_batch_idx=params.batch_idx_train,
|
|
model=model,
|
|
params=params,
|
|
optimizer=optimizer,
|
|
scheduler=scheduler,
|
|
sampler=train_dl.sampler,
|
|
scaler=scaler,
|
|
rank=rank,
|
|
)
|
|
remove_checkpoints(
|
|
out_dir=params.exp_dir,
|
|
topk=params.keep_last_k,
|
|
rank=rank,
|
|
)
|
|
|
|
if batch_idx % params.log_interval == 0:
|
|
cur_lr = scheduler.get_last_lr()[0]
|
|
logging.info(
|
|
f"Epoch {params.cur_epoch}, "
|
|
f"batch {batch_idx}, {prefix}_loss[{loss_info}], "
|
|
f"tot_loss[{tot_loss}], "
|
|
f"libri_tot_loss[{libri_tot_loss}], "
|
|
f"giga_tot_loss[{giga_tot_loss}], "
|
|
f"batch size: {batch_size}"
|
|
f"lr: {cur_lr:.2e}"
|
|
)
|
|
|
|
if tb_writer is not None:
|
|
tb_writer.add_scalar(
|
|
"train/learning_rate", cur_lr, params.batch_idx_train
|
|
)
|
|
|
|
loss_info.write_summary(
|
|
tb_writer,
|
|
f"train/current_{prefix}_",
|
|
params.batch_idx_train,
|
|
)
|
|
tot_loss.write_summary(
|
|
tb_writer, "train/tot_", params.batch_idx_train
|
|
)
|
|
libri_tot_loss.write_summary(
|
|
tb_writer, "train/libri_tot_", params.batch_idx_train
|
|
)
|
|
giga_tot_loss.write_summary(
|
|
tb_writer, "train/giga_tot_", params.batch_idx_train
|
|
)
|
|
|
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
|
logging.info("Computing validation loss")
|
|
valid_info = compute_validation_loss(
|
|
params=params,
|
|
model=model,
|
|
sp=sp,
|
|
valid_dl=valid_dl,
|
|
world_size=world_size,
|
|
)
|
|
model.train()
|
|
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
|
if tb_writer is not None:
|
|
valid_info.write_summary(
|
|
tb_writer, "train/valid_", params.batch_idx_train
|
|
)
|
|
|
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
|
params.train_loss = loss_value
|
|
if params.train_loss < params.best_train_loss:
|
|
params.best_train_epoch = params.cur_epoch
|
|
params.best_train_loss = params.train_loss
|
|
|
|
|
|
def filter_short_and_long_utterances(cuts: CutSet) -> CutSet:
|
|
def remove_short_and_long_utt(c: Cut):
|
|
# Keep only utterances with duration between 1 second and 20 seconds
|
|
#
|
|
# Caution: There is a reason to select 20.0 here. Please see
|
|
# ../local/display_manifest_statistics.py
|
|
#
|
|
# You should use ../local/display_manifest_statistics.py to get
|
|
# an utterance duration distribution for your dataset to select
|
|
# the threshold
|
|
return 1.0 <= c.duration <= 20.0
|
|
|
|
cuts = cuts.filter(remove_short_and_long_utt)
|
|
|
|
return cuts
|
|
|
|
|
|
def run(rank, world_size, args):
|
|
"""
|
|
Args:
|
|
rank:
|
|
It is a value between 0 and `world_size-1`, which is
|
|
passed automatically by `mp.spawn()` in :func:`main`.
|
|
The node with rank 0 is responsible for saving checkpoint.
|
|
world_size:
|
|
Number of GPUs for DDP training.
|
|
args:
|
|
The return value of get_parser().parse_args()
|
|
"""
|
|
params = get_params()
|
|
params.update(vars(args))
|
|
if params.full_libri is False:
|
|
params.valid_interval = 1600
|
|
|
|
fix_random_seed(params.seed)
|
|
rng = random.Random(params.seed)
|
|
if world_size > 1:
|
|
setup_dist(rank, world_size, params.master_port)
|
|
|
|
setup_logger(f"{params.exp_dir}/log/log-train")
|
|
logging.info("Training started")
|
|
|
|
if args.tensorboard and rank == 0:
|
|
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
|
else:
|
|
tb_writer = None
|
|
|
|
device = torch.device("cpu")
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda", rank)
|
|
logging.info(f"Device: {device}")
|
|
|
|
sp = spm.SentencePieceProcessor()
|
|
sp.load(params.bpe_model)
|
|
|
|
# <blk> is defined in local/train_bpe_model.py
|
|
params.blank_id = sp.piece_to_id("<blk>")
|
|
params.vocab_size = sp.get_piece_size()
|
|
|
|
logging.info(params)
|
|
|
|
logging.info("About to create model")
|
|
model = get_transducer_model(params)
|
|
|
|
num_param = sum([p.numel() for p in model.parameters()])
|
|
logging.info(f"Number of model parameters: {num_param}")
|
|
|
|
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
|
|
|
model.to(device)
|
|
if world_size > 1:
|
|
logging.info("Using DDP")
|
|
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
|
model.device = device
|
|
|
|
optimizer = Eve(model.parameters(), lr=params.initial_lr)
|
|
|
|
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
|
|
|
if checkpoints and "optimizer" in checkpoints:
|
|
logging.info("Loading optimizer state dict")
|
|
optimizer.load_state_dict(checkpoints["optimizer"])
|
|
|
|
if (
|
|
checkpoints
|
|
and "scheduler" in checkpoints
|
|
and checkpoints["scheduler"] is not None
|
|
):
|
|
logging.info("Loading scheduler state dict")
|
|
scheduler.load_state_dict(checkpoints["scheduler"])
|
|
|
|
if params.print_diagnostics:
|
|
opts = diagnostics.TensorDiagnosticOptions(
|
|
2 ** 22
|
|
) # allow 4 megabytes per sub-module
|
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
|
|
|
librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
|
|
|
|
train_cuts = librispeech.train_clean_100_cuts()
|
|
if params.full_libri:
|
|
train_cuts += librispeech.train_clean_360_cuts()
|
|
train_cuts += librispeech.train_other_500_cuts()
|
|
|
|
train_cuts = filter_short_and_long_utterances(train_cuts)
|
|
|
|
gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir)
|
|
# XL 10k hours
|
|
# L 2.5k hours
|
|
# M 1k hours
|
|
# S 250 hours
|
|
# XS 10 hours
|
|
# DEV 12 hours
|
|
# Test 40 hours
|
|
if params.full_libri:
|
|
logging.info("Using the XL subset of GigaSpeech (10k hours)")
|
|
train_giga_cuts = gigaspeech.train_XL_cuts()
|
|
else:
|
|
logging.info("Using the S subset of GigaSpeech (250 hours)")
|
|
train_giga_cuts = gigaspeech.train_S_cuts()
|
|
|
|
train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts)
|
|
train_giga_cuts = train_giga_cuts.repeat(times=None)
|
|
|
|
if args.enable_musan:
|
|
cuts_musan = load_manifest(
|
|
Path(args.manifest_dir) / "cuts_musan.json.gz"
|
|
)
|
|
else:
|
|
cuts_musan = None
|
|
|
|
asr_datamodule = AsrDataModule(args)
|
|
|
|
train_dl = asr_datamodule.train_dataloaders(
|
|
train_cuts,
|
|
dynamic_bucketing=False,
|
|
on_the_fly_feats=False,
|
|
cuts_musan=cuts_musan,
|
|
)
|
|
|
|
giga_train_dl = asr_datamodule.train_dataloaders(
|
|
train_giga_cuts,
|
|
dynamic_bucketing=True,
|
|
on_the_fly_feats=False,
|
|
cuts_musan=cuts_musan,
|
|
)
|
|
|
|
valid_cuts = librispeech.dev_clean_cuts()
|
|
valid_cuts += librispeech.dev_other_cuts()
|
|
valid_dl = asr_datamodule.valid_dataloaders(valid_cuts)
|
|
|
|
# It's time consuming to include `giga_train_dl` here
|
|
# for dl in [train_dl, giga_train_dl]:
|
|
for dl in [train_dl]:
|
|
scan_pessimistic_batches_for_oom(
|
|
model=model,
|
|
train_dl=dl,
|
|
optimizer=optimizer,
|
|
sp=sp,
|
|
params=params,
|
|
)
|
|
|
|
scaler = GradScaler(enabled=params.use_fp16)
|
|
if checkpoints and "grad_scaler" in checkpoints:
|
|
logging.info("Loading grad scaler state dict")
|
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
|
|
|
for epoch in range(params.start_epoch, params.num_epochs):
|
|
scheduler.step_epoch(epoch)
|
|
fix_random_seed(params.seed + epoch)
|
|
train_dl.sampler.set_epoch(epoch)
|
|
|
|
if tb_writer is not None:
|
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
|
|
|
params.cur_epoch = epoch
|
|
|
|
train_one_epoch(
|
|
params=params,
|
|
model=model,
|
|
optimizer=optimizer,
|
|
scheduler=scheduler,
|
|
sp=sp,
|
|
train_dl=train_dl,
|
|
giga_train_dl=giga_train_dl,
|
|
valid_dl=valid_dl,
|
|
rng=rng,
|
|
scaler=scaler,
|
|
tb_writer=tb_writer,
|
|
world_size=world_size,
|
|
rank=rank,
|
|
)
|
|
|
|
if params.print_diagnostics:
|
|
diagnostic.print_diagnostics()
|
|
break
|
|
|
|
save_checkpoint(
|
|
params=params,
|
|
model=model,
|
|
optimizer=optimizer,
|
|
scheduler=scheduler,
|
|
sampler=train_dl.sampler,
|
|
scaler=scaler,
|
|
rank=rank,
|
|
)
|
|
|
|
logging.info("Done!")
|
|
|
|
if world_size > 1:
|
|
torch.distributed.barrier()
|
|
cleanup_dist()
|
|
|
|
|
|
def scan_pessimistic_batches_for_oom(
|
|
model: nn.Module,
|
|
train_dl: torch.utils.data.DataLoader,
|
|
optimizer: torch.optim.Optimizer,
|
|
sp: spm.SentencePieceProcessor,
|
|
params: AttributeDict,
|
|
):
|
|
from lhotse.dataset import find_pessimistic_batches
|
|
|
|
logging.info(
|
|
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
|
|
)
|
|
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
|
for criterion, cuts in batches.items():
|
|
batch = train_dl.dataset[cuts]
|
|
try:
|
|
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
|
# (i.e. are not remembered by the decaying-average in adam), because
|
|
# we want to avoid these params being subject to shrinkage in adam.
|
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
|
loss, _ = compute_loss(
|
|
params=params,
|
|
model=model,
|
|
sp=sp,
|
|
batch=batch,
|
|
is_training=True,
|
|
warmup=0.0,
|
|
)
|
|
loss.backward()
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
except RuntimeError as e:
|
|
if "CUDA out of memory" in str(e):
|
|
logging.error(
|
|
"Your GPU ran out of memory with the current "
|
|
"max_duration setting. We recommend decreasing "
|
|
"max_duration and trying again.\n"
|
|
f"Failing criterion: {criterion} "
|
|
f"(={crit_values[criterion]}) ..."
|
|
)
|
|
raise
|
|
|
|
|
|
def main():
|
|
parser = get_parser()
|
|
AsrDataModule.add_arguments(parser)
|
|
args = parser.parse_args()
|
|
args.exp_dir = Path(args.exp_dir)
|
|
|
|
assert 0 <= args.giga_prob < 1, args.giga_prob
|
|
|
|
world_size = args.world_size
|
|
assert world_size >= 1
|
|
if world_size > 1:
|
|
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
|
else:
|
|
run(rank=0, world_size=1, args=args)
|
|
|
|
|
|
torch.set_num_threads(1)
|
|
torch.set_num_interop_threads(1)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|