take a couple files from liyong's branch
This commit is contained in:
parent
f688066517
commit
096ebeaf23
@ -18,7 +18,7 @@ import contextlib
|
|||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
@ -132,6 +132,9 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
|
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
|
||||||
|
Unlike common optimizers, which accept model.parameters() or groups of parameters(),
|
||||||
|
this optimizer could accept model.named_parameters() or groups of named_parameters().
|
||||||
|
See comments of function _get_names_of_parameters for its 4 possible cases.
|
||||||
lr: The learning rate. We will typically use a learning rate schedule that starts
|
lr: The learning rate. We will typically use a learning rate schedule that starts
|
||||||
at 0.03 and decreases over time, i.e. much higher than other common
|
at 0.03 and decreases over time, i.e. much higher than other common
|
||||||
optimizers.
|
optimizers.
|
||||||
@ -178,15 +181,8 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
scalar_max=10.0,
|
scalar_max=10.0,
|
||||||
size_update_period=4,
|
size_update_period=4,
|
||||||
clipping_update_period=100,
|
clipping_update_period=100,
|
||||||
parameters_names=None,
|
|
||||||
show_dominant_parameters=True,
|
|
||||||
):
|
):
|
||||||
|
|
||||||
assert parameters_names is not None, (
|
|
||||||
"Please prepare parameters_names,"
|
|
||||||
"which is a List[List[str]]. Each List[str] is for a group"
|
|
||||||
"and each str is for a parameter"
|
|
||||||
)
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr,
|
lr=lr,
|
||||||
clipping_scale=clipping_scale,
|
clipping_scale=clipping_scale,
|
||||||
@ -200,10 +196,135 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
clipping_update_period=clipping_update_period,
|
clipping_update_period=clipping_update_period,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If params only contains parameters or group of parameters,
|
||||||
|
# i.e when parameter names are not given,
|
||||||
|
# this flag will be set to False in funciton _get_names_of_parameters.
|
||||||
|
self.show_dominant_parameters = True
|
||||||
|
params, parameters_names = self._get_names_of_parameters(params)
|
||||||
super(ScaledAdam, self).__init__(params, defaults)
|
super(ScaledAdam, self).__init__(params, defaults)
|
||||||
assert len(self.param_groups) == len(parameters_names)
|
assert len(self.param_groups) == len(parameters_names)
|
||||||
self.parameters_names = parameters_names
|
self.parameters_names = parameters_names
|
||||||
self.show_dominant_parameters = show_dominant_parameters
|
|
||||||
|
def _get_names_of_parameters(
|
||||||
|
self, params_or_named_params
|
||||||
|
) -> Tuple[List[Dict], List[List[str]]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
params_or_named_params: according to the way ScaledAdam is initialized in train.py,
|
||||||
|
this argument could be one of following 4 cases,
|
||||||
|
case 1, a generator of parameter, e.g.:
|
||||||
|
optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0)
|
||||||
|
|
||||||
|
case 2, a list of parameter groups with different config, e.g.:
|
||||||
|
model_param_groups = [
|
||||||
|
{'params': model.encoder.parameters(), 'lr': 0.05},
|
||||||
|
{'params': model.decoder.parameters(), 'lr': 0.01},
|
||||||
|
{'params': model.joiner.parameters(), 'lr': 0.03},
|
||||||
|
]
|
||||||
|
optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0)
|
||||||
|
|
||||||
|
case 3, a generator of named_parameter, e.g.:
|
||||||
|
optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0)
|
||||||
|
|
||||||
|
case 4, a list of named_parameter groups with different config, e.g.:
|
||||||
|
model_named_param_groups = [
|
||||||
|
{'params': model.encoder.named_parameters(), 'lr': 0.05},
|
||||||
|
{'params': model.decoder.named_parameters(), 'lr': 0.01},
|
||||||
|
{'params': model.joiner.named_parameters(), 'lr': 0.03},
|
||||||
|
]
|
||||||
|
optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0)
|
||||||
|
|
||||||
|
For case 1 and case 2, input params is used to initialize the underlying torch.optimizer.
|
||||||
|
For case 3 and case 4, firstly, names and params are extracted from input named_params,
|
||||||
|
then, these extracted params are used to initialize the underlying torch.optimizer,
|
||||||
|
and these extracted names are mainly used by function
|
||||||
|
`_show_gradient_dominating_parameter`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Returns a tuple containing 2 elements:
|
||||||
|
- `param_groups` with type List[Dict], each Dict element is a parameter group.
|
||||||
|
An example of `param_groups` could be:
|
||||||
|
[
|
||||||
|
{'params': `one iterable of Parameter`, 'lr': 0.05},
|
||||||
|
{'params': `another iterable of Parameter`, 'lr': 0.08},
|
||||||
|
{'params': `a third iterable of Parameter`, 'lr': 0.1},
|
||||||
|
]
|
||||||
|
- `param_gruops_names` with type List[List[str]],
|
||||||
|
each `List[str]` is for a group['params'] in param_groups,
|
||||||
|
and each `str` is the name of a parameter.
|
||||||
|
A dummy name "foo" is related to each parameter,
|
||||||
|
if input are params without names, i.e. case 1 or case 2.
|
||||||
|
"""
|
||||||
|
# variable naming convention in this function:
|
||||||
|
# p is short for param.
|
||||||
|
# np is short for named_param.
|
||||||
|
# p_or_np is short for param_or_named_param.
|
||||||
|
# cur is short for current.
|
||||||
|
# group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}.
|
||||||
|
# groups is a List[group]
|
||||||
|
|
||||||
|
iterable_or_groups = list(params_or_named_params)
|
||||||
|
if len(iterable_or_groups) == 0:
|
||||||
|
raise ValueError("optimizer got an empty parameter list")
|
||||||
|
|
||||||
|
# The first value of returned tuple.
|
||||||
|
param_groups = []
|
||||||
|
|
||||||
|
# The second value of returned tuple,
|
||||||
|
# a List[List[str]], each sub-List is for a group.
|
||||||
|
param_groups_names = []
|
||||||
|
|
||||||
|
if not isinstance(iterable_or_groups[0], dict):
|
||||||
|
# case 1 or case 3,
|
||||||
|
# the input is an iterable of parameter or named parameter.
|
||||||
|
param_iterable_cur_group = []
|
||||||
|
param_names_cur_group = []
|
||||||
|
for p_or_np in iterable_or_groups:
|
||||||
|
if isinstance(p_or_np, tuple):
|
||||||
|
# case 3
|
||||||
|
name, param = p_or_np
|
||||||
|
else:
|
||||||
|
# case 1
|
||||||
|
assert isinstance(p_or_np, torch.Tensor)
|
||||||
|
param = p_or_np
|
||||||
|
# Assign a dummy name as a placeholder
|
||||||
|
name = "foo"
|
||||||
|
self.show_dominant_parameters = False
|
||||||
|
param_iterable_cur_group.append(param)
|
||||||
|
param_names_cur_group.append(name)
|
||||||
|
param_groups.append({"params": param_iterable_cur_group})
|
||||||
|
param_groups_names.append(param_names_cur_group)
|
||||||
|
else:
|
||||||
|
# case 2 or case 4
|
||||||
|
# the input is groups of parameter or named parameter.
|
||||||
|
for p_or_np_cur_group in iterable_or_groups:
|
||||||
|
param_iterable_cur_group = []
|
||||||
|
param_names_cur_group = []
|
||||||
|
p_or_np_iterable = p_or_np_cur_group["params"]
|
||||||
|
for p_or_np in p_or_np_iterable:
|
||||||
|
if isinstance(p_or_np, tuple):
|
||||||
|
# case 4
|
||||||
|
name, param = p_or_np
|
||||||
|
else:
|
||||||
|
# case 2
|
||||||
|
assert isinstance(p_or_np, torch.Tensor)
|
||||||
|
param = p_or_np
|
||||||
|
# Assign a dummy name as a placeholder
|
||||||
|
name = "foo"
|
||||||
|
self.show_dominant_parameters = False
|
||||||
|
param_iterable_cur_group.append(param)
|
||||||
|
param_names_cur_group.append(name)
|
||||||
|
|
||||||
|
# The original `params` filed contains named_parameters.
|
||||||
|
# After following assignment,
|
||||||
|
# it will be changed to an iterable of parameter,
|
||||||
|
# and other fileds, if exist, are still original values.
|
||||||
|
# So param_groups could be used to initialize
|
||||||
|
# an underlying torch.Optimizer later.
|
||||||
|
p_or_np_cur_group["params"] = param_iterable_cur_group
|
||||||
|
param_groups.append(p_or_np_cur_group)
|
||||||
|
param_groups_names.append(param_names_cur_group)
|
||||||
|
return param_groups, param_groups_names
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
super(ScaledAdam, self).__setstate__(state)
|
super(ScaledAdam, self).__setstate__(state)
|
||||||
@ -398,7 +519,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
|
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Show information of parameter wihch dominanting tot_sumsq.
|
Show information of parameter wihch dominating tot_sumsq.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tuples: a list of tuples of (param, state, param_names)
|
tuples: a list of tuples of (param, state, param_names)
|
||||||
@ -416,7 +537,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
batch_grad = p.grad
|
batch_grad = p.grad
|
||||||
if p.numel() == p.shape[0]: # a batch of scalars
|
if p.numel() == p.shape[0]: # a batch of scalars
|
||||||
batch_sumsq_orig = batch_grad**2
|
batch_sumsq_orig = batch_grad**2
|
||||||
# Dummpy values used by following `zip` statement.
|
# Dummy values used by following `zip` statement.
|
||||||
batch_rms_orig = torch.ones(p.shape[0])
|
batch_rms_orig = torch.ones(p.shape[0])
|
||||||
else:
|
else:
|
||||||
batch_rms_orig = state["param_rms"]
|
batch_rms_orig = state["param_rms"]
|
||||||
@ -449,11 +570,11 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
dominant_grad,
|
dominant_grad,
|
||||||
) = sorted_by_proportion[dominant_param_name]
|
) = sorted_by_proportion[dominant_param_name]
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
f"Parameter dominating tot_sumsq {dominant_param_name}"
|
||||||
f" with proportion {dominant_proportion:.2f},"
|
f" with proportion {dominant_proportion:.2f},"
|
||||||
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
||||||
f"={dominant_sumsq:.3e},"
|
f"={dominant_sumsq:.3e},"
|
||||||
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
|
f" grad_sumsq={(dominant_grad**2).sum():.3e},"
|
||||||
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
|
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -561,11 +682,8 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
|
|
||||||
# when the param gets too small, just don't shrink it any further.
|
# when the param gets too small, just don't shrink it any further.
|
||||||
scale_step.masked_fill_(is_too_small, 0.0)
|
scale_step.masked_fill_(is_too_small, 0.0)
|
||||||
|
# when it gets too large, stop it from getting any larger.
|
||||||
# and ensure the parameter rms after update never exceeds param_max_rms.
|
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
|
||||||
scale_step = torch.minimum(scale_step,
|
|
||||||
(param_max_rms - param_rms) / param_rms)
|
|
||||||
|
|
||||||
delta = state["delta"]
|
delta = state["delta"]
|
||||||
# the factor of (1-beta1) relates to momentum.
|
# the factor of (1-beta1) relates to momentum.
|
||||||
delta.add_(p * scale_step, alpha=(1 - beta1))
|
delta.add_(p * scale_step, alpha=(1 - beta1))
|
||||||
@ -779,9 +897,7 @@ class Eden(LRScheduler):
|
|||||||
|
|
||||||
def _test_eden():
|
def _test_eden():
|
||||||
m = torch.nn.Linear(100, 100)
|
m = torch.nn.Linear(100, 100)
|
||||||
parameters_names = [ [ x[0] for x in m.named_parameters() ] ]
|
optim = ScaledAdam(m.parameters(), lr=0.03)
|
||||||
optim = ScaledAdam(m.parameters(), lr=0.03,
|
|
||||||
parameters_names=parameters_names)
|
|
||||||
|
|
||||||
scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
|
scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
|
||||||
|
|
||||||
@ -992,9 +1108,7 @@ def _test_scaled_adam(hidden_dim: int):
|
|||||||
if iter == 0:
|
if iter == 0:
|
||||||
optim = Eve(m.parameters(), lr=0.003)
|
optim = Eve(m.parameters(), lr=0.003)
|
||||||
elif iter == 1:
|
elif iter == 1:
|
||||||
parameters_names = [ [ x[0] for x in m.named_parameters() ] ]
|
optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
|
||||||
optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0,
|
|
||||||
parameters_names=parameters_names)
|
|
||||||
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
||||||
|
|
||||||
start = timeit.default_timer()
|
start = timeit.default_timer()
|
||||||
|
|||||||
@ -59,8 +59,6 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from zipformer import Zipformer
|
|
||||||
from scaling import ScheduledFloat
|
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
@ -72,6 +70,7 @@ from torch import Tensor
|
|||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from zipformer import Zipformer
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||||
@ -80,125 +79,81 @@ from icefall.checkpoint import (
|
|||||||
save_checkpoint_with_global_batch_idx,
|
save_checkpoint_with_global_batch_idx,
|
||||||
update_averaged_model,
|
update_averaged_model,
|
||||||
)
|
)
|
||||||
from icefall.hooks import register_inf_check_hooks
|
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
LRSchedulerType = Union[
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_adjusted_batch_count(
|
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||||
params: AttributeDict) -> float:
|
|
||||||
# returns the number of batches we would have used so far if we had used the reference
|
|
||||||
# duration. This is for purposes of set_batch_count().
|
|
||||||
return (params.batch_idx_train * params.ref_duration /
|
|
||||||
(params.max_duration * params.world_size))
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def set_batch_count(
|
|
||||||
model: Union[nn.Module, DDP], batch_count: float
|
|
||||||
) -> None:
|
|
||||||
if isinstance(model, DDP):
|
if isinstance(model, DDP):
|
||||||
# get underlying nn.Module
|
# get underlying nn.Module
|
||||||
model = model.module
|
model = model.module
|
||||||
for name, module in model.named_modules():
|
for module in model.modules():
|
||||||
if hasattr(module, 'batch_count'):
|
if hasattr(module, "batch_count"):
|
||||||
module.batch_count = batch_count
|
module.batch_count = batch_count
|
||||||
if hasattr(module, 'name'):
|
|
||||||
module.name = name
|
|
||||||
|
|
||||||
|
|
||||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-encoder-layers",
|
"--num-encoder-layers",
|
||||||
type=str,
|
type=str,
|
||||||
default="4,4,4,4,4,4",
|
default="2,4,3,2,4",
|
||||||
help="Number of zipformer encoder layers per stack, comma separated.",
|
help="Number of zipformer encoder layers, comma separated.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--feedforward-dims",
|
||||||
|
type=str,
|
||||||
|
default="1024,1024,2048,2048,1024",
|
||||||
|
help="Feedforward dimension of the zipformer encoder layers, comma separated.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--downsampling-factor",
|
"--nhead",
|
||||||
type=str,
|
type=str,
|
||||||
default="1,2,4,8,4,2",
|
default="8,8,8,8,8",
|
||||||
|
help="Number of attention heads in the zipformer encoder layers.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-dims",
|
||||||
|
type=str,
|
||||||
|
default="384,384,384,384,384",
|
||||||
|
help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--attention-dims",
|
||||||
|
type=str,
|
||||||
|
default="192,192,192,192,192",
|
||||||
|
help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
|
||||||
|
not the same as embedding dimension.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-unmasked-dims",
|
||||||
|
type=str,
|
||||||
|
default="256,256,256,256,256",
|
||||||
|
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
||||||
|
"Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
|
||||||
|
" worse.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--zipformer-downsampling-factors",
|
||||||
|
type=str,
|
||||||
|
default="1,2,4,8,2",
|
||||||
help="Downsampling factor for each stack of encoder layers.",
|
help="Downsampling factor for each stack of encoder layers.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--feedforward-dim",
|
"--cnn-module-kernels",
|
||||||
type=str,
|
type=str,
|
||||||
default="1792,1792,2304,2304,2304,1792",
|
default="31,31,31,31,31",
|
||||||
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
|
help="Sizes of kernels in convolution modules",
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-heads",
|
|
||||||
type=str,
|
|
||||||
default="8,8,8,16,8,8",
|
|
||||||
help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--attention-share-layers",
|
|
||||||
type=str,
|
|
||||||
default="2",
|
|
||||||
help="Number of layers that share attention weights within each zipformer stack: a single int or comma-separated list.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--encoder-dim",
|
|
||||||
type=str,
|
|
||||||
default="384",
|
|
||||||
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--query-head-dim",
|
|
||||||
type=str,
|
|
||||||
default="32",
|
|
||||||
help="Query/key dimension per head in encoder stacks: a single int or comma-separated list."
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--value-head-dim",
|
|
||||||
type=str,
|
|
||||||
default="12",
|
|
||||||
help="Value dimension per head in encoder stacks: a single int or comma-separated list."
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--pos-head-dim",
|
|
||||||
type=str,
|
|
||||||
default="4",
|
|
||||||
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list."
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--pos-dim",
|
|
||||||
type=int,
|
|
||||||
default="48",
|
|
||||||
help="Positional-encoding embedding dimension"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--encoder-unmasked-dim",
|
|
||||||
type=str,
|
|
||||||
default="256",
|
|
||||||
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
|
||||||
"A single int or comma-separated list. Must be <= each corresponding encoder_dim."
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--cnn-module-kernel",
|
|
||||||
type=str,
|
|
||||||
default="31",
|
|
||||||
help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
|
|
||||||
"a single int or comma-separated list.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -289,10 +244,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base-lr",
|
"--base-lr", type=float, default=0.05, help="The base learning rate."
|
||||||
type=float,
|
|
||||||
default=0.05,
|
|
||||||
help="The base learning rate."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -311,21 +263,11 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--ref-duration",
|
|
||||||
type=float,
|
|
||||||
default=600,
|
|
||||||
help="Reference batch duration for purposes of adjusting batch counts for setting various "
|
|
||||||
"schedules inside the model"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; "
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
"2 means tri-gram",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -348,8 +290,7 @@ def get_parser():
|
|||||||
"--am-scale",
|
"--am-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.0,
|
||||||
help="The scale to smooth the loss with am (output of encoder network)"
|
help="The scale to smooth the loss with am (output of encoder network) part.",
|
||||||
"part.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -386,7 +327,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-every-n",
|
"--save-every-n",
|
||||||
type=int,
|
type=int,
|
||||||
default=4000,
|
default=2000,
|
||||||
help="""Save checkpoint after processing this number of batches"
|
help="""Save checkpoint after processing this number of batches"
|
||||||
periodically. We save checkpoint to exp-dir/ whenever
|
periodically. We save checkpoint to exp-dir/ whenever
|
||||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||||
@ -501,24 +442,21 @@ def get_params() -> AttributeDict:
|
|||||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
# TODO: We can add an option to switch between Zipformer and Transformer
|
# TODO: We can add an option to switch between Zipformer and Transformer
|
||||||
def to_int_tuple(s: str):
|
def to_int_tuple(s: str):
|
||||||
return tuple(map(int, s.split(',')))
|
return tuple(map(int, s.split(",")))
|
||||||
|
|
||||||
encoder = Zipformer(
|
encoder = Zipformer(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
output_downsampling_factor=2,
|
output_downsampling_factor=2,
|
||||||
downsampling_factor=to_int_tuple(params.downsampling_factor),
|
zipformer_downsampling_factors=to_int_tuple(
|
||||||
|
params.zipformer_downsampling_factors
|
||||||
|
),
|
||||||
|
encoder_dims=to_int_tuple(params.encoder_dims),
|
||||||
|
attention_dim=to_int_tuple(params.attention_dims),
|
||||||
|
encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
|
||||||
|
nhead=to_int_tuple(params.nhead),
|
||||||
|
feedforward_dim=to_int_tuple(params.feedforward_dims),
|
||||||
|
cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
|
||||||
num_encoder_layers=to_int_tuple(params.num_encoder_layers),
|
num_encoder_layers=to_int_tuple(params.num_encoder_layers),
|
||||||
encoder_dim=to_int_tuple(params.encoder_dim),
|
|
||||||
encoder_unmasked_dim=to_int_tuple(params.encoder_unmasked_dim),
|
|
||||||
query_head_dim=to_int_tuple(params.query_head_dim),
|
|
||||||
pos_head_dim=to_int_tuple(params.pos_head_dim),
|
|
||||||
value_head_dim=to_int_tuple(params.value_head_dim),
|
|
||||||
pos_dim=params.pos_dim,
|
|
||||||
num_heads=to_int_tuple(params.num_heads),
|
|
||||||
attention_share_layers=to_int_tuple(params.attention_share_layers),
|
|
||||||
feedforward_dim=to_int_tuple(params.feedforward_dim),
|
|
||||||
cnn_module_kernel=to_int_tuple(params.cnn_module_kernel),
|
|
||||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
|
||||||
warmup_batches=4000.0,
|
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -535,7 +473,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||||
joiner = Joiner(
|
joiner = Joiner(
|
||||||
encoder_dim=int(params.encoder_dim.split(',')[-1]),
|
encoder_dim=int(params.encoder_dims.split(",")[-1]),
|
||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
joiner_dim=params.joiner_dim,
|
joiner_dim=params.joiner_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
@ -552,7 +490,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
encoder_dim=int(params.encoder_dim.split(',')[-1]),
|
encoder_dim=int(params.encoder_dims.split(",")[-1]),
|
||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
joiner_dim=params.joiner_dim,
|
joiner_dim=params.joiner_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
@ -687,7 +625,7 @@ def compute_loss(
|
|||||||
is_training: bool,
|
is_training: bool,
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute CTC loss given the model and its inputs.
|
Compute transducer loss given the model and its inputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params:
|
params:
|
||||||
@ -704,11 +642,7 @@ def compute_loss(
|
|||||||
warmup: a floating point value which increases throughout training;
|
warmup: a floating point value which increases throughout training;
|
||||||
values >= 1.0 are fully warmed up and have all modules present.
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
"""
|
"""
|
||||||
device = (
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
model.device
|
|
||||||
if isinstance(model, DDP)
|
|
||||||
else next(model.parameters()).device
|
|
||||||
)
|
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
@ -738,27 +672,24 @@ def compute_loss(
|
|||||||
# take down the scale on the simple loss from 1.0 at the start
|
# take down the scale on the simple loss from 1.0 at the start
|
||||||
# to params.simple_loss scale by warm_step.
|
# to params.simple_loss scale by warm_step.
|
||||||
simple_loss_scale = (
|
simple_loss_scale = (
|
||||||
s if batch_idx_train >= warm_step
|
s
|
||||||
|
if batch_idx_train >= warm_step
|
||||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
||||||
)
|
)
|
||||||
pruned_loss_scale = (
|
pruned_loss_scale = (
|
||||||
1.0 if batch_idx_train >= warm_step
|
1.0
|
||||||
|
if batch_idx_train >= warm_step
|
||||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = (
|
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||||
simple_loss_scale * simple_loss +
|
|
||||||
pruned_loss_scale * pruned_loss
|
|
||||||
)
|
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
info["frames"] = (
|
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||||
(feature_lens // params.subsampling_factor).sum().item()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
@ -853,22 +784,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||||
|
|
||||||
saved_bad_model = False
|
|
||||||
def save_bad_model(suffix: str = ""):
|
|
||||||
save_checkpoint_impl(filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
|
||||||
model=model,
|
|
||||||
model_avg=model_avg,
|
|
||||||
params=params,
|
|
||||||
optimizer=optimizer,
|
|
||||||
scheduler=scheduler,
|
|
||||||
sampler=train_dl.sampler,
|
|
||||||
scaler=scaler,
|
|
||||||
rank=0)
|
|
||||||
|
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
if batch_idx % 10 == 0:
|
|
||||||
set_batch_count(model, get_adjusted_batch_count(params))
|
|
||||||
if batch_idx < cur_batch_idx:
|
if batch_idx < cur_batch_idx:
|
||||||
continue
|
continue
|
||||||
cur_batch_idx = batch_idx
|
cur_batch_idx = batch_idx
|
||||||
@ -891,13 +807,13 @@ def train_one_epoch(
|
|||||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||||
# in the batch and there is no normalization to it so far.
|
# in the batch and there is no normalization to it so far.
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
|
set_batch_count(model, params.batch_idx_train)
|
||||||
scheduler.step_batch(params.batch_idx_train)
|
scheduler.step_batch(params.batch_idx_train)
|
||||||
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
except: # noqa
|
except: # noqa
|
||||||
save_bad_model()
|
|
||||||
display_and_save_batch(batch, params=params, sp=sp)
|
display_and_save_batch(batch, params=params, sp=sp)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -944,17 +860,14 @@ def train_one_epoch(
|
|||||||
# of the grad scaler is configurable, but we can't configure it to have different
|
# of the grad scaler is configurable, but we can't configure it to have different
|
||||||
# behavior depending on the current grad scale.
|
# behavior depending on the current grad scale.
|
||||||
cur_grad_scale = scaler._scale.item()
|
cur_grad_scale = scaler._scale.item()
|
||||||
|
if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
|
||||||
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
|
|
||||||
scaler.update(cur_grad_scale * 2.0)
|
scaler.update(cur_grad_scale * 2.0)
|
||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
if not saved_bad_model:
|
|
||||||
save_bad_model(suffix="-first-warning")
|
|
||||||
saved_bad_model = True
|
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
save_bad_model()
|
raise RuntimeError(
|
||||||
raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}")
|
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||||
|
)
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
@ -964,8 +877,8 @@ def train_one_epoch(
|
|||||||
f"Epoch {params.cur_epoch}, "
|
f"Epoch {params.cur_epoch}, "
|
||||||
f"batch {batch_idx}, loss[{loss_info}], "
|
f"batch {batch_idx}, loss[{loss_info}], "
|
||||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||||
f"lr: {cur_lr:.2e}, " +
|
f"lr: {cur_lr:.2e}, "
|
||||||
(f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||||
)
|
)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
@ -976,16 +889,14 @@ def train_one_epoch(
|
|||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
tot_loss.write_summary(
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
tb_writer, "train/tot_", params.batch_idx_train
|
|
||||||
)
|
|
||||||
if params.use_fp16:
|
if params.use_fp16:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
"train/grad_scale",
|
||||||
|
cur_grad_scale,
|
||||||
|
params.batch_idx_train,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
valid_info = compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
@ -997,7 +908,9 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
model.train()
|
model.train()
|
||||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||||
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
|
logging.info(
|
||||||
|
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||||
|
)
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
@ -1024,6 +937,8 @@ def run(rank, world_size, args):
|
|||||||
"""
|
"""
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
if params.full_libri is False:
|
||||||
|
params.valid_interval = 1600
|
||||||
|
|
||||||
fix_random_seed(params.seed)
|
fix_random_seed(params.seed)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
@ -1071,14 +986,12 @@ def run(rank, world_size, args):
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
model = DDP(model, device_ids=[rank],
|
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||||
find_unused_parameters=True)
|
|
||||||
|
|
||||||
optimizer = ScaledAdam(
|
optimizer = ScaledAdam(
|
||||||
model.parameters(),
|
model.named_parameters(),
|
||||||
lr=params.base_lr,
|
lr=params.base_lr,
|
||||||
clipping_scale=2.0,
|
clipping_scale=2.0,
|
||||||
parameters_names=[ [p[0] for p in model.named_parameters()] ],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||||
@ -1097,7 +1010,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2 ** 22
|
2**22
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
@ -1120,7 +1033,33 @@ def run(rank, world_size, args):
|
|||||||
# You should use ../local/display_manifest_statistics.py to get
|
# You should use ../local/display_manifest_statistics.py to get
|
||||||
# an utterance duration distribution for your dataset to select
|
# an utterance duration distribution for your dataset to select
|
||||||
# the threshold
|
# the threshold
|
||||||
return 1.0 <= c.duration <= 20.0
|
if c.duration < 1.0 or c.duration > 20.0:
|
||||||
|
logging.warning(
|
||||||
|
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# In pruned RNN-T, we require that T >= S
|
||||||
|
# where T is the number of feature frames after subsampling
|
||||||
|
# and S is the number of tokens in the utterance
|
||||||
|
|
||||||
|
# In ./zipformer.py, the conv module uses the following expression
|
||||||
|
# for subsampling
|
||||||
|
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||||
|
tokens = sp.encode(c.supervisions[0].text, out_type=str)
|
||||||
|
|
||||||
|
if T < len(tokens):
|
||||||
|
logging.warning(
|
||||||
|
f"Exclude cut with ID {c.id} from training. "
|
||||||
|
f"Number of frames (before subsampling): {c.num_frames}. "
|
||||||
|
f"Number of frames (after subsampling): {T}. "
|
||||||
|
f"Text: {c.supervisions[0].text}. "
|
||||||
|
f"Tokens: {tokens}. "
|
||||||
|
f"Number of tokens: {len(tokens)}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
@ -1148,8 +1087,7 @@ def run(rank, world_size, args):
|
|||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16,
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
init_scale=1.0)
|
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1270,7 +1208,9 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
)
|
)
|
||||||
display_and_save_batch(batch, params=params, sp=sp)
|
display_and_save_batch(batch, params=params, sp=sp)
|
||||||
raise
|
raise
|
||||||
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
|
logging.info(
|
||||||
|
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user