mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
Add cr-ctc to libriheavy recipe
This commit is contained in:
parent
57451b0382
commit
3bd4a2e6c3
1
egs/libriheavy/ASR/zipformer/attention_decoder.py
Symbolic link
1
egs/libriheavy/ASR/zipformer/attention_decoder.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/attention_decoder.py
|
1221
egs/libriheavy/ASR/zipformer/ctc_decode.py
Executable file
1221
egs/libriheavy/ASR/zipformer/ctc_decode.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/libriheavy/ASR/zipformer/label_smoothing.py
Symbolic link
1
egs/libriheavy/ASR/zipformer/label_smoothing.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/label_smoothing.py
|
@ -1,5 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
# Wei Kang,
|
# Wei Kang,
|
||||||
# Mingshuang Luo,
|
# Mingshuang Luo,
|
||||||
# Zengwei Yao,
|
# Zengwei Yao,
|
||||||
@ -67,9 +67,11 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriHeavyAsrDataModule
|
from asr_datamodule import LibriHeavyAsrDataModule
|
||||||
|
from attention_decoder import AttentionDecoderModel
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
|
from lhotse.dataset import SpecAugment
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import AsrModel
|
from model import AsrModel
|
||||||
@ -223,6 +225,41 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--attention-decoder-dim",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="""Dimension used in the attention decoder""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--attention-decoder-num-layers",
|
||||||
|
type=int,
|
||||||
|
default=6,
|
||||||
|
help="""Number of transformer layers used in attention decoder""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--attention-decoder-attention-dim",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="""Attention dimension used in attention decoder""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--attention-decoder-num-heads",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="""Number of attention heads used in attention decoder""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--attention-decoder-feedforward-dim",
|
||||||
|
type=int,
|
||||||
|
default=2048,
|
||||||
|
help="""Feedforward dimension used in attention decoder""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--causal",
|
"--causal",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -261,6 +298,20 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="If True, use CTC head.",
|
help="If True, use CTC head.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-attention-decoder",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="If True, use attention-decoder head.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-cr-ctc",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="If True, use consistency-regularized CTC.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -406,6 +457,34 @@ def get_parser():
|
|||||||
help="Scale for CTC loss.",
|
help="Scale for CTC loss.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--cr-loss-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.15,
|
||||||
|
help="Scale for consistency-regularization loss.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--time-mask-ratio",
|
||||||
|
type=float,
|
||||||
|
default=2.0,
|
||||||
|
help="When using cr-ctc, we increase the time-masking ratio.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--cr-loss-masked-scale",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="The value used to scale up the cr_loss at masked positions",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--attention-decoder-loss-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.8,
|
||||||
|
help="Scale for attention-decoder loss.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
@ -427,6 +506,17 @@ def get_parser():
|
|||||||
help="Add hooks to check for infinite module outputs and gradients.",
|
help="Add hooks to check for infinite module outputs and gradients.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--scan-for-oom-batches",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""
|
||||||
|
Whether to scan for oom batches before training, this is helpful for
|
||||||
|
finding the suitable max_duration, you only need to run it once.
|
||||||
|
Caution: a little time consuming.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-every-n",
|
"--save-every-n",
|
||||||
type=int,
|
type=int,
|
||||||
@ -541,6 +631,9 @@ def get_params() -> AttributeDict:
|
|||||||
# parameters for zipformer
|
# parameters for zipformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4, # not passed in, this is fixed.
|
"subsampling_factor": 4, # not passed in, this is fixed.
|
||||||
|
# parameters for attention-decoder
|
||||||
|
"ignore_id": -1,
|
||||||
|
"label_smoothing": 0.1,
|
||||||
"warm_step": 2000,
|
"warm_step": 2000,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
@ -613,6 +706,23 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
return joiner
|
return joiner
|
||||||
|
|
||||||
|
|
||||||
|
def get_attention_decoder_model(params: AttributeDict) -> nn.Module:
|
||||||
|
decoder = AttentionDecoderModel(
|
||||||
|
vocab_size=params.vocab_size,
|
||||||
|
decoder_dim=params.attention_decoder_dim,
|
||||||
|
num_decoder_layers=params.attention_decoder_num_layers,
|
||||||
|
attention_dim=params.attention_decoder_attention_dim,
|
||||||
|
num_heads=params.attention_decoder_num_heads,
|
||||||
|
feedforward_dim=params.attention_decoder_feedforward_dim,
|
||||||
|
memory_dim=max(_to_int_tuple(params.encoder_dim)),
|
||||||
|
sos_id=params.sos_id,
|
||||||
|
eos_id=params.eos_id,
|
||||||
|
ignore_id=params.ignore_id,
|
||||||
|
label_smoothing=params.label_smoothing,
|
||||||
|
)
|
||||||
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
def get_model(params: AttributeDict) -> nn.Module:
|
def get_model(params: AttributeDict) -> nn.Module:
|
||||||
assert params.use_transducer or params.use_ctc, (
|
assert params.use_transducer or params.use_ctc, (
|
||||||
f"At least one of them should be True, "
|
f"At least one of them should be True, "
|
||||||
@ -630,20 +740,45 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
decoder = None
|
decoder = None
|
||||||
joiner = None
|
joiner = None
|
||||||
|
|
||||||
|
if params.use_attention_decoder:
|
||||||
|
attention_decoder = get_attention_decoder_model(params)
|
||||||
|
else:
|
||||||
|
attention_decoder = None
|
||||||
|
|
||||||
model = AsrModel(
|
model = AsrModel(
|
||||||
encoder_embed=encoder_embed,
|
encoder_embed=encoder_embed,
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
|
attention_decoder=attention_decoder,
|
||||||
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
use_transducer=params.use_transducer,
|
use_transducer=params.use_transducer,
|
||||||
use_ctc=params.use_ctc,
|
use_ctc=params.use_ctc,
|
||||||
|
use_attention_decoder=params.use_attention_decoder,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_spec_augment(params: AttributeDict) -> SpecAugment:
|
||||||
|
num_frame_masks = int(10 * params.time_mask_ratio)
|
||||||
|
max_frames_mask_fraction = 0.15 * params.time_mask_ratio
|
||||||
|
logging.info(
|
||||||
|
f"num_frame_masks: {num_frame_masks}, "
|
||||||
|
f"max_frames_mask_fraction: {max_frames_mask_fraction}"
|
||||||
|
)
|
||||||
|
spec_augment = SpecAugment(
|
||||||
|
time_warp_factor=0, # Do time warping in model.py
|
||||||
|
num_frame_masks=num_frame_masks, # default: 10
|
||||||
|
features_mask_size=27,
|
||||||
|
num_feature_masks=2,
|
||||||
|
frames_mask_size=100,
|
||||||
|
max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15
|
||||||
|
)
|
||||||
|
return spec_augment
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint_if_available(
|
def load_checkpoint_if_available(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -766,6 +901,7 @@ def compute_loss(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
|
spec_augment: Optional[SpecAugment] = None,
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute loss given the model and its inputs.
|
Compute loss given the model and its inputs.
|
||||||
@ -782,8 +918,8 @@ def compute_loss(
|
|||||||
True for training. False for validation. When it is True, this
|
True for training. False for validation. When it is True, this
|
||||||
function enables autograd during computation; when it is False, it
|
function enables autograd during computation; when it is False, it
|
||||||
disables autograd.
|
disables autograd.
|
||||||
warmup: a floating point value which increases throughout training;
|
spec_augment:
|
||||||
values >= 1.0 are fully warmed up and have all modules present.
|
The SpecAugment instance used only when use_cr_ctc is True.
|
||||||
"""
|
"""
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
@ -802,6 +938,21 @@ def compute_loss(
|
|||||||
y = sp.encode(texts, out_type=int)
|
y = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
|
use_cr_ctc = params.use_cr_ctc
|
||||||
|
use_spec_aug = use_cr_ctc and is_training
|
||||||
|
if use_spec_aug:
|
||||||
|
supervision_intervals = batch["supervisions"]
|
||||||
|
supervision_segments = torch.stack(
|
||||||
|
[
|
||||||
|
supervision_intervals["sequence_idx"],
|
||||||
|
supervision_intervals["start_frame"],
|
||||||
|
supervision_intervals["num_frames"],
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
) # shape: (S, 3)
|
||||||
|
else:
|
||||||
|
supervision_segments = None
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
losses = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
@ -810,6 +961,12 @@ def compute_loss(
|
|||||||
prune_range=params.prune_range,
|
prune_range=params.prune_range,
|
||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
|
use_cr_ctc=use_cr_ctc,
|
||||||
|
use_spec_aug=use_spec_aug,
|
||||||
|
spec_augment=spec_augment,
|
||||||
|
supervision_segments=supervision_segments,
|
||||||
|
time_warp_factor=params.spec_aug_time_warp_factor,
|
||||||
|
cr_loss_masked_scale=params.cr_loss_masked_scale,
|
||||||
)
|
)
|
||||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
@ -833,6 +990,11 @@ def compute_loss(
|
|||||||
|
|
||||||
if params.use_ctc:
|
if params.use_ctc:
|
||||||
loss += params.ctc_loss_scale * ctc_loss
|
loss += params.ctc_loss_scale * ctc_loss
|
||||||
|
if use_cr_ctc:
|
||||||
|
loss += params.cr_loss_scale * cr_loss
|
||||||
|
|
||||||
|
if params.use_attention_decoder:
|
||||||
|
loss += params.attention_decoder_loss_scale * attention_decoder_loss
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
@ -848,6 +1010,10 @@ def compute_loss(
|
|||||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||||
if params.use_ctc:
|
if params.use_ctc:
|
||||||
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||||
|
if params.use_cr_ctc:
|
||||||
|
info["cr_loss"] = cr_loss.detach().cpu().item()
|
||||||
|
if params.use_attention_decoder:
|
||||||
|
info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item()
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
@ -895,6 +1061,7 @@ def train_one_epoch(
|
|||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: GradScaler,
|
||||||
|
spec_augment: Optional[SpecAugment] = None,
|
||||||
model_avg: Optional[nn.Module] = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -921,6 +1088,8 @@ def train_one_epoch(
|
|||||||
Dataloader for the validation dataset.
|
Dataloader for the validation dataset.
|
||||||
scaler:
|
scaler:
|
||||||
The scaler used for mix precision training.
|
The scaler used for mix precision training.
|
||||||
|
spec_augment:
|
||||||
|
The SpecAugment instance used only when use_cr_ctc is True.
|
||||||
model_avg:
|
model_avg:
|
||||||
The stored model averaged from the start of training.
|
The stored model averaged from the start of training.
|
||||||
tb_writer:
|
tb_writer:
|
||||||
@ -965,6 +1134,7 @@ def train_one_epoch(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
|
spec_augment=spec_augment,
|
||||||
)
|
)
|
||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
@ -1128,10 +1298,17 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.sos_id = params.eos_id = sp.piece_to_id("<sos/eos>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
if not params.use_transducer:
|
if not params.use_transducer:
|
||||||
params.ctc_loss_scale = 1.0
|
if not params.use_attention_decoder:
|
||||||
|
params.ctc_loss_scale = 1.0
|
||||||
|
else:
|
||||||
|
assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, (
|
||||||
|
params.ctc_loss_scale,
|
||||||
|
params.attention_decoder_loss_scale,
|
||||||
|
)
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -1141,6 +1318,13 @@ def run(rank, world_size, args):
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
if params.use_cr_ctc:
|
||||||
|
assert params.use_ctc
|
||||||
|
assert not params.enable_spec_aug # we will do spec_augment in model.py
|
||||||
|
spec_augment = get_spec_augment(params)
|
||||||
|
else:
|
||||||
|
spec_augment = None
|
||||||
|
|
||||||
assert params.save_every_n >= params.average_period
|
assert params.save_every_n >= params.average_period
|
||||||
model_avg: Optional[nn.Module] = None
|
model_avg: Optional[nn.Module] = None
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
@ -1201,31 +1385,7 @@ def run(rank, world_size, args):
|
|||||||
# an utterance duration distribution for your dataset to select
|
# an utterance duration distribution for your dataset to select
|
||||||
# the threshold
|
# the threshold
|
||||||
if c.duration < 2.0 or c.duration > 30.0:
|
if c.duration < 2.0 or c.duration > 30.0:
|
||||||
# logging.warning(
|
|
||||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
|
||||||
# )
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# In pruned RNN-T, we require that T >= S
|
|
||||||
# where T is the number of feature frames after subsampling
|
|
||||||
# and S is the number of tokens in the utterance
|
|
||||||
|
|
||||||
# In ./zipformer.py, the conv module uses the following expression
|
|
||||||
# for subsampling
|
|
||||||
T = ((c.num_frames - 7) // 2 + 1) // 2
|
|
||||||
tokens = sp.encode(c.supervisions[0].text, out_type=str)
|
|
||||||
|
|
||||||
if T < len(tokens):
|
|
||||||
logging.warning(
|
|
||||||
f"Exclude cut with ID {c.id} from training. "
|
|
||||||
f"Number of frames (before subsampling): {c.num_frames}. "
|
|
||||||
f"Number of frames (after subsampling): {T}. "
|
|
||||||
f"Text: {c.supervisions[0].text}. "
|
|
||||||
f"Tokens: {tokens}. "
|
|
||||||
f"Number of tokens: {len(tokens)}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
libriheavy = LibriHeavyAsrDataModule(args)
|
libriheavy = LibriHeavyAsrDataModule(args)
|
||||||
@ -1259,14 +1419,15 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
valid_dl = libriheavy.valid_dataloaders(valid_cuts)
|
valid_dl = libriheavy.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
# if not params.print_diagnostics:
|
if not params.print_diagnostics and params.scan_for_oom_batches:
|
||||||
# scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
# model=model,
|
model=model,
|
||||||
# train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
# optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
# sp=sp,
|
sp=sp,
|
||||||
# params=params,
|
params=params,
|
||||||
# )
|
spec_augment=spec_augment,
|
||||||
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
@ -1292,6 +1453,7 @@ def run(rank, world_size, args):
|
|||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
|
spec_augment=spec_augment,
|
||||||
tb_writer=tb_writer,
|
tb_writer=tb_writer,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
@ -1357,6 +1519,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
|
spec_augment: Optional[SpecAugment] = None,
|
||||||
):
|
):
|
||||||
from lhotse.dataset import find_pessimistic_batches
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
|
|
||||||
@ -1374,6 +1537,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
|
spec_augment=spec_augment,
|
||||||
)
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
#
|
#
|
||||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||||
# Liyong Guo,
|
# Liyong Guo,
|
||||||
# Quandong Wang,
|
# Quandong Wang,
|
||||||
# Zengwei Yao)
|
# Zengwei Yao,
|
||||||
|
# Wei Kang)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
|
@ -498,6 +498,17 @@ def get_parser():
|
|||||||
help="Add hooks to check for infinite module outputs and gradients.",
|
help="Add hooks to check for infinite module outputs and gradients.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--scan-for-oom-batches",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""
|
||||||
|
Whether to scan for oom batches before training, this is helpful for
|
||||||
|
finding the suitable max_duration, you only need to run it once.
|
||||||
|
Caution: a little time consuming.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-every-n",
|
"--save-every-n",
|
||||||
type=int,
|
type=int,
|
||||||
@ -1388,27 +1399,6 @@ def run(rank, world_size, args):
|
|||||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||||
# )
|
# )
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# In pruned RNN-T, we require that T >= S
|
|
||||||
# where T is the number of feature frames after subsampling
|
|
||||||
# and S is the number of tokens in the utterance
|
|
||||||
|
|
||||||
# In ./zipformer.py, the conv module uses the following expression
|
|
||||||
# for subsampling
|
|
||||||
T = ((c.num_frames - 7) // 2 + 1) // 2
|
|
||||||
tokens = sp.encode(c.supervisions[0].text, out_type=str)
|
|
||||||
|
|
||||||
if T < len(tokens):
|
|
||||||
logging.warning(
|
|
||||||
f"Exclude cut with ID {c.id} from training. "
|
|
||||||
f"Number of frames (before subsampling): {c.num_frames}. "
|
|
||||||
f"Number of frames (after subsampling): {T}. "
|
|
||||||
f"Text: {c.supervisions[0].text}. "
|
|
||||||
f"Tokens: {tokens}. "
|
|
||||||
f"Number of tokens: {len(tokens)}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
@ -1428,7 +1418,7 @@ def run(rank, world_size, args):
|
|||||||
valid_cuts += librispeech.dev_other_cuts()
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
if not params.print_diagnostics and params.scan_for_oom_batches:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
model=model,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user