mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Merge a85592dc784de1ac77ca4dac861b46b5e13b539d into f84270c93528f4b77b99ada9ac0c9f7fb231d6a4
This commit is contained in:
commit
0be47510a9
@ -16,15 +16,17 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
from icefall.utils import add_sos, make_pad_mask
|
from icefall.utils import add_sos, make_pad_mask, time_warp
|
||||||
|
from lhotse.dataset import SpecAugment
|
||||||
|
|
||||||
|
|
||||||
class AsrModel(nn.Module):
|
class AsrModel(nn.Module):
|
||||||
@ -110,9 +112,8 @@ class AsrModel(nn.Module):
|
|||||||
if use_ctc:
|
if use_ctc:
|
||||||
# Modules for CTC head
|
# Modules for CTC head
|
||||||
self.ctc_output = nn.Sequential(
|
self.ctc_output = nn.Sequential(
|
||||||
nn.Dropout(p=0.1),
|
nn.Dropout(p=0.1), # TODO: test removing this
|
||||||
nn.Linear(encoder_dim, vocab_size),
|
nn.Linear(encoder_dim, vocab_size),
|
||||||
nn.LogSoftmax(dim=-1),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.use_attention_decoder = use_attention_decoder
|
self.use_attention_decoder = use_attention_decoder
|
||||||
@ -158,28 +159,82 @@ class AsrModel(nn.Module):
|
|||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
targets: torch.Tensor,
|
targets: torch.Tensor,
|
||||||
target_lengths: torch.Tensor,
|
target_lengths: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
use_consistency_reg: bool = False,
|
||||||
"""Compute CTC loss.
|
use_smooth_reg: bool = False,
|
||||||
|
smooth_kernel: List[float] = [0.25, 0.5, 0.25],
|
||||||
|
eps: float = 1e-6,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Compute CTC loss with consistency regularization loss.
|
||||||
Args:
|
Args:
|
||||||
encoder_out:
|
encoder_out:
|
||||||
Encoder output, of shape (N, T, C).
|
Encoder output, of shape (N or 2 * N, T, C).
|
||||||
encoder_out_lens:
|
encoder_out_lens:
|
||||||
Encoder output lengths, of shape (N,).
|
Encoder output lengths, of shape (N or 2 * N,).
|
||||||
targets:
|
targets:
|
||||||
Target Tensor of shape (sum(target_lengths)). The targets are assumed
|
Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed
|
||||||
to be un-padded and concatenated within 1 dimension.
|
to be un-padded and concatenated within 1 dimension.
|
||||||
|
use_consistency_reg:
|
||||||
|
Whether use consistency regularization.
|
||||||
|
use_smooth_reg:
|
||||||
|
Whether use smooth regularization.
|
||||||
"""
|
"""
|
||||||
# Compute CTC log-prob
|
ctc_output = self.ctc_output(encoder_out) # (N or 2 * N, T, C)
|
||||||
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1)
|
||||||
|
|
||||||
|
if not use_smooth_reg:
|
||||||
|
ctc_log_probs = F.log_softmax(ctc_output, dim=-1)
|
||||||
|
else:
|
||||||
|
ctc_probs = ctc_output.softmax(dim=-1) # Used in sr_loss
|
||||||
|
ctc_log_probs = (ctc_probs + eps).log()
|
||||||
|
|
||||||
|
# Compute CTC loss
|
||||||
ctc_loss = torch.nn.functional.ctc_loss(
|
ctc_loss = torch.nn.functional.ctc_loss(
|
||||||
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
log_probs=ctc_log_probs.permute(1, 0, 2), # (T, N or 2 * N, C)
|
||||||
targets=targets.cpu(),
|
targets=targets.cpu(),
|
||||||
input_lengths=encoder_out_lens.cpu(),
|
input_lengths=encoder_out_lens.cpu(),
|
||||||
target_lengths=target_lengths.cpu(),
|
target_lengths=target_lengths.cpu(),
|
||||||
reduction="sum",
|
reduction="sum",
|
||||||
)
|
)
|
||||||
return ctc_loss
|
|
||||||
|
if use_consistency_reg:
|
||||||
|
assert ctc_log_probs.shape[0] % 2 == 0
|
||||||
|
# Compute cr_loss
|
||||||
|
exchanged_targets = ctc_log_probs.detach().chunk(2, dim=0)
|
||||||
|
exchanged_targets = torch.cat(
|
||||||
|
[exchanged_targets[1], exchanged_targets[0]], dim=0
|
||||||
|
) # exchange: [x1, x2] -> [x2, x1]
|
||||||
|
cr_loss = nn.functional.kl_div(
|
||||||
|
input=ctc_log_probs,
|
||||||
|
target=exchanged_targets,
|
||||||
|
reduction="none",
|
||||||
|
log_target=True,
|
||||||
|
) # (2 * N, T, C)
|
||||||
|
cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum()
|
||||||
|
else:
|
||||||
|
cr_loss = torch.empty(0)
|
||||||
|
|
||||||
|
if use_smooth_reg:
|
||||||
|
# Hard code the kernel here, could try other values
|
||||||
|
assert len(smooth_kernel) == 3 and sum(smooth_kernel) == 1.0, smooth_kernel
|
||||||
|
smooth_kernel = torch.tensor(smooth_kernel, dtype=ctc_probs.dtype,
|
||||||
|
device=ctc_probs.device, requires_grad=False)
|
||||||
|
smooth_kernel = smooth_kernel.unsqueeze(0).unsqueeze(1).expand(ctc_probs.shape[-1], 1, 3)
|
||||||
|
# Now kernel: (C, 1, 3)
|
||||||
|
smoothed_ctc_probs = F.conv1d(
|
||||||
|
ctc_probs.detach().permute(0, 2, 1), # (N or 2 * N, C, T)
|
||||||
|
weight=smooth_kernel, stride=1, padding=0, groups=ctc_probs.shape[-1]
|
||||||
|
).permute(0, 2, 1) # (N or 2 * N, T - 2, C)
|
||||||
|
sr_loss = nn.functional.kl_div(
|
||||||
|
input=ctc_log_probs[:, 1:-1],
|
||||||
|
target=(smoothed_ctc_probs + eps).log(),
|
||||||
|
reduction="none",
|
||||||
|
log_target=True,
|
||||||
|
) # (N, T - 1 , C)
|
||||||
|
sr_loss = sr_loss.masked_fill(length_mask[:, 1:-1], 0.0).sum()
|
||||||
|
else:
|
||||||
|
sr_loss = torch.empty(0)
|
||||||
|
|
||||||
|
return ctc_loss, cr_loss, sr_loss
|
||||||
|
|
||||||
def forward_transducer(
|
def forward_transducer(
|
||||||
self,
|
self,
|
||||||
@ -296,7 +351,13 @@ class AsrModel(nn.Module):
|
|||||||
prune_range: int = 5,
|
prune_range: int = 5,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
use_cr_ctc: bool = False,
|
||||||
|
use_sr_ctc: bool = False,
|
||||||
|
use_spec_aug: bool = False,
|
||||||
|
spec_augment: Optional[SpecAugment] = None,
|
||||||
|
supervision_segments: Optional[torch.Tensor] = None,
|
||||||
|
time_warp_factor: Optional[int] = 80,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
@ -316,9 +377,28 @@ class AsrModel(nn.Module):
|
|||||||
lm_scale:
|
lm_scale:
|
||||||
The scale to smooth the loss with lm (output of predictor network)
|
The scale to smooth the loss with lm (output of predictor network)
|
||||||
part
|
part
|
||||||
|
use_cr_ctc:
|
||||||
|
Whether use consistency-regularized CTC.
|
||||||
|
use_sr_ctc:
|
||||||
|
Whether use smooth-regularized CTC.
|
||||||
|
use_spec_aug:
|
||||||
|
Whether apply spec-augment manually, used only if use_cr_ctc is True.
|
||||||
|
spec_augment:
|
||||||
|
The SpecAugment instance that returns time masks,
|
||||||
|
used only if use_cr_ctc is True.
|
||||||
|
supervision_segments:
|
||||||
|
An int tensor of shape ``(S, 3)``. ``S`` is the number of
|
||||||
|
supervision segments that exist in ``features``.
|
||||||
|
Used only if use_cr_ctc is True.
|
||||||
|
time_warp_factor:
|
||||||
|
Parameter for the time warping; larger values mean more warping.
|
||||||
|
Set to ``None``, or less than ``1``, to disable.
|
||||||
|
Used only if use_cr_ctc is True.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer losses and CTC loss,
|
Return the transducer losses, CTC loss, AED loss,
|
||||||
in form of (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss)
|
and consistency-regularization loss in form of
|
||||||
|
(simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss)
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||||
@ -334,6 +414,24 @@ class AsrModel(nn.Module):
|
|||||||
|
|
||||||
device = x.device
|
device = x.device
|
||||||
|
|
||||||
|
if use_cr_ctc:
|
||||||
|
assert self.use_ctc
|
||||||
|
if use_spec_aug:
|
||||||
|
assert spec_augment is not None and spec_augment.time_warp_factor < 1
|
||||||
|
# Apply time warping before input duplicating
|
||||||
|
assert supervision_segments is not None
|
||||||
|
x = time_warp(
|
||||||
|
x,
|
||||||
|
time_warp_factor=time_warp_factor,
|
||||||
|
supervision_segments=supervision_segments,
|
||||||
|
)
|
||||||
|
# Independently apply frequency masking and time masking to the two copies
|
||||||
|
x = spec_augment(x.repeat(2, 1, 1))
|
||||||
|
else:
|
||||||
|
x = x.repeat(2, 1, 1)
|
||||||
|
x_lens = x_lens.repeat(2)
|
||||||
|
y = k2.ragged.cat([y, y], axis=0)
|
||||||
|
|
||||||
# Compute encoder outputs
|
# Compute encoder outputs
|
||||||
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
||||||
|
|
||||||
@ -351,6 +449,9 @@ class AsrModel(nn.Module):
|
|||||||
am_scale=am_scale,
|
am_scale=am_scale,
|
||||||
lm_scale=lm_scale,
|
lm_scale=lm_scale,
|
||||||
)
|
)
|
||||||
|
if use_cr_ctc:
|
||||||
|
simple_loss = simple_loss * 0.5
|
||||||
|
pruned_loss = pruned_loss * 0.5
|
||||||
else:
|
else:
|
||||||
simple_loss = torch.empty(0)
|
simple_loss = torch.empty(0)
|
||||||
pruned_loss = torch.empty(0)
|
pruned_loss = torch.empty(0)
|
||||||
@ -358,14 +459,23 @@ class AsrModel(nn.Module):
|
|||||||
if self.use_ctc:
|
if self.use_ctc:
|
||||||
# Compute CTC loss
|
# Compute CTC loss
|
||||||
targets = y.values
|
targets = y.values
|
||||||
ctc_loss = self.forward_ctc(
|
ctc_loss, cr_loss, sr_loss = self.forward_ctc(
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
target_lengths=y_lens,
|
target_lengths=y_lens,
|
||||||
|
use_consistency_reg=use_cr_ctc,
|
||||||
|
use_smooth_reg=use_sr_ctc,
|
||||||
)
|
)
|
||||||
|
if use_cr_ctc:
|
||||||
|
# We duplicate the batch when use_cr_ctc is True
|
||||||
|
ctc_loss = ctc_loss * 0.5
|
||||||
|
cr_loss = cr_loss * 0.5
|
||||||
|
sr_loss = sr_loss * 0.5
|
||||||
else:
|
else:
|
||||||
ctc_loss = torch.empty(0)
|
ctc_loss = torch.empty(0)
|
||||||
|
cr_loss = torch.empty(0)
|
||||||
|
sr_loss = torch.empty(0)
|
||||||
|
|
||||||
if self.use_attention_decoder:
|
if self.use_attention_decoder:
|
||||||
attention_decoder_loss = self.attention_decoder.calc_att_loss(
|
attention_decoder_loss = self.attention_decoder.calc_att_loss(
|
||||||
@ -374,7 +484,9 @@ class AsrModel(nn.Module):
|
|||||||
ys=y.to(device),
|
ys=y.to(device),
|
||||||
ys_lens=y_lens.to(device),
|
ys_lens=y_lens.to(device),
|
||||||
)
|
)
|
||||||
|
if use_cr_ctc:
|
||||||
|
attention_decoder_loss = attention_decoder_loss * 0.5
|
||||||
else:
|
else:
|
||||||
attention_decoder_loss = torch.empty(0)
|
attention_decoder_loss = torch.empty(0)
|
||||||
|
|
||||||
return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss
|
return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, sr_loss
|
||||||
|
@ -72,6 +72,7 @@ from attention_decoder import AttentionDecoderModel
|
|||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
|
from lhotse.dataset import SpecAugment
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import AsrModel
|
from model import AsrModel
|
||||||
@ -304,6 +305,20 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="If True, use attention-decoder head.",
|
help="If True, use attention-decoder head.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-cr-ctc",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="If True, use consistency-regularized CTC.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-sr-ctc",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="If True, use smooth-regularized CTC.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -449,6 +464,27 @@ def get_parser():
|
|||||||
help="Scale for CTC loss.",
|
help="Scale for CTC loss.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--cr-loss-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.2,
|
||||||
|
help="Scale for consistency-regularization loss.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sr-loss-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.2,
|
||||||
|
help="Scale for smooth-regularization loss.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--time-mask-ratio",
|
||||||
|
type=float,
|
||||||
|
default=2.5,
|
||||||
|
help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--attention-decoder-loss-scale",
|
"--attention-decoder-loss-scale",
|
||||||
type=float,
|
type=float,
|
||||||
@ -717,6 +753,24 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_spec_augment(params: AttributeDict) -> SpecAugment:
|
||||||
|
num_frame_masks = int(10 * params.time_mask_ratio)
|
||||||
|
max_frames_mask_fraction = 0.15 * params.time_mask_ratio
|
||||||
|
logging.info(
|
||||||
|
f"num_frame_masks: {num_frame_masks}, "
|
||||||
|
f"max_frames_mask_fraction: {max_frames_mask_fraction}"
|
||||||
|
)
|
||||||
|
spec_augment = SpecAugment(
|
||||||
|
time_warp_factor=0, # Do time warping in model.py
|
||||||
|
num_frame_masks=num_frame_masks, # default: 10
|
||||||
|
features_mask_size=27,
|
||||||
|
num_feature_masks=2,
|
||||||
|
frames_mask_size=100,
|
||||||
|
max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15
|
||||||
|
)
|
||||||
|
return spec_augment
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint_if_available(
|
def load_checkpoint_if_available(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -839,6 +893,7 @@ def compute_loss(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
|
spec_augment: Optional[SpecAugment] = None,
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute loss given the model and its inputs.
|
Compute loss given the model and its inputs.
|
||||||
@ -855,8 +910,8 @@ def compute_loss(
|
|||||||
True for training. False for validation. When it is True, this
|
True for training. False for validation. When it is True, this
|
||||||
function enables autograd during computation; when it is False, it
|
function enables autograd during computation; when it is False, it
|
||||||
disables autograd.
|
disables autograd.
|
||||||
warmup: a floating point value which increases throughout training;
|
spec_augment:
|
||||||
values >= 1.0 are fully warmed up and have all modules present.
|
The SpecAugment instance used only when use_cr_ctc is True.
|
||||||
"""
|
"""
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
@ -874,14 +929,36 @@ def compute_loss(
|
|||||||
y = sp.encode(texts, out_type=int)
|
y = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
|
use_cr_ctc = params.use_cr_ctc
|
||||||
|
use_sr_ctc = params.use_sr_ctc
|
||||||
|
use_spec_aug = use_cr_ctc and is_training
|
||||||
|
if use_spec_aug:
|
||||||
|
supervision_intervals = batch["supervisions"]
|
||||||
|
supervision_segments = torch.stack(
|
||||||
|
[
|
||||||
|
supervision_intervals["sequence_idx"],
|
||||||
|
supervision_intervals["start_frame"],
|
||||||
|
supervision_intervals["num_frames"],
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
) # shape: (S, 3)
|
||||||
|
else:
|
||||||
|
supervision_segments = None
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model(
|
simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, sr_loss = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
prune_range=params.prune_range,
|
prune_range=params.prune_range,
|
||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
|
use_cr_ctc=use_cr_ctc,
|
||||||
|
use_sr_ctc=use_sr_ctc,
|
||||||
|
use_spec_aug=use_spec_aug,
|
||||||
|
spec_augment=spec_augment,
|
||||||
|
supervision_segments=supervision_segments,
|
||||||
|
time_warp_factor=params.spec_aug_time_warp_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
@ -904,6 +981,10 @@ def compute_loss(
|
|||||||
|
|
||||||
if params.use_ctc:
|
if params.use_ctc:
|
||||||
loss += params.ctc_loss_scale * ctc_loss
|
loss += params.ctc_loss_scale * ctc_loss
|
||||||
|
if use_cr_ctc:
|
||||||
|
loss += params.cr_loss_scale * cr_loss
|
||||||
|
if use_sr_ctc:
|
||||||
|
loss += params.sr_loss_scale * sr_loss
|
||||||
|
|
||||||
if params.use_attention_decoder:
|
if params.use_attention_decoder:
|
||||||
loss += params.attention_decoder_loss_scale * attention_decoder_loss
|
loss += params.attention_decoder_loss_scale * attention_decoder_loss
|
||||||
@ -922,6 +1003,10 @@ def compute_loss(
|
|||||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||||
if params.use_ctc:
|
if params.use_ctc:
|
||||||
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||||
|
if use_cr_ctc:
|
||||||
|
info["cr_loss"] = cr_loss.detach().cpu().item()
|
||||||
|
if use_sr_ctc:
|
||||||
|
info["sr_loss"] = sr_loss.detach().cpu().item()
|
||||||
if params.use_attention_decoder:
|
if params.use_attention_decoder:
|
||||||
info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item()
|
info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item()
|
||||||
|
|
||||||
@ -971,6 +1056,7 @@ def train_one_epoch(
|
|||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: GradScaler,
|
||||||
|
spec_augment: Optional[SpecAugment] = None,
|
||||||
model_avg: Optional[nn.Module] = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -997,6 +1083,8 @@ def train_one_epoch(
|
|||||||
Dataloader for the validation dataset.
|
Dataloader for the validation dataset.
|
||||||
scaler:
|
scaler:
|
||||||
The scaler used for mix precision training.
|
The scaler used for mix precision training.
|
||||||
|
spec_augment:
|
||||||
|
The SpecAugment instance used only when use_cr_ctc is True.
|
||||||
model_avg:
|
model_avg:
|
||||||
The stored model averaged from the start of training.
|
The stored model averaged from the start of training.
|
||||||
tb_writer:
|
tb_writer:
|
||||||
@ -1043,6 +1131,7 @@ def train_one_epoch(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
|
spec_augment=spec_augment,
|
||||||
)
|
)
|
||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
@ -1238,6 +1327,13 @@ def run(rank, world_size, args):
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
if params.use_cr_ctc:
|
||||||
|
assert params.use_ctc
|
||||||
|
assert not params.enable_spec_aug # we will do spec_augment in model.py
|
||||||
|
spec_augment = get_spec_augment(params)
|
||||||
|
else:
|
||||||
|
spec_augment = None
|
||||||
|
|
||||||
assert params.save_every_n >= params.average_period
|
assert params.save_every_n >= params.average_period
|
||||||
model_avg: Optional[nn.Module] = None
|
model_avg: Optional[nn.Module] = None
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
@ -1360,6 +1456,7 @@ def run(rank, world_size, args):
|
|||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
params=params,
|
params=params,
|
||||||
|
spec_augment=spec_augment,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0)
|
scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0)
|
||||||
@ -1387,6 +1484,7 @@ def run(rank, world_size, args):
|
|||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
|
spec_augment=spec_augment,
|
||||||
tb_writer=tb_writer,
|
tb_writer=tb_writer,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
@ -1452,6 +1550,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
|
spec_augment: Optional[SpecAugment] = None,
|
||||||
):
|
):
|
||||||
from lhotse.dataset import find_pessimistic_batches
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
|
|
||||||
@ -1471,6 +1570,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
|
spec_augment=spec_augment,
|
||||||
)
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
@ -21,6 +21,7 @@ import argparse
|
|||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
@ -38,6 +39,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from lhotse.dataset.signal_transforms import time_warp as time_warp_impl
|
||||||
from pypinyin import lazy_pinyin, pinyin
|
from pypinyin import lazy_pinyin, pinyin
|
||||||
from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials
|
from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
@ -2271,3 +2273,41 @@ def num_tokens(
|
|||||||
if 0 in ans:
|
if 0 in ans:
|
||||||
num_tokens -= 1
|
num_tokens -= 1
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
|
|
||||||
|
# Based on https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py
|
||||||
|
def time_warp(
|
||||||
|
features: torch.Tensor,
|
||||||
|
p: float = 0.9,
|
||||||
|
time_warp_factor: Optional[int] = 80,
|
||||||
|
supervision_segments: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
"""Apply time warping on a batch of features
|
||||||
|
"""
|
||||||
|
if time_warp_factor is None or time_warp_factor < 1:
|
||||||
|
return features
|
||||||
|
assert len(features.shape) == 3, (
|
||||||
|
"SpecAugment only supports batches of single-channel feature matrices."
|
||||||
|
)
|
||||||
|
features = features.clone()
|
||||||
|
if supervision_segments is None:
|
||||||
|
# No supervisions - apply spec augment to full feature matrices.
|
||||||
|
for sequence_idx in range(features.size(0)):
|
||||||
|
if random.random() > p:
|
||||||
|
# Randomly choose whether this transform is applied
|
||||||
|
continue
|
||||||
|
features[sequence_idx] = time_warp_impl(
|
||||||
|
features[sequence_idx], factor=time_warp_factor
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Supervisions provided - we will apply time warping only on the supervised areas.
|
||||||
|
for sequence_idx, start_frame, num_frames in supervision_segments:
|
||||||
|
if random.random() > p:
|
||||||
|
# Randomly choose whether this transform is applied
|
||||||
|
continue
|
||||||
|
end_frame = start_frame + num_frames
|
||||||
|
features[sequence_idx, start_frame:end_frame] = time_warp_impl(
|
||||||
|
features[sequence_idx, start_frame:end_frame], factor=time_warp_factor
|
||||||
|
)
|
||||||
|
|
||||||
|
return features
|
||||||
|
Loading…
x
Reference in New Issue
Block a user