take a couple files from liyong's branch

This commit is contained in:
Daniel Povey 2023-01-05 12:01:42 +08:00
parent f688066517
commit 096ebeaf23
2 changed files with 263 additions and 209 deletions

View File

@ -18,7 +18,7 @@ import contextlib
import logging
import random
from collections import defaultdict
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import torch
from lhotse.utils import fix_random_seed
@ -132,6 +132,9 @@ class ScaledAdam(BatchedOptimizer):
Args:
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
Unlike common optimizers, which accept model.parameters() or groups of parameters(),
this optimizer could accept model.named_parameters() or groups of named_parameters().
See comments of function _get_names_of_parameters for its 4 possible cases.
lr: The learning rate. We will typically use a learning rate schedule that starts
at 0.03 and decreases over time, i.e. much higher than other common
optimizers.
@ -178,15 +181,8 @@ class ScaledAdam(BatchedOptimizer):
scalar_max=10.0,
size_update_period=4,
clipping_update_period=100,
parameters_names=None,
show_dominant_parameters=True,
):
assert parameters_names is not None, (
"Please prepare parameters_names,"
"which is a List[List[str]]. Each List[str] is for a group"
"and each str is for a parameter"
)
defaults = dict(
lr=lr,
clipping_scale=clipping_scale,
@ -200,10 +196,135 @@ class ScaledAdam(BatchedOptimizer):
clipping_update_period=clipping_update_period,
)
# If params only contains parameters or group of parameters,
# i.e when parameter names are not given,
# this flag will be set to False in funciton _get_names_of_parameters.
self.show_dominant_parameters = True
params, parameters_names = self._get_names_of_parameters(params)
super(ScaledAdam, self).__init__(params, defaults)
assert len(self.param_groups) == len(parameters_names)
self.parameters_names = parameters_names
self.show_dominant_parameters = show_dominant_parameters
def _get_names_of_parameters(
self, params_or_named_params
) -> Tuple[List[Dict], List[List[str]]]:
"""
Args:
params_or_named_params: according to the way ScaledAdam is initialized in train.py,
this argument could be one of following 4 cases,
case 1, a generator of parameter, e.g.:
optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0)
case 2, a list of parameter groups with different config, e.g.:
model_param_groups = [
{'params': model.encoder.parameters(), 'lr': 0.05},
{'params': model.decoder.parameters(), 'lr': 0.01},
{'params': model.joiner.parameters(), 'lr': 0.03},
]
optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0)
case 3, a generator of named_parameter, e.g.:
optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0)
case 4, a list of named_parameter groups with different config, e.g.:
model_named_param_groups = [
{'params': model.encoder.named_parameters(), 'lr': 0.05},
{'params': model.decoder.named_parameters(), 'lr': 0.01},
{'params': model.joiner.named_parameters(), 'lr': 0.03},
]
optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0)
For case 1 and case 2, input params is used to initialize the underlying torch.optimizer.
For case 3 and case 4, firstly, names and params are extracted from input named_params,
then, these extracted params are used to initialize the underlying torch.optimizer,
and these extracted names are mainly used by function
`_show_gradient_dominating_parameter`
Returns:
Returns a tuple containing 2 elements:
- `param_groups` with type List[Dict], each Dict element is a parameter group.
An example of `param_groups` could be:
[
{'params': `one iterable of Parameter`, 'lr': 0.05},
{'params': `another iterable of Parameter`, 'lr': 0.08},
{'params': `a third iterable of Parameter`, 'lr': 0.1},
]
- `param_gruops_names` with type List[List[str]],
each `List[str]` is for a group['params'] in param_groups,
and each `str` is the name of a parameter.
A dummy name "foo" is related to each parameter,
if input are params without names, i.e. case 1 or case 2.
"""
# variable naming convention in this function:
# p is short for param.
# np is short for named_param.
# p_or_np is short for param_or_named_param.
# cur is short for current.
# group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}.
# groups is a List[group]
iterable_or_groups = list(params_or_named_params)
if len(iterable_or_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
# The first value of returned tuple.
param_groups = []
# The second value of returned tuple,
# a List[List[str]], each sub-List is for a group.
param_groups_names = []
if not isinstance(iterable_or_groups[0], dict):
# case 1 or case 3,
# the input is an iterable of parameter or named parameter.
param_iterable_cur_group = []
param_names_cur_group = []
for p_or_np in iterable_or_groups:
if isinstance(p_or_np, tuple):
# case 3
name, param = p_or_np
else:
# case 1
assert isinstance(p_or_np, torch.Tensor)
param = p_or_np
# Assign a dummy name as a placeholder
name = "foo"
self.show_dominant_parameters = False
param_iterable_cur_group.append(param)
param_names_cur_group.append(name)
param_groups.append({"params": param_iterable_cur_group})
param_groups_names.append(param_names_cur_group)
else:
# case 2 or case 4
# the input is groups of parameter or named parameter.
for p_or_np_cur_group in iterable_or_groups:
param_iterable_cur_group = []
param_names_cur_group = []
p_or_np_iterable = p_or_np_cur_group["params"]
for p_or_np in p_or_np_iterable:
if isinstance(p_or_np, tuple):
# case 4
name, param = p_or_np
else:
# case 2
assert isinstance(p_or_np, torch.Tensor)
param = p_or_np
# Assign a dummy name as a placeholder
name = "foo"
self.show_dominant_parameters = False
param_iterable_cur_group.append(param)
param_names_cur_group.append(name)
# The original `params` filed contains named_parameters.
# After following assignment,
# it will be changed to an iterable of parameter,
# and other fileds, if exist, are still original values.
# So param_groups could be used to initialize
# an underlying torch.Optimizer later.
p_or_np_cur_group["params"] = param_iterable_cur_group
param_groups.append(p_or_np_cur_group)
param_groups_names.append(param_names_cur_group)
return param_groups, param_groups_names
def __setstate__(self, state):
super(ScaledAdam, self).__setstate__(state)
@ -398,7 +519,7 @@ class ScaledAdam(BatchedOptimizer):
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
):
"""
Show information of parameter wihch dominanting tot_sumsq.
Show information of parameter wihch dominating tot_sumsq.
Args:
tuples: a list of tuples of (param, state, param_names)
@ -416,7 +537,7 @@ class ScaledAdam(BatchedOptimizer):
batch_grad = p.grad
if p.numel() == p.shape[0]: # a batch of scalars
batch_sumsq_orig = batch_grad**2
# Dummpy values used by following `zip` statement.
# Dummy values used by following `zip` statement.
batch_rms_orig = torch.ones(p.shape[0])
else:
batch_rms_orig = state["param_rms"]
@ -449,11 +570,11 @@ class ScaledAdam(BatchedOptimizer):
dominant_grad,
) = sorted_by_proportion[dominant_param_name]
logging.info(
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
f"Parameter dominating tot_sumsq {dominant_param_name}"
f" with proportion {dominant_proportion:.2f},"
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
f"={dominant_sumsq:.3e},"
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
f" grad_sumsq={(dominant_grad**2).sum():.3e},"
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
)
@ -561,11 +682,8 @@ class ScaledAdam(BatchedOptimizer):
# when the param gets too small, just don't shrink it any further.
scale_step.masked_fill_(is_too_small, 0.0)
# and ensure the parameter rms after update never exceeds param_max_rms.
scale_step = torch.minimum(scale_step,
(param_max_rms - param_rms) / param_rms)
# when it gets too large, stop it from getting any larger.
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
delta = state["delta"]
# the factor of (1-beta1) relates to momentum.
delta.add_(p * scale_step, alpha=(1 - beta1))
@ -779,9 +897,7 @@ class Eden(LRScheduler):
def _test_eden():
m = torch.nn.Linear(100, 100)
parameters_names = [ [ x[0] for x in m.named_parameters() ] ]
optim = ScaledAdam(m.parameters(), lr=0.03,
parameters_names=parameters_names)
optim = ScaledAdam(m.parameters(), lr=0.03)
scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
@ -992,9 +1108,7 @@ def _test_scaled_adam(hidden_dim: int):
if iter == 0:
optim = Eve(m.parameters(), lr=0.003)
elif iter == 1:
parameters_names = [ [ x[0] for x in m.named_parameters() ] ]
optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0,
parameters_names=parameters_names)
optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
start = timeit.default_timer()

View File

@ -59,8 +59,6 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from zipformer import Zipformer
from scaling import ScheduledFloat
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
@ -72,6 +70,7 @@ from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints
@ -80,125 +79,81 @@ from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx,
update_averaged_model,
)
from icefall.hooks import register_inf_check_hooks
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def get_adjusted_batch_count(
params: AttributeDict) -> float:
# returns the number of batches we would have used so far if we had used the reference
# duration. This is for purposes of set_batch_count().
return (params.batch_idx_train * params.ref_duration /
(params.max_duration * params.world_size))
def set_batch_count(
model: Union[nn.Module, DDP], batch_count: float
) -> None:
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
if isinstance(model, DDP):
# get underlying nn.Module
model = model.module
for name, module in model.named_modules():
if hasattr(module, 'batch_count'):
for module in model.modules():
if hasattr(module, "batch_count"):
module.batch_count = batch_count
if hasattr(module, 'name'):
module.name = name
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-encoder-layers",
type=str,
default="4,4,4,4,4,4",
help="Number of zipformer encoder layers per stack, comma separated.",
default="2,4,3,2,4",
help="Number of zipformer encoder layers, comma separated.",
)
parser.add_argument(
"--feedforward-dims",
type=str,
default="1024,1024,2048,2048,1024",
help="Feedforward dimension of the zipformer encoder layers, comma separated.",
)
parser.add_argument(
"--downsampling-factor",
"--nhead",
type=str,
default="1,2,4,8,4,2",
default="8,8,8,8,8",
help="Number of attention heads in the zipformer encoder layers.",
)
parser.add_argument(
"--encoder-dims",
type=str,
default="384,384,384,384,384",
help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
)
parser.add_argument(
"--attention-dims",
type=str,
default="192,192,192,192,192",
help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
not the same as embedding dimension.""",
)
parser.add_argument(
"--encoder-unmasked-dims",
type=str,
default="256,256,256,256,256",
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
" worse.",
)
parser.add_argument(
"--zipformer-downsampling-factors",
type=str,
default="1,2,4,8,2",
help="Downsampling factor for each stack of encoder layers.",
)
parser.add_argument(
"--feedforward-dim",
"--cnn-module-kernels",
type=str,
default="1792,1792,2304,2304,2304,1792",
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
)
parser.add_argument(
"--num-heads",
type=str,
default="8,8,8,16,8,8",
help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
)
parser.add_argument(
"--attention-share-layers",
type=str,
default="2",
help="Number of layers that share attention weights within each zipformer stack: a single int or comma-separated list.",
)
parser.add_argument(
"--encoder-dim",
type=str,
default="384",
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
)
parser.add_argument(
"--query-head-dim",
type=str,
default="32",
help="Query/key dimension per head in encoder stacks: a single int or comma-separated list."
)
parser.add_argument(
"--value-head-dim",
type=str,
default="12",
help="Value dimension per head in encoder stacks: a single int or comma-separated list."
)
parser.add_argument(
"--pos-head-dim",
type=str,
default="4",
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list."
)
parser.add_argument(
"--pos-dim",
type=int,
default="48",
help="Positional-encoding embedding dimension"
)
parser.add_argument(
"--encoder-unmasked-dim",
type=str,
default="256",
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"A single int or comma-separated list. Must be <= each corresponding encoder_dim."
)
parser.add_argument(
"--cnn-module-kernel",
type=str,
default="31",
help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
"a single int or comma-separated list.",
default="31,31,31,31,31",
help="Sizes of kernels in convolution modules",
)
parser.add_argument(
@ -289,10 +244,7 @@ def get_parser():
)
parser.add_argument(
"--base-lr",
type=float,
default=0.05,
help="The base learning rate."
"--base-lr", type=float, default=0.05, help="The base learning rate."
)
parser.add_argument(
@ -311,21 +263,11 @@ def get_parser():
""",
)
parser.add_argument(
"--ref-duration",
type=float,
default=600,
help="Reference batch duration for purposes of adjusting batch counts for setting various "
"schedules inside the model"
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
@ -348,8 +290,7 @@ def get_parser():
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
help="The scale to smooth the loss with am (output of encoder network) part.",
)
parser.add_argument(
@ -386,7 +327,7 @@ def get_parser():
parser.add_argument(
"--save-every-n",
type=int,
default=4000,
default=2000,
help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename
@ -501,24 +442,21 @@ def get_params() -> AttributeDict:
def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Zipformer and Transformer
def to_int_tuple(s: str):
return tuple(map(int, s.split(',')))
return tuple(map(int, s.split(",")))
encoder = Zipformer(
num_features=params.feature_dim,
output_downsampling_factor=2,
downsampling_factor=to_int_tuple(params.downsampling_factor),
zipformer_downsampling_factors=to_int_tuple(
params.zipformer_downsampling_factors
),
encoder_dims=to_int_tuple(params.encoder_dims),
attention_dim=to_int_tuple(params.attention_dims),
encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
nhead=to_int_tuple(params.nhead),
feedforward_dim=to_int_tuple(params.feedforward_dims),
cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
num_encoder_layers=to_int_tuple(params.num_encoder_layers),
encoder_dim=to_int_tuple(params.encoder_dim),
encoder_unmasked_dim=to_int_tuple(params.encoder_unmasked_dim),
query_head_dim=to_int_tuple(params.query_head_dim),
pos_head_dim=to_int_tuple(params.pos_head_dim),
value_head_dim=to_int_tuple(params.value_head_dim),
pos_dim=params.pos_dim,
num_heads=to_int_tuple(params.num_heads),
attention_share_layers=to_int_tuple(params.attention_share_layers),
feedforward_dim=to_int_tuple(params.feedforward_dim),
cnn_module_kernel=to_int_tuple(params.cnn_module_kernel),
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
warmup_batches=4000.0,
)
return encoder
@ -535,7 +473,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
encoder_dim=int(params.encoder_dim.split(',')[-1]),
encoder_dim=int(params.encoder_dims.split(",")[-1]),
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
@ -552,7 +490,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder=encoder,
decoder=decoder,
joiner=joiner,
encoder_dim=int(params.encoder_dim.split(',')[-1]),
encoder_dim=int(params.encoder_dims.split(",")[-1]),
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
@ -687,7 +625,7 @@ def compute_loss(
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
Compute transducer loss given the model and its inputs.
Args:
params:
@ -704,11 +642,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
@ -738,27 +672,24 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s if batch_idx_train >= warm_step
s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step
1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss = (
simple_loss_scale * simple_loss +
pruned_loss_scale * pruned_loss
)
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -853,22 +784,7 @@ def train_one_epoch(
cur_batch_idx = params.get("cur_batch_idx", 0)
saved_bad_model = False
def save_bad_model(suffix: str = ""):
save_checkpoint_impl(filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
model=model,
model_avg=model_avg,
params=params,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=0)
for batch_idx, batch in enumerate(train_dl):
if batch_idx % 10 == 0:
set_batch_count(model, get_adjusted_batch_count(params))
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
@ -891,13 +807,13 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
set_batch_count(model, params.batch_idx_train)
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
except: # noqa
save_bad_model()
display_and_save_batch(batch, params=params, sp=sp)
raise
@ -944,17 +860,14 @@ def train_one_epoch(
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
save_bad_model(suffix="-first-warning")
saved_bad_model = True
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}")
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
@ -964,8 +877,8 @@ def train_one_epoch(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, " +
(f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
f"lr: {cur_lr:.2e}, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
)
if tb_writer is not None:
@ -976,16 +889,14 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
"train/grad_scale",
cur_grad_scale,
params.batch_idx_train,
)
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
@ -997,7 +908,9 @@ def train_one_epoch(
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
@ -1024,6 +937,8 @@ def run(rank, world_size, args):
"""
params = get_params()
params.update(vars(args))
if params.full_libri is False:
params.valid_interval = 1600
fix_random_seed(params.seed)
if world_size > 1:
@ -1071,14 +986,12 @@ def run(rank, world_size, args):
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank],
find_unused_parameters=True)
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = ScaledAdam(
model.parameters(),
model.named_parameters(),
lr=params.base_lr,
clipping_scale=2.0,
parameters_names=[ [p[0] for p in model.named_parameters()] ],
)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
@ -1097,7 +1010,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
2**22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
@ -1120,7 +1033,33 @@ def run(rank, world_size, args):
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
if c.duration < 1.0 or c.duration > 20.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 7) // 2 + 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)
@ -1148,8 +1087,7 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=params.use_fp16,
init_scale=1.0)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1270,7 +1208,9 @@ def scan_pessimistic_batches_for_oom(
)
display_and_save_batch(batch, params=params, sp=sp)
raise
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
def main():