Add auxiliary losses.

This commit is contained in:
Fangjun Kuang 2022-02-22 16:15:00 +08:00
parent 5e4b1e01fe
commit 8b4c571824
2 changed files with 95 additions and 18 deletions

View File

@ -15,11 +15,12 @@
# limitations under the License.
import random
from typing import Optional
from typing import Optional, Tuple
import k2
import torch
import torch.nn as nn
import torch.nn.functional as F
from encoder_interface import EncoderInterface
from icefall.utils import add_sos
@ -37,6 +38,7 @@ class Transducer(nn.Module):
joiner: nn.Module,
decoder_giga: Optional[nn.Module] = None,
joiner_giga: Optional[nn.Module] = None,
aux_module: Optional[nn.Module] = None,
):
"""
Args:
@ -57,6 +59,8 @@ class Transducer(nn.Module):
The decoder for the GigaSpeech dataset.
joiner_giga:
The joiner for the GigaSpeech dataset.
aux_module:
Optional. The auxiliary branch for computing auxiliary losses.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
@ -73,6 +77,8 @@ class Transducer(nn.Module):
self.decoder_giga = decoder_giga
self.joiner_giga = joiner_giga
self.aux_module = aux_module
def forward(
self,
x: torch.Tensor,
@ -80,7 +86,7 @@ class Transducer(nn.Module):
y: k2.RaggedTensor,
libri: bool = True,
modified_transducer_prob: float = 0.0,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
@ -97,7 +103,10 @@ class Transducer(nn.Module):
modified_transducer_prob:
The probability to use modified transducer loss.
Returns:
Return the transducer loss.
Return a tuple of 3 scalar tensors containing:
- transducer loss
- auxiliary transducer loss
- KL loss
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
@ -105,7 +114,7 @@ class Transducer(nn.Module):
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens)
encoder_out, x_lens, aux_input = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
@ -154,15 +163,51 @@ class Transducer(nn.Module):
else:
one_sym_per_frame = False
loss = optimized_transducer.transducer_loss(
logits=logits,
log_probs = F.log_softmax(logits, dim=-1)
transducer_loss = optimized_transducer.transducer_loss(
logits=log_probs,
targets=y_padded,
logit_lengths=x_lens,
target_lengths=y_lens,
blank=blank_id,
reduction="sum",
one_sym_per_frame=one_sym_per_frame,
from_log_softmax=False,
from_log_softmax=True,
)
return loss
aux_output = self.aux_module(aux_input)
# Now process the auxiliary branch
aux_logits = joiner(
encoder_out=aux_output,
decoder_out=decoder_out,
encoder_out_len=x_lens,
decoder_out_len=y_lens + 1,
)
aux_log_probs = F.log_softmax(aux_logits, dim=-1)
aux_transducer_loss = optimized_transducer.transducer_loss(
logits=aux_log_probs,
targets=y_padded,
logit_lengths=x_lens,
target_lengths=y_lens,
blank=blank_id,
reduction="sum",
one_sym_per_frame=one_sym_per_frame,
from_log_softmax=True,
)
kl_loss_1 = F.kl_div(
input=log_probs,
target=aux_log_probs,
reduction="sum",
log_target=True,
)
kl_loss_2 = F.kl_div(
input=aux_log_probs,
target=log_probs,
reduction="sum",
log_target=True,
)
kl_loss = (kl_loss_1 + kl_loss_2) * 0.5
return transducer_loss, aux_transducer_loss, kl_loss

View File

@ -168,6 +168,13 @@ def get_parser():
help="The probability to select a batch from the GigaSpeech dataset",
)
parser.add_argument(
"--lambda-aux",
type=float,
default=0.3,
help="The scale applied to the auxiliary losses",
)
return parser
@ -280,6 +287,14 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
return joiner
def get_aux_model(params: AttributeDict) -> nn.Module:
return nn.Sequential(
nn.Linear(params.attention_dim, params.encoder_out_dim),
nn.ReLU(inplace=True),
nn.Linear(params.encoder_out_dim, params.encoder_out_dim),
)
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
@ -289,12 +304,15 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
decoder_giga = get_decoder_model(params)
joiner_giga = get_joiner_model(params)
aux_module = get_aux_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
decoder_giga=decoder_giga,
joiner_giga=joiner_giga,
aux_module=aux_module,
)
return model
@ -436,7 +454,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
loss = model(
transducer_loss, aux_transducer_loss, kl_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -444,15 +462,25 @@ def compute_loss(
modified_transducer_prob=params.modified_transducer_prob,
)
assert loss.requires_grad == is_training
aux_loss = aux_transducer_loss + kl_loss
assert transducer_loss.requires_grad == is_training
assert aux_transducer_loss.requires_grad == is_training
assert kl_loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
info["tot_loss"] = (
(transducer_loss + params.lambda_aux * aux_loss).detach().cpu().item()
)
return loss, info
info["transducer_loss"] = transducer_loss.detach().cpu().item()
info["aux_transducer_loss"] = aux_transducer_loss.detach().cpu().item()
info["kl_loss"] = kl_loss.detach().cpu().item()
return transducer_loss, aux_loss, info
def compute_validation_loss(
@ -468,7 +496,7 @@ def compute_validation_loss(
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
transduer_loss, aux_loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
@ -481,7 +509,7 @@ def compute_validation_loss(
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
loss_value = tot_loss["tot_loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
@ -557,7 +585,7 @@ def train_one_epoch(
libri = is_libri(batch["supervisions"]["cut"][0])
loss, loss_info = compute_loss(
transducer_loss, aux_loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
@ -581,7 +609,9 @@ def train_one_epoch(
# in the batch and there is no normalization to it so far.
optimizer.zero_grad()
loss.backward()
(transducer_loss + aux_loss * params.lambda_aux).backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
@ -849,14 +879,16 @@ def scan_pessimistic_batches_for_oom(
batch = train_dl.dataset[cuts]
try:
optimizer.zero_grad()
loss, _ = compute_loss(
transducer_loss, aux_loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
loss.backward()
(transducer_loss + aux_loss * params.lambda_aux).backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
except RuntimeError as e: