update train.py

This commit is contained in:
Fangjun Kuang 2022-11-14 16:26:53 +08:00
parent ab38f4a926
commit 1d494556fc

View File

@ -27,8 +27,8 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--num-epochs 30 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless7/exp \
--full-libri 1 \
--max-duration 300
--max-duration 750 \
--training-subset L
# For mix precision training:
@ -38,9 +38,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--start-epoch 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless7/exp \
--full-libri 1 \
--max-duration 550
--max-duration 750
"""
@ -54,12 +52,10 @@ from typing import Any, Dict, Optional, Tuple, Union
import k2
import optim
import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from zipformer import Zipformer
from asr_datamodule import WenetSpeechAsrDataModule
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
@ -71,17 +67,20 @@ from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
from icefall import diagnostics
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx,
update_averaged_model,
)
from icefall.hooks import register_inf_check_hooks
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[
@ -89,14 +88,12 @@ LRSchedulerType = Union[
]
def set_batch_count(
model: Union[nn.Module, DDP], batch_count: float
) -> None:
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
if isinstance(model, DDP):
# get underlying nn.Module
model = model.module
for module in model.modules():
if hasattr(module, 'batch_count'):
if hasattr(module, "batch_count"):
module.batch_count = batch_count
@ -126,7 +123,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--encoder-dims",
type=str,
default="384,384,384,384,384",
help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated"
help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
)
parser.add_argument(
@ -134,7 +131,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
type=str,
default="192,192,192,192,192",
help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
not the same as embedding dimension."""
not the same as embedding dimension.""",
)
parser.add_argument(
@ -143,7 +140,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
default="256,256,256,256,256",
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
" worse."
" worse.",
)
parser.add_argument(
@ -241,17 +238,17 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--lang-dir",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--base-lr",
type=float,
default=0.05,
help="The base learning rate."
"--base-lr", type=float, default=0.05, help="The base learning rate."
)
parser.add_argument(
@ -451,11 +448,14 @@ def get_params() -> AttributeDict:
def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Zipformer and Transformer
def to_int_tuple(s: str):
return tuple(map(int, s.split(',')))
return tuple(map(int, s.split(",")))
encoder = Zipformer(
num_features=params.feature_dim,
output_downsampling_factor=2,
zipformer_downsampling_factors=to_int_tuple(params.zipformer_downsampling_factors),
zipformer_downsampling_factors=to_int_tuple(
params.zipformer_downsampling_factors
),
encoder_dims=to_int_tuple(params.encoder_dims),
attention_dim=to_int_tuple(params.attention_dims),
encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
@ -479,7 +479,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
encoder_dim=int(params.encoder_dims.split(',')[-1]),
encoder_dim=int(params.encoder_dims.split(",")[-1]),
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
@ -496,7 +496,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder=encoder,
decoder=decoder,
joiner=joiner,
encoder_dim=int(params.encoder_dims.split(',')[-1]),
encoder_dim=int(params.encoder_dims.split(",")[-1]),
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
@ -567,9 +567,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params
@ -626,7 +623,7 @@ def save_checkpoint(
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor,
graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
@ -665,7 +662,8 @@ def compute_loss(
warm_step = params.warm_step
texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y = graph_compiler.texts_to_ids(texts)
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
@ -682,18 +680,17 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s if batch_idx_train >= warm_step
s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step
1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss = (
simple_loss_scale * simple_loss +
pruned_loss_scale * pruned_loss
)
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training
@ -715,7 +712,7 @@ def compute_loss(
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor,
graph_compiler: CharCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
@ -728,7 +725,7 @@ def compute_validation_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
graph_compiler=graph_compiler,
batch=batch,
is_training=False,
)
@ -751,7 +748,7 @@ def train_one_epoch(
model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor,
graph_compiler: CharCtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
@ -795,13 +792,7 @@ def train_one_epoch(
tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
@ -810,7 +801,7 @@ def train_one_epoch(
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
graph_compiler=graph_compiler,
batch=batch,
is_training=True,
)
@ -827,7 +818,7 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
except: # noqa
display_and_save_batch(batch, params=params, sp=sp)
display_and_save_batch(batch, params=params)
raise
if params.print_diagnostics and batch_idx == 5:
@ -848,7 +839,6 @@ def train_one_epoch(
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
@ -861,7 +851,6 @@ def train_one_epoch(
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
@ -873,12 +862,16 @@ def train_one_epoch(
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
if cur_grad_scale < 1.0 or (
cur_grad_scale < 8.0 and batch_idx % 400 == 0
):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}")
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
@ -888,8 +881,12 @@ def train_one_epoch(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, " +
(f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
f"lr: {cur_lr:.2e}, "
+ (
f"grad_scale: {scaler._scale.item()}"
if params.use_fp16
else ""
)
)
if tb_writer is not None:
@ -905,23 +902,28 @@ def train_one_epoch(
)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
"train/grad_scale",
cur_grad_scale,
params.batch_idx_train,
)
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
if (
batch_idx % params.valid_interval == 0
and not params.print_diagnostics
):
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
sp=sp,
graph_compiler=graph_compiler,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
@ -948,8 +950,6 @@ def run(rank, world_size, args):
"""
params = get_params()
params.update(vars(args))
if params.full_libri is False:
params.valid_interval = 1600
fix_random_seed(params.seed)
if world_size > 1:
@ -968,12 +968,14 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
lexicon = Lexicon(params.lang_dir)
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
@ -997,12 +999,11 @@ def run(rank, world_size, args):
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank],
find_unused_parameters=True)
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = ScaledAdam(model.parameters(),
lr=params.base_lr,
clipping_scale=2.0)
optimizer = ScaledAdam(
model.parameters(), lr=params.base_lr, clipping_scale=2.0
)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
@ -1027,26 +1028,26 @@ def run(rank, world_size, args):
if params.inf_check:
register_inf_check_hooks(model)
librispeech = LibriSpeechAsrDataModule(args)
wenetspeech = WenetSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
train_cuts = wenetspeech.train_cuts()
valid_cuts = wenetspeech.valid_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
# Keep only utterances with duration between 1 second and 19 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# Caution: There is a reason to select 19.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
return 1.0 <= c.duration <= 19.0
train_cuts = train_cuts.filter(remove_short_and_long_utt)
valid_dl = wenetspeech.valid_dataloaders(valid_cuts)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
# saved in the middle of an epoch
@ -1054,25 +1055,20 @@ def run(rank, world_size, args):
else:
sampler_state_dict = None
train_dl = librispeech.train_dataloaders(
train_dl = wenetspeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
if False and not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
graph_compiler=graph_compiler,
params=params,
)
scaler = GradScaler(enabled=params.use_fp16,
init_scale=1.0)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1093,7 +1089,7 @@ def run(rank, world_size, args):
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
sp=sp,
graph_compiler=graph_compiler,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
@ -1127,7 +1123,6 @@ def run(rank, world_size, args):
def display_and_save_batch(
batch: dict,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
) -> None:
"""Display the batch statistics and save the batch into disk.
@ -1137,8 +1132,6 @@ def display_and_save_batch(
for the content in it.
params:
Parameters for training. See :func:`get_params`.
sp:
The BPE model.
"""
from lhotse.utils import uuid4
@ -1146,13 +1139,13 @@ def display_and_save_batch(
logging.info(f"Saving batch to {filename}")
torch.save(batch, filename)
supervisions = batch["supervisions"]
features = batch["inputs"]
logging.info(f"features shape: {features.shape}")
y = sp.encode(supervisions["text"], out_type=int)
num_tokens = sum(len(i) for i in y)
texts = batch["supervisions"]["text"]
num_tokens = sum(len(i) for i in texts)
logging.info(f"num tokens: {num_tokens}")
@ -1160,7 +1153,7 @@ def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
graph_compiler: CharCtcTrainingGraphCompiler,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
@ -1176,7 +1169,7 @@ def scan_pessimistic_batches_for_oom(
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
graph_compiler=graph_compiler,
batch=batch,
is_training=True,
)
@ -1191,15 +1184,18 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
display_and_save_batch(batch, params=params, sp=sp)
display_and_save_batch(batch, params=params)
raise
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
WenetSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.lang_dir = Path(args.lang_dir)
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size