mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-11 11:02:29 +00:00
update train.py
This commit is contained in:
parent
ab38f4a926
commit
1d494556fc
@ -27,8 +27,8 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
--exp-dir pruned_transducer_stateless7/exp \
|
--exp-dir pruned_transducer_stateless7/exp \
|
||||||
--full-libri 1 \
|
--max-duration 750 \
|
||||||
--max-duration 300
|
--training-subset L
|
||||||
|
|
||||||
# For mix precision training:
|
# For mix precision training:
|
||||||
|
|
||||||
@ -38,9 +38,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
--use-fp16 1 \
|
--use-fp16 1 \
|
||||||
--exp-dir pruned_transducer_stateless7/exp \
|
--exp-dir pruned_transducer_stateless7/exp \
|
||||||
--full-libri 1 \
|
--max-duration 750
|
||||||
--max-duration 550
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -54,12 +52,10 @@ from typing import Any, Dict, Optional, Tuple, Union
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import optim
|
import optim
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import WenetSpeechAsrDataModule
|
||||||
from zipformer import Zipformer
|
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
@ -71,17 +67,20 @@ from torch import Tensor
|
|||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from zipformer import Zipformer
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
save_checkpoint_with_global_batch_idx,
|
save_checkpoint_with_global_batch_idx,
|
||||||
update_averaged_model,
|
update_averaged_model,
|
||||||
)
|
)
|
||||||
from icefall.hooks import register_inf_check_hooks
|
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.hooks import register_inf_check_hooks
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
LRSchedulerType = Union[
|
LRSchedulerType = Union[
|
||||||
@ -89,14 +88,12 @@ LRSchedulerType = Union[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def set_batch_count(
|
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||||
model: Union[nn.Module, DDP], batch_count: float
|
|
||||||
) -> None:
|
|
||||||
if isinstance(model, DDP):
|
if isinstance(model, DDP):
|
||||||
# get underlying nn.Module
|
# get underlying nn.Module
|
||||||
model = model.module
|
model = model.module
|
||||||
for module in model.modules():
|
for module in model.modules():
|
||||||
if hasattr(module, 'batch_count'):
|
if hasattr(module, "batch_count"):
|
||||||
module.batch_count = batch_count
|
module.batch_count = batch_count
|
||||||
|
|
||||||
|
|
||||||
@ -126,7 +123,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
"--encoder-dims",
|
"--encoder-dims",
|
||||||
type=str,
|
type=str,
|
||||||
default="384,384,384,384,384",
|
default="384,384,384,384,384",
|
||||||
help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated"
|
help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -134,7 +131,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
type=str,
|
type=str,
|
||||||
default="192,192,192,192,192",
|
default="192,192,192,192,192",
|
||||||
help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
|
help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
|
||||||
not the same as embedding dimension."""
|
not the same as embedding dimension.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -143,7 +140,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
default="256,256,256,256,256",
|
default="256,256,256,256,256",
|
||||||
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
||||||
"Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
|
"Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
|
||||||
" worse."
|
" worse.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -241,17 +238,17 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--lang-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_char",
|
||||||
help="Path to the BPE model",
|
help="""The lang dir
|
||||||
|
It contains language related input files such as
|
||||||
|
"lexicon.txt"
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base-lr",
|
"--base-lr", type=float, default=0.05, help="The base learning rate."
|
||||||
type=float,
|
|
||||||
default=0.05,
|
|
||||||
help="The base learning rate."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -451,11 +448,14 @@ def get_params() -> AttributeDict:
|
|||||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
# TODO: We can add an option to switch between Zipformer and Transformer
|
# TODO: We can add an option to switch between Zipformer and Transformer
|
||||||
def to_int_tuple(s: str):
|
def to_int_tuple(s: str):
|
||||||
return tuple(map(int, s.split(',')))
|
return tuple(map(int, s.split(",")))
|
||||||
|
|
||||||
encoder = Zipformer(
|
encoder = Zipformer(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
output_downsampling_factor=2,
|
output_downsampling_factor=2,
|
||||||
zipformer_downsampling_factors=to_int_tuple(params.zipformer_downsampling_factors),
|
zipformer_downsampling_factors=to_int_tuple(
|
||||||
|
params.zipformer_downsampling_factors
|
||||||
|
),
|
||||||
encoder_dims=to_int_tuple(params.encoder_dims),
|
encoder_dims=to_int_tuple(params.encoder_dims),
|
||||||
attention_dim=to_int_tuple(params.attention_dims),
|
attention_dim=to_int_tuple(params.attention_dims),
|
||||||
encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
|
encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
|
||||||
@ -479,7 +479,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||||
joiner = Joiner(
|
joiner = Joiner(
|
||||||
encoder_dim=int(params.encoder_dims.split(',')[-1]),
|
encoder_dim=int(params.encoder_dims.split(",")[-1]),
|
||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
joiner_dim=params.joiner_dim,
|
joiner_dim=params.joiner_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
@ -496,7 +496,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
encoder_dim=int(params.encoder_dims.split(',')[-1]),
|
encoder_dim=int(params.encoder_dims.split(",")[-1]),
|
||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
joiner_dim=params.joiner_dim,
|
joiner_dim=params.joiner_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
@ -567,9 +567,6 @@ def load_checkpoint_if_available(
|
|||||||
if "cur_epoch" in saved_params:
|
if "cur_epoch" in saved_params:
|
||||||
params["start_epoch"] = saved_params["cur_epoch"]
|
params["start_epoch"] = saved_params["cur_epoch"]
|
||||||
|
|
||||||
if "cur_batch_idx" in saved_params:
|
|
||||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
|
||||||
|
|
||||||
return saved_params
|
return saved_params
|
||||||
|
|
||||||
|
|
||||||
@ -626,7 +623,7 @@ def save_checkpoint(
|
|||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
sp: spm.SentencePieceProcessor,
|
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
@ -665,7 +662,8 @@ def compute_loss(
|
|||||||
warm_step = params.warm_step
|
warm_step = params.warm_step
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
y = sp.encode(texts, out_type=int)
|
|
||||||
|
y = graph_compiler.texts_to_ids(texts)
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
@ -682,18 +680,17 @@ def compute_loss(
|
|||||||
# take down the scale on the simple loss from 1.0 at the start
|
# take down the scale on the simple loss from 1.0 at the start
|
||||||
# to params.simple_loss scale by warm_step.
|
# to params.simple_loss scale by warm_step.
|
||||||
simple_loss_scale = (
|
simple_loss_scale = (
|
||||||
s if batch_idx_train >= warm_step
|
s
|
||||||
|
if batch_idx_train >= warm_step
|
||||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
||||||
)
|
)
|
||||||
pruned_loss_scale = (
|
pruned_loss_scale = (
|
||||||
1.0 if batch_idx_train >= warm_step
|
1.0
|
||||||
|
if batch_idx_train >= warm_step
|
||||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = (
|
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||||
simple_loss_scale * simple_loss +
|
|
||||||
pruned_loss_scale * pruned_loss
|
|
||||||
)
|
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
@ -715,7 +712,7 @@ def compute_loss(
|
|||||||
def compute_validation_loss(
|
def compute_validation_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
sp: spm.SentencePieceProcessor,
|
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
) -> MetricsTracker:
|
) -> MetricsTracker:
|
||||||
@ -728,7 +725,7 @@ def compute_validation_loss(
|
|||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
graph_compiler=graph_compiler,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
)
|
)
|
||||||
@ -751,7 +748,7 @@ def train_one_epoch(
|
|||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
scheduler: LRSchedulerType,
|
scheduler: LRSchedulerType,
|
||||||
sp: spm.SentencePieceProcessor,
|
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: GradScaler,
|
||||||
@ -795,13 +792,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
if batch_idx < cur_batch_idx:
|
|
||||||
continue
|
|
||||||
cur_batch_idx = batch_idx
|
|
||||||
|
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
@ -810,7 +801,7 @@ def train_one_epoch(
|
|||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
graph_compiler=graph_compiler,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
)
|
)
|
||||||
@ -827,7 +818,7 @@ def train_one_epoch(
|
|||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
except: # noqa
|
except: # noqa
|
||||||
display_and_save_batch(batch, params=params, sp=sp)
|
display_and_save_batch(batch, params=params)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
@ -848,7 +839,6 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train > 0
|
params.batch_idx_train > 0
|
||||||
and params.batch_idx_train % params.save_every_n == 0
|
and params.batch_idx_train % params.save_every_n == 0
|
||||||
):
|
):
|
||||||
params.cur_batch_idx = batch_idx
|
|
||||||
save_checkpoint_with_global_batch_idx(
|
save_checkpoint_with_global_batch_idx(
|
||||||
out_dir=params.exp_dir,
|
out_dir=params.exp_dir,
|
||||||
global_batch_idx=params.batch_idx_train,
|
global_batch_idx=params.batch_idx_train,
|
||||||
@ -861,7 +851,6 @@ def train_one_epoch(
|
|||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
del params.cur_batch_idx
|
|
||||||
remove_checkpoints(
|
remove_checkpoints(
|
||||||
out_dir=params.exp_dir,
|
out_dir=params.exp_dir,
|
||||||
topk=params.keep_last_k,
|
topk=params.keep_last_k,
|
||||||
@ -873,12 +862,16 @@ def train_one_epoch(
|
|||||||
# of the grad scaler is configurable, but we can't configure it to have different
|
# of the grad scaler is configurable, but we can't configure it to have different
|
||||||
# behavior depending on the current grad scale.
|
# behavior depending on the current grad scale.
|
||||||
cur_grad_scale = scaler._scale.item()
|
cur_grad_scale = scaler._scale.item()
|
||||||
if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
|
if cur_grad_scale < 1.0 or (
|
||||||
|
cur_grad_scale < 8.0 and batch_idx % 400 == 0
|
||||||
|
):
|
||||||
scaler.update(cur_grad_scale * 2.0)
|
scaler.update(cur_grad_scale * 2.0)
|
||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}")
|
raise RuntimeError(
|
||||||
|
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||||
|
)
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
@ -888,8 +881,12 @@ def train_one_epoch(
|
|||||||
f"Epoch {params.cur_epoch}, "
|
f"Epoch {params.cur_epoch}, "
|
||||||
f"batch {batch_idx}, loss[{loss_info}], "
|
f"batch {batch_idx}, loss[{loss_info}], "
|
||||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||||
f"lr: {cur_lr:.2e}, " +
|
f"lr: {cur_lr:.2e}, "
|
||||||
(f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
+ (
|
||||||
|
f"grad_scale: {scaler._scale.item()}"
|
||||||
|
if params.use_fp16
|
||||||
|
else ""
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
@ -905,23 +902,28 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
if params.use_fp16:
|
if params.use_fp16:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
"train/grad_scale",
|
||||||
|
cur_grad_scale,
|
||||||
|
params.batch_idx_train,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
batch_idx % params.valid_interval == 0
|
||||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
and not params.print_diagnostics
|
||||||
|
):
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
valid_info = compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
graph_compiler=graph_compiler,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
model.train()
|
model.train()
|
||||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||||
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
|
logging.info(
|
||||||
|
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||||
|
)
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
@ -948,8 +950,6 @@ def run(rank, world_size, args):
|
|||||||
"""
|
"""
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
if params.full_libri is False:
|
|
||||||
params.valid_interval = 1600
|
|
||||||
|
|
||||||
fix_random_seed(params.seed)
|
fix_random_seed(params.seed)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
@ -968,12 +968,14 @@ def run(rank, world_size, args):
|
|||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
lexicon = Lexicon(params.lang_dir)
|
||||||
sp.load(params.bpe_model)
|
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||||
|
lexicon=lexicon,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
params.blank_id = lexicon.token_table["<blk>"]
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
params.vocab_size = sp.get_piece_size()
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -997,12 +999,11 @@ def run(rank, world_size, args):
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
model = DDP(model, device_ids=[rank],
|
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||||
find_unused_parameters=True)
|
|
||||||
|
|
||||||
optimizer = ScaledAdam(model.parameters(),
|
optimizer = ScaledAdam(
|
||||||
lr=params.base_lr,
|
model.parameters(), lr=params.base_lr, clipping_scale=2.0
|
||||||
clipping_scale=2.0)
|
)
|
||||||
|
|
||||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||||
|
|
||||||
@ -1027,26 +1028,26 @@ def run(rank, world_size, args):
|
|||||||
if params.inf_check:
|
if params.inf_check:
|
||||||
register_inf_check_hooks(model)
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
wenetspeech = WenetSpeechAsrDataModule(args)
|
||||||
|
|
||||||
train_cuts = librispeech.train_clean_100_cuts()
|
train_cuts = wenetspeech.train_cuts()
|
||||||
if params.full_libri:
|
valid_cuts = wenetspeech.valid_cuts()
|
||||||
train_cuts += librispeech.train_clean_360_cuts()
|
|
||||||
train_cuts += librispeech.train_other_500_cuts()
|
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
# Keep only utterances with duration between 1 second and 19 seconds
|
||||||
#
|
#
|
||||||
# Caution: There is a reason to select 20.0 here. Please see
|
# Caution: There is a reason to select 19.0 here. Please see
|
||||||
# ../local/display_manifest_statistics.py
|
# ../local/display_manifest_statistics.py
|
||||||
#
|
#
|
||||||
# You should use ../local/display_manifest_statistics.py to get
|
# You should use ../local/display_manifest_statistics.py to get
|
||||||
# an utterance duration distribution for your dataset to select
|
# an utterance duration distribution for your dataset to select
|
||||||
# the threshold
|
# the threshold
|
||||||
return 1.0 <= c.duration <= 20.0
|
return 1.0 <= c.duration <= 19.0
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
|
valid_dl = wenetspeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||||
# We only load the sampler's state dict when it loads a checkpoint
|
# We only load the sampler's state dict when it loads a checkpoint
|
||||||
# saved in the middle of an epoch
|
# saved in the middle of an epoch
|
||||||
@ -1054,25 +1055,20 @@ def run(rank, world_size, args):
|
|||||||
else:
|
else:
|
||||||
sampler_state_dict = None
|
sampler_state_dict = None
|
||||||
|
|
||||||
train_dl = librispeech.train_dataloaders(
|
train_dl = wenetspeech.train_dataloaders(
|
||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_cuts = librispeech.dev_clean_cuts()
|
if False and not params.print_diagnostics:
|
||||||
valid_cuts += librispeech.dev_other_cuts()
|
|
||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
model=model,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
sp=sp,
|
graph_compiler=graph_compiler,
|
||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16,
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
init_scale=1.0)
|
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1093,7 +1089,7 @@ def run(rank, world_size, args):
|
|||||||
model_avg=model_avg,
|
model_avg=model_avg,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
sp=sp,
|
graph_compiler=graph_compiler,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
@ -1127,7 +1123,6 @@ def run(rank, world_size, args):
|
|||||||
def display_and_save_batch(
|
def display_and_save_batch(
|
||||||
batch: dict,
|
batch: dict,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
sp: spm.SentencePieceProcessor,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Display the batch statistics and save the batch into disk.
|
"""Display the batch statistics and save the batch into disk.
|
||||||
|
|
||||||
@ -1137,8 +1132,6 @@ def display_and_save_batch(
|
|||||||
for the content in it.
|
for the content in it.
|
||||||
params:
|
params:
|
||||||
Parameters for training. See :func:`get_params`.
|
Parameters for training. See :func:`get_params`.
|
||||||
sp:
|
|
||||||
The BPE model.
|
|
||||||
"""
|
"""
|
||||||
from lhotse.utils import uuid4
|
from lhotse.utils import uuid4
|
||||||
|
|
||||||
@ -1146,13 +1139,13 @@ def display_and_save_batch(
|
|||||||
logging.info(f"Saving batch to {filename}")
|
logging.info(f"Saving batch to {filename}")
|
||||||
torch.save(batch, filename)
|
torch.save(batch, filename)
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
|
||||||
features = batch["inputs"]
|
features = batch["inputs"]
|
||||||
|
|
||||||
logging.info(f"features shape: {features.shape}")
|
logging.info(f"features shape: {features.shape}")
|
||||||
|
|
||||||
y = sp.encode(supervisions["text"], out_type=int)
|
texts = batch["supervisions"]["text"]
|
||||||
num_tokens = sum(len(i) for i in y)
|
num_tokens = sum(len(i) for i in texts)
|
||||||
|
|
||||||
logging.info(f"num tokens: {num_tokens}")
|
logging.info(f"num tokens: {num_tokens}")
|
||||||
|
|
||||||
|
|
||||||
@ -1160,7 +1153,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
sp: spm.SentencePieceProcessor,
|
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
):
|
):
|
||||||
from lhotse.dataset import find_pessimistic_batches
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
@ -1176,7 +1169,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
graph_compiler=graph_compiler,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
)
|
)
|
||||||
@ -1191,15 +1184,18 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
f"Failing criterion: {criterion} "
|
f"Failing criterion: {criterion} "
|
||||||
f"(={crit_values[criterion]}) ..."
|
f"(={crit_values[criterion]}) ..."
|
||||||
)
|
)
|
||||||
display_and_save_batch(batch, params=params, sp=sp)
|
display_and_save_batch(batch, params=params)
|
||||||
raise
|
raise
|
||||||
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
|
logging.info(
|
||||||
|
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
WenetSpeechAsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
args.lang_dir = Path(args.lang_dir)
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
world_size = args.world_size
|
world_size = args.world_size
|
||||||
|
Loading…
x
Reference in New Issue
Block a user