Zengwei Yao b3e6bf66df
Add modified beam search decoding for streaming inference with emformer model (#327)
* Fix torch.nn.Embedding error for torch below 1.8.0

* Changes to fbank computation, use lilcom chunky writer

* Add min in q,k,v of attention

* Remove learnable offset, use relu instead.

* Experiments based on SpecAugment change

* Merge specaug change from Mingshuang.

* Use much more aggressive SpecAug setup

* Fix to num_feature_masks bug I introduced; reduce max_frames_mask_fraction 0.4->0.3

* Change p=0.5->0.9, mask_fraction 0.3->0.2

* Change p=0.9 to p=0.8 in SpecAug

* Fix num_time_masks code; revert 0.8 to 0.9

* Change max_frames from 0.2 to 0.15

* Remove ReLU in attention

* Adding diagnostics code...

* Refactor/simplify ConformerEncoder

* First version of rand-combine iterated-training-like idea.

* Improvements to diagnostics (RE those with 1 dim

* Add pelu to this good-performing setup..

* Small bug fixes/imports

* Add baseline for the PeLU expt, keeping only the small normalization-related changes.

* pelu_base->expscale, add 2xExpScale in subsampling, and in feedforward units.

* Double learning rate of exp-scale units

* Combine ExpScale and swish for memory reduction

* Add import

* Fix backprop bug

* Fix bug in diagnostics

* Increase scale on Scale from 4 to 20

* Increase scale from 20 to 50.

* Fix duplicate Swish; replace norm+swish with swish+exp-scale in convolution module

* Reduce scale from 50 to 20

* Add deriv-balancing code

* Double the threshold in brelu; slightly increase max_factor.

* Fix exp dir

* Convert swish nonlinearities to ReLU

* Replace relu with swish-squared.

* Restore ConvolutionModule to state before changes; change all Swish,Swish(Swish) to SwishOffset.

* Replace norm on input layer with scale of 0.1.

* Extensions to diagnostics code

* Update diagnostics

* Add BasicNorm module

* Replace most normalizations with scales (still have norm in conv)

* Change exp dir

* Replace norm in ConvolutionModule with a scaling factor.

* use nonzero threshold in DerivBalancer

* Add min-abs-value 0.2

* Fix dirname

* Change min-abs threshold from 0.2 to 0.5

* Scale up pos_bias_u and pos_bias_v before use.

* Reduce max_factor to 0.01

* Fix q*scaling logic

* Change max_factor in DerivBalancer from 0.025 to 0.01; fix scaling code.

* init 1st conv module to smaller variance

* Change how scales are applied; fix residual bug

* Reduce min_abs from 0.5 to 0.2

* Introduce in_scale=0.5 for SwishExpScale

* Fix scale from 0.5 to 2.0 as I really intended..

* Set scaling on SwishExpScale

* Add identity pre_norm_final for diagnostics.

* Add learnable post-scale for mha

* Fix self.post-scale-mha

* Another rework, use scales on linear/conv

* Change dir name

* Reduce initial scaling of modules

* Bug-fix RE bias

* Cosmetic change

* Reduce initial_scale.

* Replace ExpScaleRelu with DoubleSwish()

* DoubleSwish fix

* Use learnable scales for joiner and decoder

* Add max-abs-value constraint in DerivBalancer

* Add max-abs-value

* Change dir name

* Remove ExpScale in feedforward layes.

* Reduce max-abs limit from 1000 to 100; introduce 2 DerivBalancer modules in conv layer.

* Make DoubleSwish more memory efficient

* Reduce constraints from deriv-balancer in ConvModule.

* Add warmup mode

* Remove max-positive constraint in deriv-balancing; add second DerivBalancer in conv module.

* Add some extra info to diagnostics

* Add deriv-balancer at output of embedding.

* Add more stats.

* Make epsilon in BasicNorm learnable, optionally.

* Draft of 0mean changes..

* Rework of initialization

* Fix typo

* Remove dead code

* Modifying initialization from normal->uniform; add initial_scale when initializing

* bug fix re sqrt

* Remove xscale from pos_embedding

* Remove some dead code.

* Cosmetic changes/renaming things

* Start adding some files..

* Add more files..

* update decode.py file type

* Add remaining files in pruned_transducer_stateless2

* Fix diagnostics-getting code

* Scale down pruned loss in warmup mode

* Reduce warmup scale on pruned loss form 0.1 to 0.01.

* Remove scale_speed, make swish deriv more efficient.

* Cosmetic changes to swish

* Double warm_step

* Fix bug with import

* Change initial std from 0.05 to 0.025.

* Set also scale for embedding to 0.025.

* Remove logging code that broke with newer Lhotse; fix bug with pruned_loss

* Add norm+balancer to VggSubsampling

* Incorporate changes from master into pruned_transducer_stateless2.

* Add max-abs=6, debugged version

* Change 0.025,0.05 to 0.01 in initializations

* Fix balancer code

* Whitespace fix

* Reduce initial pruned_loss scale from 0.01 to 0.0

* Increase warm_step (and valid_interval)

* Change max-abs from 6 to 10

* Change how warmup works.

* Add changes from master to decode.py, train.py

* Simplify the warmup code; max_abs 10->6

* Make warmup work by scaling layer contributions; leave residual layer-drop

* Fix bug

* Fix test mode with random layer dropout

* Add random-number-setting function in dataloader

* Fix/patch how fix_random_seed() is imported.

* Reduce layer-drop prob

* Reduce layer-drop prob after warmup to 1 in 100

* Change power of lr-schedule from -0.5 to -0.333

* Increase model_warm_step to 4k

* Change max-keep-prob to 0.95

* Refactoring and simplifying conformer and frontend

* Rework conformer, remove some code.

* Reduce 1st conv channels from 64 to 32

* Add another convolutional layer

* Fix padding bug

* Remove dropout in output layer

* Reduce speed of some components

* Initial refactoring to remove unnecessary vocab_size

* Fix RE identity

* Bug-fix

* Add final dropout to conformer

* Remove some un-used code

* Replace nn.Linear with ScaledLinear in simple joiner

* Make 2 projections..

* Reduce initial_speed

* Use initial_speed=0.5

* Reduce initial_speed further from 0.5 to 0.25

* Reduce initial_speed from 0.5 to 0.25

* Change how warmup is applied.

* Bug fix to warmup_scale

* Fix test-mode

* Remove final dropout

* Make layer dropout rate 0.075, was 0.1.

* First draft of model rework

* Various bug fixes

* Change learning speed of simple_lm_proj

* Revert transducer_stateless/ to state in upstream/master

* Fix to joiner to allow different dims

* Some cleanups

* Make training more efficient, avoid redoing some projections.

* Change how warm-step is set

* First draft of new approach to learning rates + init

* Some fixes..

* Change initialization to 0.25

* Fix type of parameter

* Fix weight decay formula by adding 1/1-beta

* Fix weight decay formula by adding 1/1-beta

* Fix checkpoint-writing

* Fix to reading scheudler from optim

* Simplified optimizer, rework somet things..

* Reduce model_warm_step from 4k to 3k

* Fix bug in lambda

* Bug-fix RE sign of target_rms

* Changing initial_speed from 0.25 to 01

* Change some defaults in LR-setting rule.

* Remove initial_speed

* Set new scheduler

* Change exponential part of lrate to be epoch based

* Fix bug

* Set 2n rule..

* Implement 2o schedule

* Make lrate rule more symmetric

* Implement 2p version of learning rate schedule.

* Refactor how learning rate is set.

* Fix import

* Modify init (#301)

* update icefall/__init__.py to import more common functions.

* update icefall/__init__.py

* make imports style consistent.

* exclude black check for icefall/__init__.py in pyproject.toml.

* Minor fixes for logging (#296)

* Minor fixes for logging

* Minor fix

* Fix dir names

* Modify beam search to be efficient with current joienr

* Fix adding learning rate to tensorboard

* Fix docs in optim.py

* Support mix precision training on the reworked model (#305)

* Add mix precision support

* Minor fixes

* Minor fixes

* Minor fixes

* Tedlium3 pruned transducer stateless (#261)

* update tedlium3-pruned-transducer-stateless-codes

* update README.md

* update README.md

* add fast beam search for decoding

* do a change for RESULTS.md

* do a change for RESULTS.md

* do a fix

* do some changes for pruned RNN-T

* Add mix precision support

* Minor fixes

* Minor fixes

* Updating RESULTS.md; fix in beam_search.py

* Fix rebase

* Code style check for librispeech pruned transducer stateless2 (#308)

* Update results for tedlium3 pruned RNN-T (#307)

* Update README.md

* Fix CI errors. (#310)

* Add more results

* Fix tensorboard log location

* Add one more epoch of full expt

* fix comments

* Add results for mixed precision with max-duration 300

* Changes for pretrained.py (tedlium3 pruned RNN-T) (#311)

* GigaSpeech recipe (#120)

* initial commit

* support download, data prep, and fbank

* on-the-fly feature extraction by default

* support BPE based lang

* support HLG for BPE

* small fix

* small fix

* chunked feature extraction by default

* Compute features for GigaSpeech by splitting the manifest.

* Fixes after review.

* Split manifests into 2000 pieces.

* set audio duration mismatch tolerance to 0.01

* small fix

* add conformer training recipe

* Add conformer.py without pre-commit checking

* lazy loading and use SingleCutSampler

* DynamicBucketingSampler

* use KaldifeatFbank to compute fbank for musan

* use pretrained language model and lexicon

* use 3gram to decode, 4gram to rescore

* Add decode.py

* Update .flake8

* Delete compute_fbank_gigaspeech.py

* Use BucketingSampler for valid and test dataloader

* Update params in train.py

* Use bpe_500

* update params in decode.py

* Decrease num_paths while CUDA OOM

* Added README

* Update RESULTS

* black

* Decrease num_paths while CUDA OOM

* Decode with post-processing

* Update results

* Remove lazy_load option

* Use default `storage_type`

* Keep the original tolerance

* Use split-lazy

* black

* Update pretrained model

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>

* Add LG decoding (#277)

* Add LG decoding

* Add log weight pushing

* Minor fixes

* Support computing RNN-T loss with torchaudio (#316)

* Support modified beam search decoding for streaming inference with Emformer model.

* Formatted imports.

* Update results for torchaudio RNN-T. (#322)

* Fixed streaming decoding codes for emformer model.

* Fixed docs.

* Sorted imports for transducer_emformer/streaming_feature_extractor.py

* Minor fix for transducer_emformer/streaming_feature_extractor.py

Co-authored-by: pkufool <wkang@pku.org.cn>
Co-authored-by: Daniel Povey <dpovey@gmail.com>
Co-authored-by: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com>
Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
Co-authored-by: Guo Liyong <guonwpu@qq.com>
Co-authored-by: Wang, Guanbo <wgb14@outlook.com>
2022-04-22 18:06:07 +08:00

792 lines
23 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer_stateless/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir transducer_stateless/exp \
--full-libri 1 \
--max-duration 250 \
--lr-factor 2.5
"""
import argparse
import logging
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from model import Transducer
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=30,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
transducer_stateless/exp/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lr-factor",
type=float,
default=5.0,
help="The lr_factor for Noam optimizer",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--modified-transducer-prob",
type=float,
default=0.25,
help="""The probability to use modified transducer loss.
In modified transduer, it limits the maximum number of symbols
per frame to 1. See also the option --max-sym-per-frame in
transducer_stateless/decode.py
""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--print-diagnostics",
type=str2bool,
default=False,
help="Accumulate stats on activations, print them and exit.",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
- attention_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- warm_step: The warm_step for Noam optimizer.
"""
params = AttributeDict(
{
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for Noam
"warm_step": 80000, # For the 100h subset, use 8k
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Conformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
"""
device = model.device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
loss = model(
x=feature,
x_lens=feature_lens,
y=y,
modified_transducer_prob=params.modified_transducer_prob,
)
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
return loss, info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if params.print_diagnostics and batch_idx == 5:
return
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
sp=sp,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
if params.full_libri is False:
params.valid_interval = 800
params.warm_step = 8000
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank])
model.device = device
optimizer = Noam(
model.parameters(),
model_size=params.attention_dim,
factor=params.lr_factor,
warm_step=params.warm_step,
)
if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"])
librispeech = LibriSpeechAsrDataModule(args)
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
return 1.0 <= c.duration <= 20.0
num_in_total = len(train_cuts)
train_cuts = train_cuts.filter(remove_short_and_long_utt)
try:
num_left = len(train_cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
logging.info(
f"Before removing short and long utterances: {num_in_total}"
)
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(
f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
)
except TypeError as e:
# You can ignore this error as previous versions of Lhotse work fine
# for the above code. In recent versions of Lhotse, it uses
# lazy filter, producing cutsets that don't have the __len__ method
logging.info(str(e))
train_dl = librispeech.train_dataloaders(train_cuts)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
sp=sp,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def scan_pessimistic_batches_for_oom(
model: nn.Module,
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
logging.info(
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
)
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
optimizer.zero_grad()
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logging.error(
"Your GPU ran out of memory with the current "
"max_duration setting. We recommend decreasing "
"max_duration and trying again.\n"
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
raise
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()