Merge 3698ca040f2be2572e64d6bec9753fe3b27021b0 into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9

This commit is contained in:
Wei Kang 2025-07-25 09:16:26 +02:00 committed by GitHub
commit 3ee6d5ee44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 1431 additions and 61 deletions

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/attention_decoder.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/label_smoothing.py

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, # Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang, # Wei Kang,
# Mingshuang Luo, # Mingshuang Luo,
# Zengwei Yao, # Zengwei Yao,
@ -67,9 +67,11 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriHeavyAsrDataModule from asr_datamodule import LibriHeavyAsrDataModule
from attention_decoder import AttentionDecoderModel
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset import SpecAugment
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import AsrModel from model import AsrModel
@ -223,6 +225,41 @@ def add_model_arguments(parser: argparse.ArgumentParser):
""", """,
) )
parser.add_argument(
"--attention-decoder-dim",
type=int,
default=512,
help="""Dimension used in the attention decoder""",
)
parser.add_argument(
"--attention-decoder-num-layers",
type=int,
default=6,
help="""Number of transformer layers used in attention decoder""",
)
parser.add_argument(
"--attention-decoder-attention-dim",
type=int,
default=512,
help="""Attention dimension used in attention decoder""",
)
parser.add_argument(
"--attention-decoder-num-heads",
type=int,
default=8,
help="""Number of attention heads used in attention decoder""",
)
parser.add_argument(
"--attention-decoder-feedforward-dim",
type=int,
default=2048,
help="""Feedforward dimension used in attention decoder""",
)
parser.add_argument( parser.add_argument(
"--causal", "--causal",
type=str2bool, type=str2bool,
@ -261,6 +298,20 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="If True, use CTC head.", help="If True, use CTC head.",
) )
parser.add_argument(
"--use-attention-decoder",
type=str2bool,
default=False,
help="If True, use attention-decoder head.",
)
parser.add_argument(
"--use-cr-ctc",
type=str2bool,
default=False,
help="If True, use consistency-regularized CTC.",
)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -406,6 +457,27 @@ def get_parser():
help="Scale for CTC loss.", help="Scale for CTC loss.",
) )
parser.add_argument(
"--cr-loss-scale",
type=float,
default=0.15,
help="Scale for consistency-regularization loss.",
)
parser.add_argument(
"--time-mask-ratio",
type=float,
default=2.0,
help="When using cr-ctc, we increase the time-masking ratio.",
)
parser.add_argument(
"--attention-decoder-loss-scale",
type=float,
default=0.8,
help="Scale for attention-decoder loss.",
)
parser.add_argument( parser.add_argument(
"--seed", "--seed",
type=int, type=int,
@ -427,6 +499,17 @@ def get_parser():
help="Add hooks to check for infinite module outputs and gradients.", help="Add hooks to check for infinite module outputs and gradients.",
) )
parser.add_argument(
"--scan-for-oom-batches",
type=str2bool,
default=False,
help="""
Whether to scan for oom batches before training, this is helpful for
finding the suitable max_duration, you only need to run it once.
Caution: a little time consuming.
""",
)
parser.add_argument( parser.add_argument(
"--save-every-n", "--save-every-n",
type=int, type=int,
@ -541,6 +624,9 @@ def get_params() -> AttributeDict:
# parameters for zipformer # parameters for zipformer
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, # not passed in, this is fixed. "subsampling_factor": 4, # not passed in, this is fixed.
# parameters for attention-decoder
"ignore_id": -1,
"label_smoothing": 0.1,
"warm_step": 2000, "warm_step": 2000,
"env_info": get_env_info(), "env_info": get_env_info(),
} }
@ -613,6 +699,23 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
return joiner return joiner
def get_attention_decoder_model(params: AttributeDict) -> nn.Module:
decoder = AttentionDecoderModel(
vocab_size=params.vocab_size,
decoder_dim=params.attention_decoder_dim,
num_decoder_layers=params.attention_decoder_num_layers,
attention_dim=params.attention_decoder_attention_dim,
num_heads=params.attention_decoder_num_heads,
feedforward_dim=params.attention_decoder_feedforward_dim,
memory_dim=max(_to_int_tuple(params.encoder_dim)),
sos_id=params.sos_id,
eos_id=params.eos_id,
ignore_id=params.ignore_id,
label_smoothing=params.label_smoothing,
)
return decoder
def get_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module:
assert params.use_transducer or params.use_ctc, ( assert params.use_transducer or params.use_ctc, (
f"At least one of them should be True, " f"At least one of them should be True, "
@ -630,20 +733,45 @@ def get_model(params: AttributeDict) -> nn.Module:
decoder = None decoder = None
joiner = None joiner = None
if params.use_attention_decoder:
attention_decoder = get_attention_decoder_model(params)
else:
attention_decoder = None
model = AsrModel( model = AsrModel(
encoder_embed=encoder_embed, encoder_embed=encoder_embed,
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
attention_decoder=attention_decoder,
encoder_dim=max(_to_int_tuple(params.encoder_dim)), encoder_dim=max(_to_int_tuple(params.encoder_dim)),
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
use_transducer=params.use_transducer, use_transducer=params.use_transducer,
use_ctc=params.use_ctc, use_ctc=params.use_ctc,
use_attention_decoder=params.use_attention_decoder,
) )
return model return model
def get_spec_augment(params: AttributeDict) -> SpecAugment:
num_frame_masks = int(10 * params.time_mask_ratio)
max_frames_mask_fraction = 0.15 * params.time_mask_ratio
logging.info(
f"num_frame_masks: {num_frame_masks}, "
f"max_frames_mask_fraction: {max_frames_mask_fraction}"
)
spec_augment = SpecAugment(
time_warp_factor=0, # Do time warping in model.py
num_frame_masks=num_frame_masks, # default: 10
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15
)
return spec_augment
def load_checkpoint_if_available( def load_checkpoint_if_available(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -766,6 +894,7 @@ def compute_loss(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
is_training: bool, is_training: bool,
spec_augment: Optional[SpecAugment] = None,
) -> Tuple[Tensor, MetricsTracker]: ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute loss given the model and its inputs. Compute loss given the model and its inputs.
@ -782,8 +911,8 @@ def compute_loss(
True for training. False for validation. When it is True, this True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it function enables autograd during computation; when it is False, it
disables autograd. disables autograd.
warmup: a floating point value which increases throughout training; spec_augment:
values >= 1.0 are fully warmed up and have all modules present. The SpecAugment instance used only when use_cr_ctc is True.
""" """
device = model.device if isinstance(model, DDP) else next(model.parameters()).device device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"] feature = batch["inputs"]
@ -802,6 +931,21 @@ def compute_loss(
y = sp.encode(texts, out_type=int) y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
use_cr_ctc = params.use_cr_ctc
use_spec_aug = use_cr_ctc and is_training
if use_spec_aug:
supervision_intervals = batch["supervisions"]
supervision_segments = torch.stack(
[
supervision_intervals["sequence_idx"],
supervision_intervals["start_frame"],
supervision_intervals["num_frames"],
],
dim=1,
) # shape: (S, 3)
else:
supervision_segments = None
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
losses = model( losses = model(
x=feature, x=feature,
@ -810,8 +954,13 @@ def compute_loss(
prune_range=params.prune_range, prune_range=params.prune_range,
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
use_cr_ctc=use_cr_ctc,
use_spec_aug=use_spec_aug,
spec_augment=spec_augment,
supervision_segments=supervision_segments,
time_warp_factor=params.spec_aug_time_warp_factor,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3] simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = losses[:5]
loss = 0.0 loss = 0.0
@ -833,6 +982,11 @@ def compute_loss(
if params.use_ctc: if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss loss += params.ctc_loss_scale * ctc_loss
if use_cr_ctc:
loss += params.cr_loss_scale * cr_loss
if params.use_attention_decoder:
loss += params.attention_decoder_loss_scale * attention_decoder_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
@ -848,6 +1002,10 @@ def compute_loss(
info["pruned_loss"] = pruned_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_ctc: if params.use_ctc:
info["ctc_loss"] = ctc_loss.detach().cpu().item() info["ctc_loss"] = ctc_loss.detach().cpu().item()
if params.use_cr_ctc:
info["cr_loss"] = cr_loss.detach().cpu().item()
if params.use_attention_decoder:
info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item()
return loss, info return loss, info
@ -895,6 +1053,7 @@ def train_one_epoch(
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: GradScaler,
spec_augment: Optional[SpecAugment] = None,
model_avg: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
@ -921,6 +1080,8 @@ def train_one_epoch(
Dataloader for the validation dataset. Dataloader for the validation dataset.
scaler: scaler:
The scaler used for mix precision training. The scaler used for mix precision training.
spec_augment:
The SpecAugment instance used only when use_cr_ctc is True.
model_avg: model_avg:
The stored model averaged from the start of training. The stored model averaged from the start of training.
tb_writer: tb_writer:
@ -965,6 +1126,7 @@ def train_one_epoch(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
spec_augment=spec_augment,
) )
# summary stats # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
@ -1128,10 +1290,17 @@ def run(rank, world_size, args):
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = params.eos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if not params.use_transducer: if not params.use_transducer:
if not params.use_attention_decoder:
params.ctc_loss_scale = 1.0 params.ctc_loss_scale = 1.0
else:
assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, (
params.ctc_loss_scale,
params.attention_decoder_loss_scale,
)
logging.info(params) logging.info(params)
@ -1141,6 +1310,13 @@ def run(rank, world_size, args):
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
if params.use_cr_ctc:
assert params.use_ctc
assert not params.enable_spec_aug # we will do spec_augment in model.py
spec_augment = get_spec_augment(params)
else:
spec_augment = None
assert params.save_every_n >= params.average_period assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None model_avg: Optional[nn.Module] = None
if rank == 0: if rank == 0:
@ -1201,31 +1377,7 @@ def run(rank, world_size, args):
# an utterance duration distribution for your dataset to select # an utterance duration distribution for your dataset to select
# the threshold # the threshold
if c.duration < 2.0 or c.duration > 30.0: if c.duration < 2.0 or c.duration > 30.0:
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
return False return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 7) // 2 + 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False
return True return True
libriheavy = LibriHeavyAsrDataModule(args) libriheavy = LibriHeavyAsrDataModule(args)
@ -1259,14 +1411,15 @@ def run(rank, world_size, args):
valid_dl = libriheavy.valid_dataloaders(valid_cuts) valid_dl = libriheavy.valid_dataloaders(valid_cuts)
# if not params.print_diagnostics: if not params.print_diagnostics and params.scan_for_oom_batches:
# scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
# model=model, model=model,
# train_dl=train_dl, train_dl=train_dl,
# optimizer=optimizer, optimizer=optimizer,
# sp=sp, sp=sp,
# params=params, params=params,
# ) spec_augment=spec_augment,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
@ -1292,6 +1445,7 @@ def run(rank, world_size, args):
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
scaler=scaler, scaler=scaler,
spec_augment=spec_augment,
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
@ -1357,6 +1511,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
params: AttributeDict, params: AttributeDict,
spec_augment: Optional[SpecAugment] = None,
): ):
from lhotse.dataset import find_pessimistic_batches from lhotse.dataset import find_pessimistic_batches
@ -1374,6 +1529,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
spec_augment=spec_augment,
) )
loss.backward() loss.backward()
optimizer.zero_grad() optimizer.zero_grad()

View File

@ -1,9 +1,10 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# #
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, # Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang,
# Liyong Guo, # Liyong Guo,
# Quandong Wang, # Quandong Wang,
# Zengwei Yao) # Zengwei Yao,
# Wei Kang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #

View File

@ -499,6 +499,17 @@ def get_parser():
help="Add hooks to check for infinite module outputs and gradients.", help="Add hooks to check for infinite module outputs and gradients.",
) )
parser.add_argument(
"--scan-for-oom-batches",
type=str2bool,
default=False,
help="""
Whether to scan for oom batches before training, this is helpful for
finding the suitable max_duration, you only need to run it once.
Caution: a little time consuming.
""",
)
parser.add_argument( parser.add_argument(
"--save-every-n", "--save-every-n",
type=int, type=int,
@ -1398,27 +1409,6 @@ def run(rank, world_size, args):
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# ) # )
return False return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 7) // 2 + 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False
return True return True
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
@ -1438,7 +1428,7 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts() valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts) valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics: if not params.print_diagnostics and params.scan_for_oom_batches:
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
model=model, model=model,
train_dl=train_dl, train_dl=train_dl,