Update gigaspeech train.py

This commit is contained in:
pkufool 2024-10-08 12:17:23 +08:00
parent 6bba97514e
commit e4fa25a780
3 changed files with 182 additions and 2999 deletions

View File

@ -65,6 +65,7 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import GigaSpeechAsrDataModule
from attention_decoder import AttentionDecoderModel
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
@ -99,6 +100,8 @@ from icefall.utils import (
str2bool,
)
from spec_augment import SpecAugment
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -220,6 +223,41 @@ def add_model_arguments(parser: argparse.ArgumentParser):
""",
)
parser.add_argument(
"--attention-decoder-dim",
type=int,
default=512,
help="""Dimension used in the attention decoder""",
)
parser.add_argument(
"--attention-decoder-num-layers",
type=int,
default=6,
help="""Number of transformer layers used in attention decoder""",
)
parser.add_argument(
"--attention-decoder-attention-dim",
type=int,
default=512,
help="""Attention dimension used in attention decoder""",
)
parser.add_argument(
"--attention-decoder-num-heads",
type=int,
default=8,
help="""Number of attention heads used in attention decoder""",
)
parser.add_argument(
"--attention-decoder-feedforward-dim",
type=int,
default=2048,
help="""Feedforward dimension used in attention decoder""",
)
parser.add_argument(
"--causal",
type=str2bool,
@ -258,6 +296,20 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="If True, use CTC head.",
)
parser.add_argument(
"--use-attention-decoder",
type=str2bool,
default=False,
help="If True, use attention-decoder head.",
)
parser.add_argument(
"--use-cr-ctc",
type=str2bool,
default=False,
help="If True, use consistency-regularized CTC.",
)
def get_parser():
parser = argparse.ArgumentParser(
@ -403,6 +455,34 @@ def get_parser():
help="Scale for CTC loss.",
)
parser.add_argument(
"--cr-loss-scale",
type=float,
default=0.15,
help="Scale for consistency-regularization loss.",
)
parser.add_argument(
"--time-mask-ratio",
type=float,
default=2.0,
help="When using cr-ctc, we increase the time-masking ratio.",
)
parser.add_argument(
"--cr-loss-masked-scale",
type=float,
default=1.0,
help="The value used to scale up the cr_loss at masked positions",
)
parser.add_argument(
"--attention-decoder-loss-scale",
type=float,
default=0.8,
help="Scale for attention-decoder loss.",
)
parser.add_argument(
"--seed",
type=int,
@ -542,6 +622,9 @@ def get_params() -> AttributeDict:
# parameters for zipformer
"feature_dim": 80,
"subsampling_factor": 4, # not passed in, this is fixed.
# parameters for attention-decoder
"ignore_id": -1,
"label_smoothing": 0.1,
"warm_step": 2000,
"env_info": get_env_info(),
}
@ -614,6 +697,23 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
return joiner
def get_attention_decoder_model(params: AttributeDict) -> nn.Module:
decoder = AttentionDecoderModel(
vocab_size=params.vocab_size,
decoder_dim=params.attention_decoder_dim,
num_decoder_layers=params.attention_decoder_num_layers,
attention_dim=params.attention_decoder_attention_dim,
num_heads=params.attention_decoder_num_heads,
feedforward_dim=params.attention_decoder_feedforward_dim,
memory_dim=max(_to_int_tuple(params.encoder_dim)),
sos_id=params.sos_id,
eos_id=params.eos_id,
ignore_id=params.ignore_id,
label_smoothing=params.label_smoothing,
)
return decoder
def get_model(params: AttributeDict) -> nn.Module:
assert params.use_transducer or params.use_ctc, (
f"At least one of them should be True, "
@ -631,20 +731,45 @@ def get_model(params: AttributeDict) -> nn.Module:
decoder = None
joiner = None
if params.use_attention_decoder:
attention_decoder = get_attention_decoder_model(params)
else:
attention_decoder = None
model = AsrModel(
encoder_embed=encoder_embed,
encoder=encoder,
decoder=decoder,
joiner=joiner,
attention_decoder=attention_decoder,
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
decoder_dim=params.decoder_dim,
vocab_size=params.vocab_size,
use_transducer=params.use_transducer,
use_ctc=params.use_ctc,
use_attention_decoder=params.use_attention_decoder,
)
return model
def get_spec_augment(params: AttributeDict) -> SpecAugment:
num_frame_masks = int(10 * params.time_mask_ratio)
max_frames_mask_fraction = 0.15 * params.time_mask_ratio
logging.info(
f"num_frame_masks: {num_frame_masks}, "
f"max_frames_mask_fraction: {max_frames_mask_fraction}"
)
spec_augment = SpecAugment(
time_warp_factor=0, # Do time warping in model.py
num_frame_masks=num_frame_masks, # default: 10
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15
)
return spec_augment
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
@ -767,6 +892,7 @@ def compute_loss(
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
spec_augment: Optional[SpecAugment] = None,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute loss given the model and its inputs.
@ -783,8 +909,8 @@ def compute_loss(
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
spec_augment:
The SpecAugment instance used only when use_cr_ctc is True.
"""
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
@ -802,6 +928,21 @@ def compute_loss(
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y)
use_cr_ctc = params.use_cr_ctc
use_spec_aug = use_cr_ctc and is_training
if use_spec_aug:
supervision_intervals = batch["supervisions"]
supervision_segments = torch.stack(
[
supervision_intervals["sequence_idx"],
supervision_intervals["start_frame"],
supervision_intervals["num_frames"],
],
dim=1,
) # shape: (S, 3)
else:
supervision_segments = None
with torch.set_grad_enabled(is_training):
losses = model(
x=feature,
@ -810,8 +951,14 @@ def compute_loss(
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
use_cr_ctc=use_cr_ctc,
use_spec_aug=use_spec_aug,
spec_augment=spec_augment,
supervision_segments=supervision_segments,
time_warp_factor=params.spec_aug_time_warp_factor,
cr_loss_masked_scale=params.cr_loss_masked_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]
simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = losses[:5]
loss = 0.0
@ -833,6 +980,11 @@ def compute_loss(
if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss
if use_cr_ctc:
loss += params.cr_loss_scale * cr_loss
if params.use_attention_decoder:
loss += params.attention_decoder_loss_scale * attention_decoder_loss
assert loss.requires_grad == is_training
@ -848,6 +1000,10 @@ def compute_loss(
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_ctc:
info["ctc_loss"] = ctc_loss.detach().cpu().item()
if params.use_cr_ctc:
info["cr_loss"] = cr_loss.detach().cpu().item()
if params.use_attention_decoder:
info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item()
return loss, info
@ -895,6 +1051,7 @@ def train_one_epoch(
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
spec_augment: Optional[SpecAugment] = None,
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -921,6 +1078,8 @@ def train_one_epoch(
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
spec_augment:
The SpecAugment instance used only when use_cr_ctc is True.
model_avg:
The stored model averaged from the start of training.
tb_writer:
@ -965,6 +1124,7 @@ def train_one_epoch(
sp=sp,
batch=batch,
is_training=True,
spec_augment=spec_augment,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
@ -1124,10 +1284,17 @@ def run(rank, world_size, args):
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = params.eos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()
if not params.use_transducer:
params.ctc_loss_scale = 1.0
if not params.use_attention_decoder:
params.ctc_loss_scale = 1.0
else:
assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, (
params.ctc_loss_scale,
params.attention_decoder_loss_scale,
)
logging.info(params)
@ -1137,6 +1304,13 @@ def run(rank, world_size, args):
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
if params.use_cr_ctc:
assert params.use_ctc
assert not params.enable_spec_aug # we will do spec_augment in model.py
spec_augment = get_spec_augment(params)
else:
spec_augment = None
assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None
if rank == 0:
@ -1215,6 +1389,7 @@ def run(rank, world_size, args):
optimizer=optimizer,
sp=sp,
params=params,
spec_augment=spec_augment,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
@ -1242,6 +1417,7 @@ def run(rank, world_size, args):
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
spec_augment=spec_augment,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
@ -1307,6 +1483,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
spec_augment: Optional[SpecAugment] = None,
):
from lhotse.dataset import find_pessimistic_batches
@ -1324,6 +1501,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp,
batch=batch,
is_training=True,
spec_augment=spec_augment,
)
loss.backward()
optimizer.zero_grad()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff