add ctc branch, with delay penalty

This commit is contained in:
yaozengwei 2023-02-13 12:11:35 +08:00
parent 6ce36934cb
commit 3d4b8eb58b
2 changed files with 70 additions and 8 deletions

View File

@ -69,6 +69,12 @@ class Transducer(nn.Module):
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5)
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5),
nn.LogSoftmax(dim=-1),
)
def forward(
self,
x: torch.Tensor,
@ -80,7 +86,7 @@ class Transducer(nn.Module):
warmup: float = 1.0,
reduction: str = "sum",
delay_penalty: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
@ -113,8 +119,7 @@ class Transducer(nn.Module):
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.
Returns:
Returns:
Return the transducer loss.
Return a tuple containing simple loss, pruned loss, and ctc-output.
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
@ -132,6 +137,9 @@ class Transducer(nn.Module):
encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup)
assert torch.all(x_lens > 0)
# compute ctc log-probs
ctc_output = self.ctc_output(encoder_out)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
@ -204,4 +212,4 @@ class Transducer(nn.Module):
reduction=reduction,
)
return (simple_loss, pruned_loss)
return (simple_loss, pruned_loss, ctc_output)

View File

@ -97,6 +97,7 @@ from icefall.utils import (
AttributeDict,
MetricsTracker,
display_and_save_batch,
encode_supervisions,
setup_logger,
str2bool,
)
@ -273,6 +274,13 @@ def get_parser():
"with this parameter before adding to the final loss.",
)
parser.add_argument(
"--ctc-loss-scale",
type=float,
default=0.2,
help="Scale for CTC loss.",
)
parser.add_argument(
"--seed",
type=int,
@ -341,6 +349,16 @@ def get_parser():
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
)
parser.add_argument(
"--ctc-delay-penalty",
type=float,
default=0.0,
help="""A constant value used to penalize symbol delay for CTC loss,
to encourage streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
)
add_model_arguments(parser)
return parser
@ -412,6 +430,9 @@ def get_params() -> AttributeDict:
"decoder_dim": 512,
# parameters for joiner
"joiner_dim": 512,
# parameters for ctc loss
"beam_size": 10,
"use_double_scores": True,
# parameters for Noam
"model_warm_step": 3000, # arg given to model, not for lrate
"env_info": get_env_info(),
@ -627,11 +648,11 @@ def compute_loss(
feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
token_ids = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(token_ids).to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(
simple_loss, pruned_loss, ctc_output = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -640,7 +661,7 @@ def compute_loss(
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
delay_penalty=params.delay_penalty if warmup >= 1.0 else 0,
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
@ -674,6 +695,38 @@ def compute_loss(
)
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
# Compute ctc loss
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
with warnings.catch_warnings():
warnings.simplefilter("ignore")
supervision_segments, token_ids = encode_supervisions(
supervisions,
subsampling_factor=params.subsampling_factor,
token_ids=token_ids,
)
# Works with a BPE model
decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device)
dense_fsa_vec = k2.DenseFsaVec(
ctc_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
delay_penalty=params.ctc_delay_penalty if warmup >= 1.0 else 0.0,
reduction="sum",
use_double_scores=params.use_double_scores,
)
assert ctc_loss.requires_grad == is_training
loss += params.ctc_loss_scale * ctc_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
@ -698,6 +751,7 @@ def compute_loss(
info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
info["ctc_loss"] = ctc_loss.detach().cpu().item()
return loss, info