mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
add ctc branch, with delay penalty
This commit is contained in:
parent
6ce36934cb
commit
3d4b8eb58b
@ -69,6 +69,12 @@ class Transducer(nn.Module):
|
|||||||
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5)
|
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5)
|
||||||
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
||||||
|
|
||||||
|
self.ctc_output = nn.Sequential(
|
||||||
|
nn.Dropout(p=0.1),
|
||||||
|
ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5),
|
||||||
|
nn.LogSoftmax(dim=-1),
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -80,7 +86,7 @@ class Transducer(nn.Module):
|
|||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
reduction: str = "sum",
|
reduction: str = "sum",
|
||||||
delay_penalty: float = 0.0,
|
delay_penalty: float = 0.0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
@ -113,8 +119,7 @@ class Transducer(nn.Module):
|
|||||||
See https://github.com/k2-fsa/k2/issues/955 and
|
See https://github.com/k2-fsa/k2/issues/955 and
|
||||||
https://arxiv.org/pdf/2211.00490.pdf for more details.
|
https://arxiv.org/pdf/2211.00490.pdf for more details.
|
||||||
Returns:
|
Returns:
|
||||||
Returns:
|
Return a tuple containing simple loss, pruned loss, and ctc-output.
|
||||||
Return the transducer loss.
|
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||||
@ -132,6 +137,9 @@ class Transducer(nn.Module):
|
|||||||
encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup)
|
encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup)
|
||||||
assert torch.all(x_lens > 0)
|
assert torch.all(x_lens > 0)
|
||||||
|
|
||||||
|
# compute ctc log-probs
|
||||||
|
ctc_output = self.ctc_output(encoder_out)
|
||||||
|
|
||||||
# Now for the decoder, i.e., the prediction network
|
# Now for the decoder, i.e., the prediction network
|
||||||
row_splits = y.shape.row_splits(1)
|
row_splits = y.shape.row_splits(1)
|
||||||
y_lens = row_splits[1:] - row_splits[:-1]
|
y_lens = row_splits[1:] - row_splits[:-1]
|
||||||
@ -204,4 +212,4 @@ class Transducer(nn.Module):
|
|||||||
reduction=reduction,
|
reduction=reduction,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (simple_loss, pruned_loss)
|
return (simple_loss, pruned_loss, ctc_output)
|
||||||
|
@ -97,6 +97,7 @@ from icefall.utils import (
|
|||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
display_and_save_batch,
|
display_and_save_batch,
|
||||||
|
encode_supervisions,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
)
|
)
|
||||||
@ -273,6 +274,13 @@ def get_parser():
|
|||||||
"with this parameter before adding to the final loss.",
|
"with this parameter before adding to the final loss.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ctc-loss-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.2,
|
||||||
|
help="Scale for CTC loss.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
@ -341,6 +349,16 @@ def get_parser():
|
|||||||
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
|
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ctc-delay-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""A constant value used to penalize symbol delay for CTC loss,
|
||||||
|
to encourage streaming models to emit symbols earlier.
|
||||||
|
See https://github.com/k2-fsa/k2/issues/955 and
|
||||||
|
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -412,6 +430,9 @@ def get_params() -> AttributeDict:
|
|||||||
"decoder_dim": 512,
|
"decoder_dim": 512,
|
||||||
# parameters for joiner
|
# parameters for joiner
|
||||||
"joiner_dim": 512,
|
"joiner_dim": 512,
|
||||||
|
# parameters for ctc loss
|
||||||
|
"beam_size": 10,
|
||||||
|
"use_double_scores": True,
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"model_warm_step": 3000, # arg given to model, not for lrate
|
"model_warm_step": 3000, # arg given to model, not for lrate
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
@ -627,11 +648,11 @@ def compute_loss(
|
|||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
y = sp.encode(texts, out_type=int)
|
token_ids = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(token_ids).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss = model(
|
simple_loss, pruned_loss, ctc_output = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -640,7 +661,7 @@ def compute_loss(
|
|||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
reduction="none",
|
reduction="none",
|
||||||
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
|
delay_penalty=params.delay_penalty if warmup >= 1.0 else 0,
|
||||||
)
|
)
|
||||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||||
@ -674,6 +695,38 @@ def compute_loss(
|
|||||||
)
|
)
|
||||||
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||||
|
|
||||||
|
# Compute ctc loss
|
||||||
|
|
||||||
|
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||||
|
# different duration in decreasing order, required by
|
||||||
|
# `k2.intersect_dense` called in `k2.ctc_loss`
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
supervision_segments, token_ids = encode_supervisions(
|
||||||
|
supervisions,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
token_ids=token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Works with a BPE model
|
||||||
|
decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device)
|
||||||
|
dense_fsa_vec = k2.DenseFsaVec(
|
||||||
|
ctc_output,
|
||||||
|
supervision_segments,
|
||||||
|
allow_truncate=params.subsampling_factor - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctc_loss = k2.ctc_loss(
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
dense_fsa_vec=dense_fsa_vec,
|
||||||
|
output_beam=params.beam_size,
|
||||||
|
delay_penalty=params.ctc_delay_penalty if warmup >= 1.0 else 0.0,
|
||||||
|
reduction="sum",
|
||||||
|
use_double_scores=params.use_double_scores,
|
||||||
|
)
|
||||||
|
assert ctc_loss.requires_grad == is_training
|
||||||
|
loss += params.ctc_loss_scale * ctc_loss
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
@ -698,6 +751,7 @@ def compute_loss(
|
|||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||||
|
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user