Add auxiliary losses.

This commit is contained in:
Fangjun Kuang 2022-02-22 16:15:00 +08:00
parent 5e4b1e01fe
commit 8b4c571824
2 changed files with 95 additions and 18 deletions

View File

@ -15,11 +15,12 @@
# limitations under the License. # limitations under the License.
import random import random
from typing import Optional from typing import Optional, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from icefall.utils import add_sos from icefall.utils import add_sos
@ -37,6 +38,7 @@ class Transducer(nn.Module):
joiner: nn.Module, joiner: nn.Module,
decoder_giga: Optional[nn.Module] = None, decoder_giga: Optional[nn.Module] = None,
joiner_giga: Optional[nn.Module] = None, joiner_giga: Optional[nn.Module] = None,
aux_module: Optional[nn.Module] = None,
): ):
""" """
Args: Args:
@ -57,6 +59,8 @@ class Transducer(nn.Module):
The decoder for the GigaSpeech dataset. The decoder for the GigaSpeech dataset.
joiner_giga: joiner_giga:
The joiner for the GigaSpeech dataset. The joiner for the GigaSpeech dataset.
aux_module:
Optional. The auxiliary branch for computing auxiliary losses.
""" """
super().__init__() super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder) assert isinstance(encoder, EncoderInterface), type(encoder)
@ -73,6 +77,8 @@ class Transducer(nn.Module):
self.decoder_giga = decoder_giga self.decoder_giga = decoder_giga
self.joiner_giga = joiner_giga self.joiner_giga = joiner_giga
self.aux_module = aux_module
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
@ -80,7 +86,7 @@ class Transducer(nn.Module):
y: k2.RaggedTensor, y: k2.RaggedTensor,
libri: bool = True, libri: bool = True,
modified_transducer_prob: float = 0.0, modified_transducer_prob: float = 0.0,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Args: Args:
x: x:
@ -97,7 +103,10 @@ class Transducer(nn.Module):
modified_transducer_prob: modified_transducer_prob:
The probability to use modified transducer loss. The probability to use modified transducer loss.
Returns: Returns:
Return the transducer loss. Return a tuple of 3 scalar tensors containing:
- transducer loss
- auxiliary transducer loss
- KL loss
""" """
assert x.ndim == 3, x.shape assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape assert x_lens.ndim == 1, x_lens.shape
@ -105,7 +114,7 @@ class Transducer(nn.Module):
assert x.size(0) == x_lens.size(0) == y.dim0 assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens) encoder_out, x_lens, aux_input = self.encoder(x, x_lens)
assert torch.all(x_lens > 0) assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network # Now for the decoder, i.e., the prediction network
@ -154,15 +163,51 @@ class Transducer(nn.Module):
else: else:
one_sym_per_frame = False one_sym_per_frame = False
loss = optimized_transducer.transducer_loss( log_probs = F.log_softmax(logits, dim=-1)
logits=logits, transducer_loss = optimized_transducer.transducer_loss(
logits=log_probs,
targets=y_padded, targets=y_padded,
logit_lengths=x_lens, logit_lengths=x_lens,
target_lengths=y_lens, target_lengths=y_lens,
blank=blank_id, blank=blank_id,
reduction="sum", reduction="sum",
one_sym_per_frame=one_sym_per_frame, one_sym_per_frame=one_sym_per_frame,
from_log_softmax=False, from_log_softmax=True,
) )
return loss aux_output = self.aux_module(aux_input)
# Now process the auxiliary branch
aux_logits = joiner(
encoder_out=aux_output,
decoder_out=decoder_out,
encoder_out_len=x_lens,
decoder_out_len=y_lens + 1,
)
aux_log_probs = F.log_softmax(aux_logits, dim=-1)
aux_transducer_loss = optimized_transducer.transducer_loss(
logits=aux_log_probs,
targets=y_padded,
logit_lengths=x_lens,
target_lengths=y_lens,
blank=blank_id,
reduction="sum",
one_sym_per_frame=one_sym_per_frame,
from_log_softmax=True,
)
kl_loss_1 = F.kl_div(
input=log_probs,
target=aux_log_probs,
reduction="sum",
log_target=True,
)
kl_loss_2 = F.kl_div(
input=aux_log_probs,
target=log_probs,
reduction="sum",
log_target=True,
)
kl_loss = (kl_loss_1 + kl_loss_2) * 0.5
return transducer_loss, aux_transducer_loss, kl_loss

View File

@ -168,6 +168,13 @@ def get_parser():
help="The probability to select a batch from the GigaSpeech dataset", help="The probability to select a batch from the GigaSpeech dataset",
) )
parser.add_argument(
"--lambda-aux",
type=float,
default=0.3,
help="The scale applied to the auxiliary losses",
)
return parser return parser
@ -280,6 +287,14 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
return joiner return joiner
def get_aux_model(params: AttributeDict) -> nn.Module:
return nn.Sequential(
nn.Linear(params.attention_dim, params.encoder_out_dim),
nn.ReLU(inplace=True),
nn.Linear(params.encoder_out_dim, params.encoder_out_dim),
)
def get_transducer_model(params: AttributeDict) -> nn.Module: def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
@ -289,12 +304,15 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
decoder_giga = get_decoder_model(params) decoder_giga = get_decoder_model(params)
joiner_giga = get_joiner_model(params) joiner_giga = get_joiner_model(params)
aux_module = get_aux_model(params)
model = Transducer( model = Transducer(
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
decoder_giga=decoder_giga, decoder_giga=decoder_giga,
joiner_giga=joiner_giga, joiner_giga=joiner_giga,
aux_module=aux_module,
) )
return model return model
@ -436,7 +454,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
loss = model( transducer_loss, aux_transducer_loss, kl_loss = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -444,15 +462,25 @@ def compute_loss(
modified_transducer_prob=params.modified_transducer_prob, modified_transducer_prob=params.modified_transducer_prob,
) )
assert loss.requires_grad == is_training aux_loss = aux_transducer_loss + kl_loss
assert transducer_loss.requires_grad == is_training
assert aux_transducer_loss.requires_grad == is_training
assert kl_loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["tot_loss"] = (
(transducer_loss + params.lambda_aux * aux_loss).detach().cpu().item()
)
return loss, info info["transducer_loss"] = transducer_loss.detach().cpu().item()
info["aux_transducer_loss"] = aux_transducer_loss.detach().cpu().item()
info["kl_loss"] = kl_loss.detach().cpu().item()
return transducer_loss, aux_loss, info
def compute_validation_loss( def compute_validation_loss(
@ -468,7 +496,7 @@ def compute_validation_loss(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss( transduer_loss, aux_loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
@ -481,7 +509,7 @@ def compute_validation_loss(
if world_size > 1: if world_size > 1:
tot_loss.reduce(loss.device) tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"] loss_value = tot_loss["tot_loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss: if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value params.best_valid_loss = loss_value
@ -557,7 +585,7 @@ def train_one_epoch(
libri = is_libri(batch["supervisions"]["cut"][0]) libri = is_libri(batch["supervisions"]["cut"][0])
loss, loss_info = compute_loss( transducer_loss, aux_loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
@ -581,7 +609,9 @@ def train_one_epoch(
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
optimizer.zero_grad() optimizer.zero_grad()
loss.backward()
(transducer_loss + aux_loss * params.lambda_aux).backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
@ -849,14 +879,16 @@ def scan_pessimistic_batches_for_oom(
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
optimizer.zero_grad() optimizer.zero_grad()
loss, _ = compute_loss( transducer_loss, aux_loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
) )
loss.backward()
(transducer_loss + aux_loss * params.lambda_aux).backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
except RuntimeError as e: except RuntimeError as e: