remove cr-loss

This commit is contained in:
Fangjun Kuang 2025-05-29 11:49:30 +08:00
parent 9b95c72d19
commit dc74705d20

View File

@ -48,7 +48,6 @@ It supports training with:
- transducer loss (default) - transducer loss (default)
- ctc loss - ctc loss
- attention decoder loss - attention decoder loss
- cr-ctc loss (should use half the max-duration compared to regular ctc)
""" """
@ -66,7 +65,7 @@ import sentencepiece as spm
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule_with_parallel_aug import LibriSpeechAsrDataModuleWithParallelAug
from attention_decoder import AttentionDecoderModel from attention_decoder import AttentionDecoderModel
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
@ -304,13 +303,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="If True, use attention-decoder head.", help="If True, use attention-decoder head.",
) )
parser.add_argument(
"--use-cr-ctc",
type=str2bool,
default=False,
help="If True, use consistency-regularized CTC.",
)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -456,20 +448,6 @@ def get_parser():
help="Scale for CTC loss.", help="Scale for CTC loss.",
) )
parser.add_argument(
"--cr-loss-scale",
type=float,
default=0.2,
help="Scale for consistency-regularization loss.",
)
parser.add_argument(
"--time-mask-ratio",
type=float,
default=2.5,
help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.",
)
parser.add_argument( parser.add_argument(
"--attention-decoder-loss-scale", "--attention-decoder-loss-scale",
type=float, type=float,
@ -738,24 +716,6 @@ def get_model(params: AttributeDict) -> nn.Module:
return model return model
def get_spec_augment(params: AttributeDict) -> SpecAugment:
num_frame_masks = int(10 * params.time_mask_ratio)
max_frames_mask_fraction = 0.15 * params.time_mask_ratio
logging.info(
f"num_frame_masks: {num_frame_masks}, "
f"max_frames_mask_fraction: {max_frames_mask_fraction}"
)
spec_augment = SpecAugment(
time_warp_factor=0, # Do time warping in model.py
num_frame_masks=num_frame_masks, # default: 10
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15
)
return spec_augment
def load_checkpoint_if_available( def load_checkpoint_if_available(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -878,7 +838,6 @@ def compute_loss(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
is_training: bool, is_training: bool,
spec_augment: Optional[SpecAugment] = None,
) -> Tuple[Tensor, MetricsTracker]: ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute loss given the model and its inputs. Compute loss given the model and its inputs.
@ -895,8 +854,6 @@ def compute_loss(
True for training. False for validation. When it is True, this True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it function enables autograd during computation; when it is False, it
disables autograd. disables autograd.
spec_augment:
The SpecAugment instance used only when use_cr_ctc is True.
""" """
device = model.device if isinstance(model, DDP) else next(model.parameters()).device device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"] feature = batch["inputs"]
@ -914,21 +871,6 @@ def compute_loss(
y = sp.encode(texts, out_type=int) y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
use_cr_ctc = params.use_cr_ctc
use_spec_aug = use_cr_ctc and is_training
if use_spec_aug:
supervision_intervals = batch["supervisions"]
supervision_segments = torch.stack(
[
supervision_intervals["sequence_idx"],
supervision_intervals["start_frame"],
supervision_intervals["num_frames"],
],
dim=1,
) # shape: (S, 3)
else:
supervision_segments = None
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = model( simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = model(
x=feature, x=feature,
@ -937,11 +879,6 @@ def compute_loss(
prune_range=params.prune_range, prune_range=params.prune_range,
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
use_cr_ctc=use_cr_ctc,
use_spec_aug=use_spec_aug,
spec_augment=spec_augment,
supervision_segments=supervision_segments,
time_warp_factor=params.spec_aug_time_warp_factor,
) )
loss = 0.0 loss = 0.0
@ -964,8 +901,6 @@ def compute_loss(
if params.use_ctc: if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss loss += params.ctc_loss_scale * ctc_loss
if use_cr_ctc:
loss += params.cr_loss_scale * cr_loss
if params.use_attention_decoder: if params.use_attention_decoder:
loss += params.attention_decoder_loss_scale * attention_decoder_loss loss += params.attention_decoder_loss_scale * attention_decoder_loss
@ -984,8 +919,6 @@ def compute_loss(
info["pruned_loss"] = pruned_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_ctc: if params.use_ctc:
info["ctc_loss"] = ctc_loss.detach().cpu().item() info["ctc_loss"] = ctc_loss.detach().cpu().item()
if params.use_cr_ctc:
info["cr_loss"] = cr_loss.detach().cpu().item()
if params.use_attention_decoder: if params.use_attention_decoder:
info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item()
@ -1035,7 +968,6 @@ def train_one_epoch(
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: GradScaler,
spec_augment: Optional[SpecAugment] = None,
model_avg: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
@ -1062,8 +994,6 @@ def train_one_epoch(
Dataloader for the validation dataset. Dataloader for the validation dataset.
scaler: scaler:
The scaler used for mix precision training. The scaler used for mix precision training.
spec_augment:
The SpecAugment instance used only when use_cr_ctc is True.
model_avg: model_avg:
The stored model averaged from the start of training. The stored model averaged from the start of training.
tb_writer: tb_writer:
@ -1110,7 +1040,6 @@ def train_one_epoch(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
spec_augment=spec_augment,
) )
# summary stats # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
@ -1317,13 +1246,6 @@ def run(rank, world_size, args):
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
if params.use_cr_ctc:
assert params.use_ctc
assert not params.enable_spec_aug # we will do spec_augment in model.py
spec_augment = get_spec_augment(params)
else:
spec_augment = None
assert params.save_every_n >= params.average_period assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None model_avg: Optional[nn.Module] = None
if rank == 0: if rank == 0:
@ -1369,7 +1291,7 @@ def run(rank, world_size, args):
if params.inf_check: if params.inf_check:
register_inf_check_hooks(model) register_inf_check_hooks(model)
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModuleWithParallelAug(args)
if params.full_libri: if params.full_libri:
train_cuts = librispeech.train_all_shuf_cuts() train_cuts = librispeech.train_all_shuf_cuts()
@ -1446,7 +1368,6 @@ def run(rank, world_size, args):
optimizer=optimizer, optimizer=optimizer,
sp=sp, sp=sp,
params=params, params=params,
spec_augment=spec_augment,
) )
scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0)
@ -1474,7 +1395,6 @@ def run(rank, world_size, args):
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
scaler=scaler, scaler=scaler,
spec_augment=spec_augment,
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
@ -1540,7 +1460,6 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
params: AttributeDict, params: AttributeDict,
spec_augment: Optional[SpecAugment] = None,
): ):
from lhotse.dataset import find_pessimistic_batches from lhotse.dataset import find_pessimistic_batches
@ -1560,7 +1479,6 @@ def scan_pessimistic_batches_for_oom(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
spec_augment=spec_augment,
) )
loss.backward() loss.backward()
optimizer.zero_grad() optimizer.zero_grad()
@ -1582,7 +1500,7 @@ def scan_pessimistic_batches_for_oom(
def main(): def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) LibriSpeechAsrDataModuleWithParallelAug.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)