mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Use CTC loss as auxiliary loss.
This commit is contained in:
parent
d6050eb02e
commit
69679e023e
@ -74,7 +74,10 @@ class Transducer(nn.Module):
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
Return a tuple containing:
|
||||
- the transducer loss, a tensor containing only one entry
|
||||
- encoder_out, a tensor of shape (N, num_frames, encoder_out_dim)
|
||||
- encoder_out_lens, a tensor of shape (N,)
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
@ -123,4 +126,8 @@ class Transducer(nn.Module):
|
||||
from_log_softmax=False,
|
||||
)
|
||||
|
||||
return loss
|
||||
return (
|
||||
loss,
|
||||
encoder_out,
|
||||
x_lens,
|
||||
)
|
||||
|
@ -138,6 +138,15 @@ def get_parser():
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ctc-weight",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="""If not zero, the total loss is:
|
||||
(1 - ctc_weight) * transdcuder_loss + ctc_weight * ctc_loss
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -206,6 +215,13 @@ def get_params() -> AttributeDict:
|
||||
"vgg_frontend": False,
|
||||
# parameters for Noam
|
||||
"warm_step": 80000, # For the 100h subset, use 8k
|
||||
#
|
||||
# parameters for ctc_loss, used only when ctc_weight > 0
|
||||
"modified_ctc_topo": False,
|
||||
"beam_size": 10,
|
||||
"reduction": "sum",
|
||||
"use_double_scores": True,
|
||||
#
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
@ -259,6 +275,17 @@ def get_transducer_model(params: AttributeDict):
|
||||
return model
|
||||
|
||||
|
||||
def get_ctc_model(params: AttributeDict):
|
||||
if params.ctc_weight > 0:
|
||||
return nn.Sequential(
|
||||
nn.Dropout(p=0.1),
|
||||
nn.Linear(params.encoder_out_dim, params.vocab_size),
|
||||
nn.LogSoftmax(dim=-1),
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
@ -379,11 +406,52 @@ def compute_loss(
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
y = sp.encode(texts, out_type=int)
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
token_ids = sp.encode(texts, out_type=int)
|
||||
y = k2.RaggedTensor(token_ids).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
loss = model(x=feature, x_lens=feature_lens, y=y)
|
||||
transducer_loss, encoder_out, encoder_out_lens = model(
|
||||
x=feature, x_lens=feature_lens, y=y
|
||||
)
|
||||
loss = transducer_loss
|
||||
|
||||
if params.ctc_weight > 0:
|
||||
ctc_model = (
|
||||
model.module.ctc if hasattr(model, "module") else model.ctc
|
||||
)
|
||||
ctc_graph = k2.ctc_graph(
|
||||
token_ids, modified=params.modified_ctc_topo, device=device
|
||||
)
|
||||
# Note: We assume `encoder_out_lens` is sorted in descending order.
|
||||
# If not, it will throw in k2.ctc_loss().
|
||||
supervision_segments = torch.stack(
|
||||
[
|
||||
torch.arange(encoder_out.size(0), dtype=torch.int32),
|
||||
torch.zeros(encoder_out.size(0), dtype=torch.int32),
|
||||
encoder_out_lens.cpu(),
|
||||
],
|
||||
dim=1,
|
||||
).to(torch.int32)
|
||||
nnet_out = ctc_model(encoder_out)
|
||||
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
nnet_out,
|
||||
supervision_segments,
|
||||
allow_truncate=0,
|
||||
)
|
||||
|
||||
# Note: transducer_loss should use the same reduction as ctc_loss
|
||||
ctc_loss = k2.ctc_loss(
|
||||
decoding_graph=ctc_graph,
|
||||
dense_fsa_vec=dense_fsa_vec,
|
||||
output_beam=params.beam_size,
|
||||
reduction=params.reduction,
|
||||
use_double_scores=params.use_double_scores,
|
||||
)
|
||||
assert ctc_loss.requires_grad == is_training
|
||||
loss = (
|
||||
1 - params.ctc_weight
|
||||
) * transducer_loss + params.ctc_weight * ctc_loss
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
@ -392,6 +460,9 @@ def compute_loss(
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
info["transducer_loss"] = transducer_loss.detach().cpu().item()
|
||||
if params.ctc_weight > 0:
|
||||
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
||||
@ -574,6 +645,11 @@ def run(rank, world_size, args):
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
model.ctc = get_ctc_model(params)
|
||||
if model.ctc is not None:
|
||||
logging.info(f"Enable CTC loss with weight: {params.ctc_weight}")
|
||||
else:
|
||||
logging.info("Disable CTC loss")
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user