mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
add ctc branch, with delay penalty
This commit is contained in:
parent
6ce36934cb
commit
3d4b8eb58b
@ -69,6 +69,12 @@ class Transducer(nn.Module):
|
||||
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5)
|
||||
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
||||
|
||||
self.ctc_output = nn.Sequential(
|
||||
nn.Dropout(p=0.1),
|
||||
ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5),
|
||||
nn.LogSoftmax(dim=-1),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@ -80,7 +86,7 @@ class Transducer(nn.Module):
|
||||
warmup: float = 1.0,
|
||||
reduction: str = "sum",
|
||||
delay_penalty: float = 0.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
@ -113,8 +119,7 @@ class Transducer(nn.Module):
|
||||
See https://github.com/k2-fsa/k2/issues/955 and
|
||||
https://arxiv.org/pdf/2211.00490.pdf for more details.
|
||||
Returns:
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
Return a tuple containing simple loss, pruned loss, and ctc-output.
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
@ -132,6 +137,9 @@ class Transducer(nn.Module):
|
||||
encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup)
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
# compute ctc log-probs
|
||||
ctc_output = self.ctc_output(encoder_out)
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
@ -204,4 +212,4 @@ class Transducer(nn.Module):
|
||||
reduction=reduction,
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
||||
return (simple_loss, pruned_loss, ctc_output)
|
||||
|
@ -97,6 +97,7 @@ from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
display_and_save_batch,
|
||||
encode_supervisions,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
@ -273,6 +274,13 @@ def get_parser():
|
||||
"with this parameter before adding to the final loss.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ctc-loss-scale",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="Scale for CTC loss.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
@ -341,6 +349,16 @@ def get_parser():
|
||||
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ctc-delay-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""A constant value used to penalize symbol delay for CTC loss,
|
||||
to encourage streaming models to emit symbols earlier.
|
||||
See https://github.com/k2-fsa/k2/issues/955 and
|
||||
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -412,6 +430,9 @@ def get_params() -> AttributeDict:
|
||||
"decoder_dim": 512,
|
||||
# parameters for joiner
|
||||
"joiner_dim": 512,
|
||||
# parameters for ctc loss
|
||||
"beam_size": 10,
|
||||
"use_double_scores": True,
|
||||
# parameters for Noam
|
||||
"model_warm_step": 3000, # arg given to model, not for lrate
|
||||
"env_info": get_env_info(),
|
||||
@ -627,11 +648,11 @@ def compute_loss(
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
y = sp.encode(texts, out_type=int)
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
token_ids = sp.encode(texts, out_type=int)
|
||||
y = k2.RaggedTensor(token_ids).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss = model(
|
||||
simple_loss, pruned_loss, ctc_output = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -640,7 +661,7 @@ def compute_loss(
|
||||
lm_scale=params.lm_scale,
|
||||
warmup=warmup,
|
||||
reduction="none",
|
||||
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
|
||||
delay_penalty=params.delay_penalty if warmup >= 1.0 else 0,
|
||||
)
|
||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||
@ -674,6 +695,38 @@ def compute_loss(
|
||||
)
|
||||
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
|
||||
# Compute ctc loss
|
||||
|
||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||
# different duration in decreasing order, required by
|
||||
# `k2.intersect_dense` called in `k2.ctc_loss`
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
supervision_segments, token_ids = encode_supervisions(
|
||||
supervisions,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
token_ids=token_ids,
|
||||
)
|
||||
|
||||
# Works with a BPE model
|
||||
decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device)
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
ctc_output,
|
||||
supervision_segments,
|
||||
allow_truncate=params.subsampling_factor - 1,
|
||||
)
|
||||
|
||||
ctc_loss = k2.ctc_loss(
|
||||
decoding_graph=decoding_graph,
|
||||
dense_fsa_vec=dense_fsa_vec,
|
||||
output_beam=params.beam_size,
|
||||
delay_penalty=params.ctc_delay_penalty if warmup >= 1.0 else 0.0,
|
||||
reduction="sum",
|
||||
use_double_scores=params.use_double_scores,
|
||||
)
|
||||
assert ctc_loss.requires_grad == is_training
|
||||
loss += params.ctc_loss_scale * ctc_loss
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
@ -698,6 +751,7 @@ def compute_loss(
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user