mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Keep only needed changes from Liyong's branch
This commit is contained in:
parent
096ebeaf23
commit
b7be18c2f8
@ -519,7 +519,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
|
||||
):
|
||||
"""
|
||||
Show information of parameter wihch dominating tot_sumsq.
|
||||
Show information of parameter which dominates tot_sumsq.
|
||||
|
||||
Args:
|
||||
tuples: a list of tuples of (param, state, param_names)
|
||||
@ -678,12 +678,17 @@ class ScaledAdam(BatchedOptimizer):
|
||||
)
|
||||
|
||||
is_too_small = param_rms < param_min_rms
|
||||
is_too_large = param_rms > param_max_rms
|
||||
|
||||
# when the param gets too small, just don't shrink it any further.
|
||||
scale_step.masked_fill_(is_too_small, 0.0)
|
||||
# when it gets too large, stop it from getting any larger.
|
||||
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
|
||||
|
||||
# and ensure the parameter rms after update never exceeds param_max_rms.
|
||||
# We have to look at the trained model for parameters at or around the
|
||||
# param_max_rms, because sometimes they can indicate a problem with the
|
||||
# topology or settings.
|
||||
scale_step = torch.minimum(scale_step,
|
||||
(param_max_rms - param_rms) / param_rms)
|
||||
|
||||
delta = state["delta"]
|
||||
# the factor of (1-beta1) relates to momentum.
|
||||
delta.add_(p * scale_step, alpha=(1 - beta1))
|
||||
|
||||
@ -59,6 +59,8 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from zipformer import Zipformer
|
||||
from scaling import ScheduledFloat
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
@ -70,7 +72,6 @@ from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
@ -79,81 +80,125 @@ from icefall.checkpoint import (
|
||||
save_checkpoint_with_global_batch_idx,
|
||||
update_averaged_model,
|
||||
)
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
]
|
||||
|
||||
|
||||
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||
def get_adjusted_batch_count(
|
||||
params: AttributeDict) -> float:
|
||||
# returns the number of batches we would have used so far if we had used the reference
|
||||
# duration. This is for purposes of set_batch_count().
|
||||
return (params.batch_idx_train * params.ref_duration /
|
||||
(params.max_duration * params.world_size))
|
||||
|
||||
|
||||
|
||||
def set_batch_count(
|
||||
model: Union[nn.Module, DDP], batch_count: float
|
||||
) -> None:
|
||||
if isinstance(model, DDP):
|
||||
# get underlying nn.Module
|
||||
model = model.module
|
||||
for module in model.modules():
|
||||
if hasattr(module, "batch_count"):
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, 'batch_count'):
|
||||
module.batch_count = batch_count
|
||||
if hasattr(module, 'name'):
|
||||
module.name = name
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--num-encoder-layers",
|
||||
type=str,
|
||||
default="2,4,3,2,4",
|
||||
help="Number of zipformer encoder layers, comma separated.",
|
||||
default="4,4,4,4,4,4",
|
||||
help="Number of zipformer encoder layers per stack, comma separated.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--feedforward-dims",
|
||||
type=str,
|
||||
default="1024,1024,2048,2048,1024",
|
||||
help="Feedforward dimension of the zipformer encoder layers, comma separated.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nhead",
|
||||
"--downsampling-factor",
|
||||
type=str,
|
||||
default="8,8,8,8,8",
|
||||
help="Number of attention heads in the zipformer encoder layers.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-dims",
|
||||
type=str,
|
||||
default="384,384,384,384,384",
|
||||
help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attention-dims",
|
||||
type=str,
|
||||
default="192,192,192,192,192",
|
||||
help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
|
||||
not the same as embedding dimension.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-unmasked-dims",
|
||||
type=str,
|
||||
default="256,256,256,256,256",
|
||||
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
||||
"Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
|
||||
" worse.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--zipformer-downsampling-factors",
|
||||
type=str,
|
||||
default="1,2,4,8,2",
|
||||
default="1,2,4,8,4,2",
|
||||
help="Downsampling factor for each stack of encoder layers.",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--cnn-module-kernels",
|
||||
"--feedforward-dim",
|
||||
type=str,
|
||||
default="31,31,31,31,31",
|
||||
help="Sizes of kernels in convolution modules",
|
||||
default="1792,1792,2304,2304,2304,1792",
|
||||
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-heads",
|
||||
type=str,
|
||||
default="8,8,8,16,8,8",
|
||||
help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attention-share-layers",
|
||||
type=str,
|
||||
default="2",
|
||||
help="Number of layers that share attention weights within each zipformer stack: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-dim",
|
||||
type=str,
|
||||
default="384",
|
||||
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--query-head-dim",
|
||||
type=str,
|
||||
default="32",
|
||||
help="Query/key dimension per head in encoder stacks: a single int or comma-separated list."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--value-head-dim",
|
||||
type=str,
|
||||
default="12",
|
||||
help="Value dimension per head in encoder stacks: a single int or comma-separated list."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pos-head-dim",
|
||||
type=str,
|
||||
default="4",
|
||||
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pos-dim",
|
||||
type=int,
|
||||
default="48",
|
||||
help="Positional-encoding embedding dimension"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-unmasked-dim",
|
||||
type=str,
|
||||
default="256",
|
||||
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
||||
"A single int or comma-separated list. Must be <= each corresponding encoder_dim."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cnn-module-kernel",
|
||||
type=str,
|
||||
default="31",
|
||||
help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
|
||||
"a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -244,7 +289,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--base-lr", type=float, default=0.05, help="The base learning rate."
|
||||
"--base-lr",
|
||||
type=float,
|
||||
default=0.05,
|
||||
help="The base learning rate."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -263,11 +311,21 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ref-duration",
|
||||
type=float,
|
||||
default=600,
|
||||
help="Reference batch duration for purposes of adjusting batch counts for setting various "
|
||||
"schedules inside the model"
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -290,7 +348,8 @@ def get_parser():
|
||||
"--am-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The scale to smooth the loss with am (output of encoder network) part.",
|
||||
help="The scale to smooth the loss with am (output of encoder network)"
|
||||
"part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -327,7 +386,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--save-every-n",
|
||||
type=int,
|
||||
default=2000,
|
||||
default=4000,
|
||||
help="""Save checkpoint after processing this number of batches"
|
||||
periodically. We save checkpoint to exp-dir/ whenever
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
@ -442,21 +501,24 @@ def get_params() -> AttributeDict:
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
# TODO: We can add an option to switch between Zipformer and Transformer
|
||||
def to_int_tuple(s: str):
|
||||
return tuple(map(int, s.split(",")))
|
||||
|
||||
return tuple(map(int, s.split(',')))
|
||||
encoder = Zipformer(
|
||||
num_features=params.feature_dim,
|
||||
output_downsampling_factor=2,
|
||||
zipformer_downsampling_factors=to_int_tuple(
|
||||
params.zipformer_downsampling_factors
|
||||
),
|
||||
encoder_dims=to_int_tuple(params.encoder_dims),
|
||||
attention_dim=to_int_tuple(params.attention_dims),
|
||||
encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
|
||||
nhead=to_int_tuple(params.nhead),
|
||||
feedforward_dim=to_int_tuple(params.feedforward_dims),
|
||||
cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
|
||||
downsampling_factor=to_int_tuple(params.downsampling_factor),
|
||||
num_encoder_layers=to_int_tuple(params.num_encoder_layers),
|
||||
encoder_dim=to_int_tuple(params.encoder_dim),
|
||||
encoder_unmasked_dim=to_int_tuple(params.encoder_unmasked_dim),
|
||||
query_head_dim=to_int_tuple(params.query_head_dim),
|
||||
pos_head_dim=to_int_tuple(params.pos_head_dim),
|
||||
value_head_dim=to_int_tuple(params.value_head_dim),
|
||||
pos_dim=params.pos_dim,
|
||||
num_heads=to_int_tuple(params.num_heads),
|
||||
attention_share_layers=to_int_tuple(params.attention_share_layers),
|
||||
feedforward_dim=to_int_tuple(params.feedforward_dim),
|
||||
cnn_module_kernel=to_int_tuple(params.cnn_module_kernel),
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||
warmup_batches=4000.0,
|
||||
)
|
||||
return encoder
|
||||
|
||||
@ -473,7 +535,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
encoder_dim=int(params.encoder_dims.split(",")[-1]),
|
||||
encoder_dim=int(params.encoder_dim.split(',')[-1]),
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
@ -490,7 +552,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
encoder_dim=int(params.encoder_dims.split(",")[-1]),
|
||||
encoder_dim=int(params.encoder_dim.split(',')[-1]),
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
@ -625,7 +687,7 @@ def compute_loss(
|
||||
is_training: bool,
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
"""
|
||||
Compute transducer loss given the model and its inputs.
|
||||
Compute CTC loss given the model and its inputs.
|
||||
|
||||
Args:
|
||||
params:
|
||||
@ -642,7 +704,11 @@ def compute_loss(
|
||||
warmup: a floating point value which increases throughout training;
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
device = (
|
||||
model.device
|
||||
if isinstance(model, DDP)
|
||||
else next(model.parameters()).device
|
||||
)
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
@ -672,24 +738,27 @@ def compute_loss(
|
||||
# take down the scale on the simple loss from 1.0 at the start
|
||||
# to params.simple_loss scale by warm_step.
|
||||
simple_loss_scale = (
|
||||
s
|
||||
if batch_idx_train >= warm_step
|
||||
s if batch_idx_train >= warm_step
|
||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
||||
)
|
||||
pruned_loss_scale = (
|
||||
1.0
|
||||
if batch_idx_train >= warm_step
|
||||
1.0 if batch_idx_train >= warm_step
|
||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||
)
|
||||
|
||||
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
loss = (
|
||||
simple_loss_scale * simple_loss +
|
||||
pruned_loss_scale * pruned_loss
|
||||
)
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
@ -784,7 +853,22 @@ def train_one_epoch(
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
saved_bad_model = False
|
||||
def save_bad_model(suffix: str = ""):
|
||||
save_checkpoint_impl(filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=0)
|
||||
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx % 10 == 0:
|
||||
set_batch_count(model, get_adjusted_batch_count(params))
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
@ -807,13 +891,13 @@ def train_one_epoch(
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
scaler.scale(loss).backward()
|
||||
set_batch_count(model, params.batch_idx_train)
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
except: # noqa
|
||||
save_bad_model()
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
raise
|
||||
|
||||
@ -860,14 +944,17 @@ def train_one_epoch(
|
||||
# of the grad scaler is configurable, but we can't configure it to have different
|
||||
# behavior depending on the current grad scale.
|
||||
cur_grad_scale = scaler._scale.item()
|
||||
if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
|
||||
|
||||
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
|
||||
scaler.update(cur_grad_scale * 2.0)
|
||||
if cur_grad_scale < 0.01:
|
||||
if not saved_bad_model:
|
||||
save_bad_model(suffix="-first-warning")
|
||||
saved_bad_model = True
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
save_bad_model()
|
||||
raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}")
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
@ -877,8 +964,8 @@ def train_one_epoch(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
f"lr: {cur_lr:.2e}, "
|
||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||
f"lr: {cur_lr:.2e}, " +
|
||||
(f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
@ -889,14 +976,16 @@ def train_one_epoch(
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
if params.use_fp16:
|
||||
tb_writer.add_scalar(
|
||||
"train/grad_scale",
|
||||
cur_grad_scale,
|
||||
params.batch_idx_train,
|
||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||
)
|
||||
|
||||
|
||||
|
||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
@ -908,9 +997,7 @@ def train_one_epoch(
|
||||
)
|
||||
model.train()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
@ -937,8 +1024,6 @@ def run(rank, world_size, args):
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
if params.full_libri is False:
|
||||
params.valid_interval = 1600
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
if world_size > 1:
|
||||
@ -986,7 +1071,8 @@ def run(rank, world_size, args):
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
model = DDP(model, device_ids=[rank],
|
||||
find_unused_parameters=True)
|
||||
|
||||
optimizer = ScaledAdam(
|
||||
model.named_parameters(),
|
||||
@ -1087,7 +1173,8 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
scaler = GradScaler(enabled=params.use_fp16,
|
||||
init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
@ -1208,9 +1295,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
)
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
raise
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user