mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
remove cr-loss
This commit is contained in:
parent
9b95c72d19
commit
dc74705d20
@ -48,7 +48,6 @@ It supports training with:
|
|||||||
- transducer loss (default)
|
- transducer loss (default)
|
||||||
- ctc loss
|
- ctc loss
|
||||||
- attention decoder loss
|
- attention decoder loss
|
||||||
- cr-ctc loss (should use half the max-duration compared to regular ctc)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -66,7 +65,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule_with_parallel_aug import LibriSpeechAsrDataModuleWithParallelAug
|
||||||
from attention_decoder import AttentionDecoderModel
|
from attention_decoder import AttentionDecoderModel
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
@ -304,13 +303,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="If True, use attention-decoder head.",
|
help="If True, use attention-decoder head.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-cr-ctc",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="If True, use consistency-regularized CTC.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -456,20 +448,6 @@ def get_parser():
|
|||||||
help="Scale for CTC loss.",
|
help="Scale for CTC loss.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--cr-loss-scale",
|
|
||||||
type=float,
|
|
||||||
default=0.2,
|
|
||||||
help="Scale for consistency-regularization loss.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--time-mask-ratio",
|
|
||||||
type=float,
|
|
||||||
default=2.5,
|
|
||||||
help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--attention-decoder-loss-scale",
|
"--attention-decoder-loss-scale",
|
||||||
type=float,
|
type=float,
|
||||||
@ -738,24 +716,6 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def get_spec_augment(params: AttributeDict) -> SpecAugment:
|
|
||||||
num_frame_masks = int(10 * params.time_mask_ratio)
|
|
||||||
max_frames_mask_fraction = 0.15 * params.time_mask_ratio
|
|
||||||
logging.info(
|
|
||||||
f"num_frame_masks: {num_frame_masks}, "
|
|
||||||
f"max_frames_mask_fraction: {max_frames_mask_fraction}"
|
|
||||||
)
|
|
||||||
spec_augment = SpecAugment(
|
|
||||||
time_warp_factor=0, # Do time warping in model.py
|
|
||||||
num_frame_masks=num_frame_masks, # default: 10
|
|
||||||
features_mask_size=27,
|
|
||||||
num_feature_masks=2,
|
|
||||||
frames_mask_size=100,
|
|
||||||
max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15
|
|
||||||
)
|
|
||||||
return spec_augment
|
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint_if_available(
|
def load_checkpoint_if_available(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -878,7 +838,6 @@ def compute_loss(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
spec_augment: Optional[SpecAugment] = None,
|
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute loss given the model and its inputs.
|
Compute loss given the model and its inputs.
|
||||||
@ -895,8 +854,6 @@ def compute_loss(
|
|||||||
True for training. False for validation. When it is True, this
|
True for training. False for validation. When it is True, this
|
||||||
function enables autograd during computation; when it is False, it
|
function enables autograd during computation; when it is False, it
|
||||||
disables autograd.
|
disables autograd.
|
||||||
spec_augment:
|
|
||||||
The SpecAugment instance used only when use_cr_ctc is True.
|
|
||||||
"""
|
"""
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
@ -914,21 +871,6 @@ def compute_loss(
|
|||||||
y = sp.encode(texts, out_type=int)
|
y = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
use_cr_ctc = params.use_cr_ctc
|
|
||||||
use_spec_aug = use_cr_ctc and is_training
|
|
||||||
if use_spec_aug:
|
|
||||||
supervision_intervals = batch["supervisions"]
|
|
||||||
supervision_segments = torch.stack(
|
|
||||||
[
|
|
||||||
supervision_intervals["sequence_idx"],
|
|
||||||
supervision_intervals["start_frame"],
|
|
||||||
supervision_intervals["num_frames"],
|
|
||||||
],
|
|
||||||
dim=1,
|
|
||||||
) # shape: (S, 3)
|
|
||||||
else:
|
|
||||||
supervision_segments = None
|
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = model(
|
simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
@ -937,11 +879,6 @@ def compute_loss(
|
|||||||
prune_range=params.prune_range,
|
prune_range=params.prune_range,
|
||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
use_cr_ctc=use_cr_ctc,
|
|
||||||
use_spec_aug=use_spec_aug,
|
|
||||||
spec_augment=spec_augment,
|
|
||||||
supervision_segments=supervision_segments,
|
|
||||||
time_warp_factor=params.spec_aug_time_warp_factor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
@ -964,8 +901,6 @@ def compute_loss(
|
|||||||
|
|
||||||
if params.use_ctc:
|
if params.use_ctc:
|
||||||
loss += params.ctc_loss_scale * ctc_loss
|
loss += params.ctc_loss_scale * ctc_loss
|
||||||
if use_cr_ctc:
|
|
||||||
loss += params.cr_loss_scale * cr_loss
|
|
||||||
|
|
||||||
if params.use_attention_decoder:
|
if params.use_attention_decoder:
|
||||||
loss += params.attention_decoder_loss_scale * attention_decoder_loss
|
loss += params.attention_decoder_loss_scale * attention_decoder_loss
|
||||||
@ -984,8 +919,6 @@ def compute_loss(
|
|||||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||||
if params.use_ctc:
|
if params.use_ctc:
|
||||||
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||||
if params.use_cr_ctc:
|
|
||||||
info["cr_loss"] = cr_loss.detach().cpu().item()
|
|
||||||
if params.use_attention_decoder:
|
if params.use_attention_decoder:
|
||||||
info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item()
|
info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item()
|
||||||
|
|
||||||
@ -1035,7 +968,6 @@ def train_one_epoch(
|
|||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: GradScaler,
|
||||||
spec_augment: Optional[SpecAugment] = None,
|
|
||||||
model_avg: Optional[nn.Module] = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -1062,8 +994,6 @@ def train_one_epoch(
|
|||||||
Dataloader for the validation dataset.
|
Dataloader for the validation dataset.
|
||||||
scaler:
|
scaler:
|
||||||
The scaler used for mix precision training.
|
The scaler used for mix precision training.
|
||||||
spec_augment:
|
|
||||||
The SpecAugment instance used only when use_cr_ctc is True.
|
|
||||||
model_avg:
|
model_avg:
|
||||||
The stored model averaged from the start of training.
|
The stored model averaged from the start of training.
|
||||||
tb_writer:
|
tb_writer:
|
||||||
@ -1110,7 +1040,6 @@ def train_one_epoch(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
spec_augment=spec_augment,
|
|
||||||
)
|
)
|
||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
@ -1317,13 +1246,6 @@ def run(rank, world_size, args):
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
if params.use_cr_ctc:
|
|
||||||
assert params.use_ctc
|
|
||||||
assert not params.enable_spec_aug # we will do spec_augment in model.py
|
|
||||||
spec_augment = get_spec_augment(params)
|
|
||||||
else:
|
|
||||||
spec_augment = None
|
|
||||||
|
|
||||||
assert params.save_every_n >= params.average_period
|
assert params.save_every_n >= params.average_period
|
||||||
model_avg: Optional[nn.Module] = None
|
model_avg: Optional[nn.Module] = None
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
@ -1369,7 +1291,7 @@ def run(rank, world_size, args):
|
|||||||
if params.inf_check:
|
if params.inf_check:
|
||||||
register_inf_check_hooks(model)
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModuleWithParallelAug(args)
|
||||||
|
|
||||||
if params.full_libri:
|
if params.full_libri:
|
||||||
train_cuts = librispeech.train_all_shuf_cuts()
|
train_cuts = librispeech.train_all_shuf_cuts()
|
||||||
@ -1446,7 +1368,6 @@ def run(rank, world_size, args):
|
|||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
params=params,
|
params=params,
|
||||||
spec_augment=spec_augment,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0)
|
scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0)
|
||||||
@ -1474,7 +1395,6 @@ def run(rank, world_size, args):
|
|||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
spec_augment=spec_augment,
|
|
||||||
tb_writer=tb_writer,
|
tb_writer=tb_writer,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
@ -1540,7 +1460,6 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
spec_augment: Optional[SpecAugment] = None,
|
|
||||||
):
|
):
|
||||||
from lhotse.dataset import find_pessimistic_batches
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
|
|
||||||
@ -1560,7 +1479,6 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
spec_augment=spec_augment,
|
|
||||||
)
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@ -1582,7 +1500,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
LibriSpeechAsrDataModuleWithParallelAug.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user