Keep only needed changes from Liyong's branch

This commit is contained in:
Daniel Povey 2023-01-05 12:23:32 +08:00
parent 096ebeaf23
commit b7be18c2f8
2 changed files with 191 additions and 101 deletions

View File

@ -519,7 +519,7 @@ class ScaledAdam(BatchedOptimizer):
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
): ):
""" """
Show information of parameter wihch dominating tot_sumsq. Show information of parameter which dominates tot_sumsq.
Args: Args:
tuples: a list of tuples of (param, state, param_names) tuples: a list of tuples of (param, state, param_names)
@ -678,12 +678,17 @@ class ScaledAdam(BatchedOptimizer):
) )
is_too_small = param_rms < param_min_rms is_too_small = param_rms < param_min_rms
is_too_large = param_rms > param_max_rms
# when the param gets too small, just don't shrink it any further. # when the param gets too small, just don't shrink it any further.
scale_step.masked_fill_(is_too_small, 0.0) scale_step.masked_fill_(is_too_small, 0.0)
# when it gets too large, stop it from getting any larger.
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) # and ensure the parameter rms after update never exceeds param_max_rms.
# We have to look at the trained model for parameters at or around the
# param_max_rms, because sometimes they can indicate a problem with the
# topology or settings.
scale_step = torch.minimum(scale_step,
(param_max_rms - param_rms) / param_rms)
delta = state["delta"] delta = state["delta"]
# the factor of (1-beta1) relates to momentum. # the factor of (1-beta1) relates to momentum.
delta.add_(p * scale_step, alpha=(1 - beta1)) delta.add_(p * scale_step, alpha=(1 - beta1))

View File

@ -59,6 +59,8 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from zipformer import Zipformer
from scaling import ScheduledFloat
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
@ -70,7 +72,6 @@ from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
from icefall import diagnostics from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import load_checkpoint, remove_checkpoints
@ -79,81 +80,125 @@ from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx, save_checkpoint_with_global_batch_idx,
update_averaged_model, update_averaged_model,
) )
from icefall.hooks import register_inf_check_hooks
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: def get_adjusted_batch_count(
params: AttributeDict) -> float:
# returns the number of batches we would have used so far if we had used the reference
# duration. This is for purposes of set_batch_count().
return (params.batch_idx_train * params.ref_duration /
(params.max_duration * params.world_size))
def set_batch_count(
model: Union[nn.Module, DDP], batch_count: float
) -> None:
if isinstance(model, DDP): if isinstance(model, DDP):
# get underlying nn.Module # get underlying nn.Module
model = model.module model = model.module
for module in model.modules(): for name, module in model.named_modules():
if hasattr(module, "batch_count"): if hasattr(module, 'batch_count'):
module.batch_count = batch_count module.batch_count = batch_count
if hasattr(module, 'name'):
module.name = name
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--num-encoder-layers", "--num-encoder-layers",
type=str, type=str,
default="2,4,3,2,4", default="4,4,4,4,4,4",
help="Number of zipformer encoder layers, comma separated.", help="Number of zipformer encoder layers per stack, comma separated.",
) )
parser.add_argument(
"--feedforward-dims",
type=str,
default="1024,1024,2048,2048,1024",
help="Feedforward dimension of the zipformer encoder layers, comma separated.",
)
parser.add_argument( parser.add_argument(
"--nhead", "--downsampling-factor",
type=str, type=str,
default="8,8,8,8,8", default="1,2,4,8,4,2",
help="Number of attention heads in the zipformer encoder layers.",
)
parser.add_argument(
"--encoder-dims",
type=str,
default="384,384,384,384,384",
help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
)
parser.add_argument(
"--attention-dims",
type=str,
default="192,192,192,192,192",
help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
not the same as embedding dimension.""",
)
parser.add_argument(
"--encoder-unmasked-dims",
type=str,
default="256,256,256,256,256",
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
" worse.",
)
parser.add_argument(
"--zipformer-downsampling-factors",
type=str,
default="1,2,4,8,2",
help="Downsampling factor for each stack of encoder layers.", help="Downsampling factor for each stack of encoder layers.",
) )
parser.add_argument( parser.add_argument(
"--cnn-module-kernels", "--feedforward-dim",
type=str, type=str,
default="31,31,31,31,31", default="1792,1792,2304,2304,2304,1792",
help="Sizes of kernels in convolution modules", help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
)
parser.add_argument(
"--num-heads",
type=str,
default="8,8,8,16,8,8",
help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
)
parser.add_argument(
"--attention-share-layers",
type=str,
default="2",
help="Number of layers that share attention weights within each zipformer stack: a single int or comma-separated list.",
)
parser.add_argument(
"--encoder-dim",
type=str,
default="384",
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
)
parser.add_argument(
"--query-head-dim",
type=str,
default="32",
help="Query/key dimension per head in encoder stacks: a single int or comma-separated list."
)
parser.add_argument(
"--value-head-dim",
type=str,
default="12",
help="Value dimension per head in encoder stacks: a single int or comma-separated list."
)
parser.add_argument(
"--pos-head-dim",
type=str,
default="4",
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list."
)
parser.add_argument(
"--pos-dim",
type=int,
default="48",
help="Positional-encoding embedding dimension"
)
parser.add_argument(
"--encoder-unmasked-dim",
type=str,
default="256",
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"A single int or comma-separated list. Must be <= each corresponding encoder_dim."
)
parser.add_argument(
"--cnn-module-kernel",
type=str,
default="31",
help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
"a single int or comma-separated list.",
) )
parser.add_argument( parser.add_argument(
@ -244,7 +289,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--base-lr", type=float, default=0.05, help="The base learning rate." "--base-lr",
type=float,
default=0.05,
help="The base learning rate."
) )
parser.add_argument( parser.add_argument(
@ -263,11 +311,21 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--ref-duration",
type=float,
default=600,
help="Reference batch duration for purposes of adjusting batch counts for setting various "
"schedules inside the model"
)
parser.add_argument( parser.add_argument(
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -290,7 +348,8 @@ def get_parser():
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network) part.", help="The scale to smooth the loss with am (output of encoder network)"
"part.",
) )
parser.add_argument( parser.add_argument(
@ -327,7 +386,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--save-every-n", "--save-every-n",
type=int, type=int,
default=2000, default=4000,
help="""Save checkpoint after processing this number of batches" help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename params.batch_idx_train % save_every_n == 0. The checkpoint filename
@ -442,21 +501,24 @@ def get_params() -> AttributeDict:
def get_encoder_model(params: AttributeDict) -> nn.Module: def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Zipformer and Transformer # TODO: We can add an option to switch between Zipformer and Transformer
def to_int_tuple(s: str): def to_int_tuple(s: str):
return tuple(map(int, s.split(","))) return tuple(map(int, s.split(',')))
encoder = Zipformer( encoder = Zipformer(
num_features=params.feature_dim, num_features=params.feature_dim,
output_downsampling_factor=2, output_downsampling_factor=2,
zipformer_downsampling_factors=to_int_tuple( downsampling_factor=to_int_tuple(params.downsampling_factor),
params.zipformer_downsampling_factors
),
encoder_dims=to_int_tuple(params.encoder_dims),
attention_dim=to_int_tuple(params.attention_dims),
encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
nhead=to_int_tuple(params.nhead),
feedforward_dim=to_int_tuple(params.feedforward_dims),
cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
num_encoder_layers=to_int_tuple(params.num_encoder_layers), num_encoder_layers=to_int_tuple(params.num_encoder_layers),
encoder_dim=to_int_tuple(params.encoder_dim),
encoder_unmasked_dim=to_int_tuple(params.encoder_unmasked_dim),
query_head_dim=to_int_tuple(params.query_head_dim),
pos_head_dim=to_int_tuple(params.pos_head_dim),
value_head_dim=to_int_tuple(params.value_head_dim),
pos_dim=params.pos_dim,
num_heads=to_int_tuple(params.num_heads),
attention_share_layers=to_int_tuple(params.attention_share_layers),
feedforward_dim=to_int_tuple(params.feedforward_dim),
cnn_module_kernel=to_int_tuple(params.cnn_module_kernel),
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
warmup_batches=4000.0,
) )
return encoder return encoder
@ -473,7 +535,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
def get_joiner_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
encoder_dim=int(params.encoder_dims.split(",")[-1]), encoder_dim=int(params.encoder_dim.split(',')[-1]),
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim, joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
@ -490,7 +552,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
encoder_dim=int(params.encoder_dims.split(",")[-1]), encoder_dim=int(params.encoder_dim.split(',')[-1]),
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim, joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
@ -625,7 +687,7 @@ def compute_loss(
is_training: bool, is_training: bool,
) -> Tuple[Tensor, MetricsTracker]: ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute transducer loss given the model and its inputs. Compute CTC loss given the model and its inputs.
Args: Args:
params: params:
@ -642,7 +704,11 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = model.device if isinstance(model, DDP) else next(model.parameters()).device device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
@ -672,24 +738,27 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step. # to params.simple_loss scale by warm_step.
simple_loss_scale = ( simple_loss_scale = (
s s if batch_idx_train >= warm_step
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
) )
pruned_loss_scale = ( pruned_loss_scale = (
1.0 1.0 if batch_idx_train >= warm_step
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step) else 0.1 + 0.9 * (batch_idx_train / warm_step)
) )
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss loss = (
simple_loss_scale * simple_loss +
pruned_loss_scale * pruned_loss
)
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -784,7 +853,22 @@ def train_one_epoch(
cur_batch_idx = params.get("cur_batch_idx", 0) cur_batch_idx = params.get("cur_batch_idx", 0)
saved_bad_model = False
def save_bad_model(suffix: str = ""):
save_checkpoint_impl(filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
model=model,
model_avg=model_avg,
params=params,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx % 10 == 0:
set_batch_count(model, get_adjusted_batch_count(params))
if batch_idx < cur_batch_idx: if batch_idx < cur_batch_idx:
continue continue
cur_batch_idx = batch_idx cur_batch_idx = batch_idx
@ -807,13 +891,13 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
scaler.scale(loss).backward() scaler.scale(loss).backward()
set_batch_count(model, params.batch_idx_train)
scheduler.step_batch(params.batch_idx_train) scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
except: # noqa except: # noqa
save_bad_model()
display_and_save_batch(batch, params=params, sp=sp) display_and_save_batch(batch, params=params, sp=sp)
raise raise
@ -860,14 +944,17 @@ def train_one_epoch(
# of the grad scaler is configurable, but we can't configure it to have different # of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale. # behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item() cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
scaler.update(cur_grad_scale * 2.0) scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
if not saved_bad_model:
save_bad_model(suffix="-first-warning")
saved_bad_model = True
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( save_bad_model()
f"grad_scale is too small, exiting: {cur_grad_scale}" raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}")
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]
@ -877,8 +964,8 @@ def train_one_epoch(
f"Epoch {params.cur_epoch}, " f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], " f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, " f"lr: {cur_lr:.2e}, " +
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
) )
if tb_writer is not None: if tb_writer is not None:
@ -889,14 +976,16 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if params.use_fp16: if params.use_fp16:
tb_writer.add_scalar( tb_writer.add_scalar(
"train/grad_scale", "train/grad_scale", cur_grad_scale, params.batch_idx_train
cur_grad_scale,
params.batch_idx_train,
) )
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
logging.info("Computing validation loss") logging.info("Computing validation loss")
valid_info = compute_validation_loss( valid_info = compute_validation_loss(
@ -908,9 +997,7 @@ def train_one_epoch(
) )
model.train() model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info( logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None: if tb_writer is not None:
valid_info.write_summary( valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train tb_writer, "train/valid_", params.batch_idx_train
@ -937,8 +1024,6 @@ def run(rank, world_size, args):
""" """
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
if params.full_libri is False:
params.valid_interval = 1600
fix_random_seed(params.seed) fix_random_seed(params.seed)
if world_size > 1: if world_size > 1:
@ -986,7 +1071,8 @@ def run(rank, world_size, args):
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True) model = DDP(model, device_ids=[rank],
find_unused_parameters=True)
optimizer = ScaledAdam( optimizer = ScaledAdam(
model.named_parameters(), model.named_parameters(),
@ -1010,7 +1096,7 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
2**22 2 ** 22
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)
@ -1087,7 +1173,8 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler(enabled=params.use_fp16,
init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1208,9 +1295,7 @@ def scan_pessimistic_batches_for_oom(
) )
display_and_save_batch(batch, params=params, sp=sp) display_and_save_batch(batch, params=params, sp=sp)
raise raise
logging.info( logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
def main(): def main():