Use CTC loss as auxiliary loss.

This commit is contained in:
Fangjun Kuang 2022-01-22 16:04:17 +08:00
parent d6050eb02e
commit 69679e023e
2 changed files with 88 additions and 5 deletions

View File

@ -74,7 +74,10 @@ class Transducer(nn.Module):
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
Returns:
Return the transducer loss.
Return a tuple containing:
- the transducer loss, a tensor containing only one entry
- encoder_out, a tensor of shape (N, num_frames, encoder_out_dim)
- encoder_out_lens, a tensor of shape (N,)
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
@ -123,4 +126,8 @@ class Transducer(nn.Module):
from_log_softmax=False,
)
return loss
return (
loss,
encoder_out,
x_lens,
)

View File

@ -138,6 +138,15 @@ def get_parser():
"2 means tri-gram",
)
parser.add_argument(
"--ctc-weight",
type=float,
default=0.25,
help="""If not zero, the total loss is:
(1 - ctc_weight) * transdcuder_loss + ctc_weight * ctc_loss
""",
)
return parser
@ -206,6 +215,13 @@ def get_params() -> AttributeDict:
"vgg_frontend": False,
# parameters for Noam
"warm_step": 80000, # For the 100h subset, use 8k
#
# parameters for ctc_loss, used only when ctc_weight > 0
"modified_ctc_topo": False,
"beam_size": 10,
"reduction": "sum",
"use_double_scores": True,
#
"env_info": get_env_info(),
}
)
@ -259,6 +275,17 @@ def get_transducer_model(params: AttributeDict):
return model
def get_ctc_model(params: AttributeDict):
if params.ctc_weight > 0:
return nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(params.encoder_out_dim, params.vocab_size),
nn.LogSoftmax(dim=-1),
)
else:
return None
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
@ -379,11 +406,52 @@ def compute_loss(
feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
token_ids = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(token_ids).to(device)
with torch.set_grad_enabled(is_training):
loss = model(x=feature, x_lens=feature_lens, y=y)
transducer_loss, encoder_out, encoder_out_lens = model(
x=feature, x_lens=feature_lens, y=y
)
loss = transducer_loss
if params.ctc_weight > 0:
ctc_model = (
model.module.ctc if hasattr(model, "module") else model.ctc
)
ctc_graph = k2.ctc_graph(
token_ids, modified=params.modified_ctc_topo, device=device
)
# Note: We assume `encoder_out_lens` is sorted in descending order.
# If not, it will throw in k2.ctc_loss().
supervision_segments = torch.stack(
[
torch.arange(encoder_out.size(0), dtype=torch.int32),
torch.zeros(encoder_out.size(0), dtype=torch.int32),
encoder_out_lens.cpu(),
],
dim=1,
).to(torch.int32)
nnet_out = ctc_model(encoder_out)
dense_fsa_vec = k2.DenseFsaVec(
nnet_out,
supervision_segments,
allow_truncate=0,
)
# Note: transducer_loss should use the same reduction as ctc_loss
ctc_loss = k2.ctc_loss(
decoding_graph=ctc_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
assert ctc_loss.requires_grad == is_training
loss = (
1 - params.ctc_weight
) * transducer_loss + params.ctc_weight * ctc_loss
assert loss.requires_grad == is_training
@ -392,6 +460,9 @@ def compute_loss(
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
info["transducer_loss"] = transducer_loss.detach().cpu().item()
if params.ctc_weight > 0:
info["ctc_loss"] = ctc_loss.detach().cpu().item()
return loss, info
@ -574,6 +645,11 @@ def run(rank, world_size, args):
logging.info("About to create model")
model = get_transducer_model(params)
model.ctc = get_ctc_model(params)
if model.ctc is not None:
logging.info(f"Enable CTC loss with weight: {params.ctc_weight}")
else:
logging.info("Disable CTC loss")
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")