mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Use CTC loss as auxiliary loss.
This commit is contained in:
parent
d6050eb02e
commit
69679e023e
@ -74,7 +74,10 @@ class Transducer(nn.Module):
|
|||||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||||
utterance.
|
utterance.
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer loss.
|
Return a tuple containing:
|
||||||
|
- the transducer loss, a tensor containing only one entry
|
||||||
|
- encoder_out, a tensor of shape (N, num_frames, encoder_out_dim)
|
||||||
|
- encoder_out_lens, a tensor of shape (N,)
|
||||||
"""
|
"""
|
||||||
assert x.ndim == 3, x.shape
|
assert x.ndim == 3, x.shape
|
||||||
assert x_lens.ndim == 1, x_lens.shape
|
assert x_lens.ndim == 1, x_lens.shape
|
||||||
@ -123,4 +126,8 @@ class Transducer(nn.Module):
|
|||||||
from_log_softmax=False,
|
from_log_softmax=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return loss
|
return (
|
||||||
|
loss,
|
||||||
|
encoder_out,
|
||||||
|
x_lens,
|
||||||
|
)
|
||||||
|
@ -138,6 +138,15 @@ def get_parser():
|
|||||||
"2 means tri-gram",
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ctc-weight",
|
||||||
|
type=float,
|
||||||
|
default=0.25,
|
||||||
|
help="""If not zero, the total loss is:
|
||||||
|
(1 - ctc_weight) * transdcuder_loss + ctc_weight * ctc_loss
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -206,6 +215,13 @@ def get_params() -> AttributeDict:
|
|||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"warm_step": 80000, # For the 100h subset, use 8k
|
"warm_step": 80000, # For the 100h subset, use 8k
|
||||||
|
#
|
||||||
|
# parameters for ctc_loss, used only when ctc_weight > 0
|
||||||
|
"modified_ctc_topo": False,
|
||||||
|
"beam_size": 10,
|
||||||
|
"reduction": "sum",
|
||||||
|
"use_double_scores": True,
|
||||||
|
#
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -259,6 +275,17 @@ def get_transducer_model(params: AttributeDict):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_ctc_model(params: AttributeDict):
|
||||||
|
if params.ctc_weight > 0:
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Dropout(p=0.1),
|
||||||
|
nn.Linear(params.encoder_out_dim, params.vocab_size),
|
||||||
|
nn.LogSoftmax(dim=-1),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint_if_available(
|
def load_checkpoint_if_available(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -379,11 +406,52 @@ def compute_loss(
|
|||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
y = sp.encode(texts, out_type=int)
|
token_ids = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(token_ids).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
loss = model(x=feature, x_lens=feature_lens, y=y)
|
transducer_loss, encoder_out, encoder_out_lens = model(
|
||||||
|
x=feature, x_lens=feature_lens, y=y
|
||||||
|
)
|
||||||
|
loss = transducer_loss
|
||||||
|
|
||||||
|
if params.ctc_weight > 0:
|
||||||
|
ctc_model = (
|
||||||
|
model.module.ctc if hasattr(model, "module") else model.ctc
|
||||||
|
)
|
||||||
|
ctc_graph = k2.ctc_graph(
|
||||||
|
token_ids, modified=params.modified_ctc_topo, device=device
|
||||||
|
)
|
||||||
|
# Note: We assume `encoder_out_lens` is sorted in descending order.
|
||||||
|
# If not, it will throw in k2.ctc_loss().
|
||||||
|
supervision_segments = torch.stack(
|
||||||
|
[
|
||||||
|
torch.arange(encoder_out.size(0), dtype=torch.int32),
|
||||||
|
torch.zeros(encoder_out.size(0), dtype=torch.int32),
|
||||||
|
encoder_out_lens.cpu(),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
).to(torch.int32)
|
||||||
|
nnet_out = ctc_model(encoder_out)
|
||||||
|
|
||||||
|
dense_fsa_vec = k2.DenseFsaVec(
|
||||||
|
nnet_out,
|
||||||
|
supervision_segments,
|
||||||
|
allow_truncate=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: transducer_loss should use the same reduction as ctc_loss
|
||||||
|
ctc_loss = k2.ctc_loss(
|
||||||
|
decoding_graph=ctc_graph,
|
||||||
|
dense_fsa_vec=dense_fsa_vec,
|
||||||
|
output_beam=params.beam_size,
|
||||||
|
reduction=params.reduction,
|
||||||
|
use_double_scores=params.use_double_scores,
|
||||||
|
)
|
||||||
|
assert ctc_loss.requires_grad == is_training
|
||||||
|
loss = (
|
||||||
|
1 - params.ctc_weight
|
||||||
|
) * transducer_loss + params.ctc_weight * ctc_loss
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
@ -392,6 +460,9 @@ def compute_loss(
|
|||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
|
info["transducer_loss"] = transducer_loss.detach().cpu().item()
|
||||||
|
if params.ctc_weight > 0:
|
||||||
|
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
@ -574,6 +645,11 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_transducer_model(params)
|
model = get_transducer_model(params)
|
||||||
|
model.ctc = get_ctc_model(params)
|
||||||
|
if model.ctc is not None:
|
||||||
|
logging.info(f"Enable CTC loss with weight: {params.ctc_weight}")
|
||||||
|
else:
|
||||||
|
logging.info("Disable CTC loss")
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user