mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-11 11:02:29 +00:00
Add auxiliary losses.
This commit is contained in:
parent
5e4b1e01fe
commit
8b4c571824
@ -15,11 +15,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos
|
||||
@ -37,6 +38,7 @@ class Transducer(nn.Module):
|
||||
joiner: nn.Module,
|
||||
decoder_giga: Optional[nn.Module] = None,
|
||||
joiner_giga: Optional[nn.Module] = None,
|
||||
aux_module: Optional[nn.Module] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -57,6 +59,8 @@ class Transducer(nn.Module):
|
||||
The decoder for the GigaSpeech dataset.
|
||||
joiner_giga:
|
||||
The joiner for the GigaSpeech dataset.
|
||||
aux_module:
|
||||
Optional. The auxiliary branch for computing auxiliary losses.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
@ -73,6 +77,8 @@ class Transducer(nn.Module):
|
||||
self.decoder_giga = decoder_giga
|
||||
self.joiner_giga = joiner_giga
|
||||
|
||||
self.aux_module = aux_module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@ -80,7 +86,7 @@ class Transducer(nn.Module):
|
||||
y: k2.RaggedTensor,
|
||||
libri: bool = True,
|
||||
modified_transducer_prob: float = 0.0,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
@ -97,7 +103,10 @@ class Transducer(nn.Module):
|
||||
modified_transducer_prob:
|
||||
The probability to use modified transducer loss.
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
Return a tuple of 3 scalar tensors containing:
|
||||
- transducer loss
|
||||
- auxiliary transducer loss
|
||||
- KL loss
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
@ -105,7 +114,7 @@ class Transducer(nn.Module):
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
||||
encoder_out, x_lens, aux_input = self.encoder(x, x_lens)
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
@ -154,15 +163,51 @@ class Transducer(nn.Module):
|
||||
else:
|
||||
one_sym_per_frame = False
|
||||
|
||||
loss = optimized_transducer.transducer_loss(
|
||||
logits=logits,
|
||||
log_probs = F.log_softmax(logits, dim=-1)
|
||||
transducer_loss = optimized_transducer.transducer_loss(
|
||||
logits=log_probs,
|
||||
targets=y_padded,
|
||||
logit_lengths=x_lens,
|
||||
target_lengths=y_lens,
|
||||
blank=blank_id,
|
||||
reduction="sum",
|
||||
one_sym_per_frame=one_sym_per_frame,
|
||||
from_log_softmax=False,
|
||||
from_log_softmax=True,
|
||||
)
|
||||
|
||||
return loss
|
||||
aux_output = self.aux_module(aux_input)
|
||||
|
||||
# Now process the auxiliary branch
|
||||
aux_logits = joiner(
|
||||
encoder_out=aux_output,
|
||||
decoder_out=decoder_out,
|
||||
encoder_out_len=x_lens,
|
||||
decoder_out_len=y_lens + 1,
|
||||
)
|
||||
aux_log_probs = F.log_softmax(aux_logits, dim=-1)
|
||||
|
||||
aux_transducer_loss = optimized_transducer.transducer_loss(
|
||||
logits=aux_log_probs,
|
||||
targets=y_padded,
|
||||
logit_lengths=x_lens,
|
||||
target_lengths=y_lens,
|
||||
blank=blank_id,
|
||||
reduction="sum",
|
||||
one_sym_per_frame=one_sym_per_frame,
|
||||
from_log_softmax=True,
|
||||
)
|
||||
kl_loss_1 = F.kl_div(
|
||||
input=log_probs,
|
||||
target=aux_log_probs,
|
||||
reduction="sum",
|
||||
log_target=True,
|
||||
)
|
||||
kl_loss_2 = F.kl_div(
|
||||
input=aux_log_probs,
|
||||
target=log_probs,
|
||||
reduction="sum",
|
||||
log_target=True,
|
||||
)
|
||||
kl_loss = (kl_loss_1 + kl_loss_2) * 0.5
|
||||
|
||||
return transducer_loss, aux_transducer_loss, kl_loss
|
||||
|
@ -168,6 +168,13 @@ def get_parser():
|
||||
help="The probability to select a batch from the GigaSpeech dataset",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lambda-aux",
|
||||
type=float,
|
||||
default=0.3,
|
||||
help="The scale applied to the auxiliary losses",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -280,6 +287,14 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
return joiner
|
||||
|
||||
|
||||
def get_aux_model(params: AttributeDict) -> nn.Module:
|
||||
return nn.Sequential(
|
||||
nn.Linear(params.attention_dim, params.encoder_out_dim),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(params.encoder_out_dim, params.encoder_out_dim),
|
||||
)
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
|
||||
@ -289,12 +304,15 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
decoder_giga = get_decoder_model(params)
|
||||
joiner_giga = get_joiner_model(params)
|
||||
|
||||
aux_module = get_aux_model(params)
|
||||
|
||||
model = Transducer(
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
decoder_giga=decoder_giga,
|
||||
joiner_giga=joiner_giga,
|
||||
aux_module=aux_module,
|
||||
)
|
||||
return model
|
||||
|
||||
@ -436,7 +454,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
loss = model(
|
||||
transducer_loss, aux_transducer_loss, kl_loss = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -444,15 +462,25 @@ def compute_loss(
|
||||
modified_transducer_prob=params.modified_transducer_prob,
|
||||
)
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
aux_loss = aux_transducer_loss + kl_loss
|
||||
|
||||
assert transducer_loss.requires_grad == is_training
|
||||
assert aux_transducer_loss.requires_grad == is_training
|
||||
assert kl_loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
info["tot_loss"] = (
|
||||
(transducer_loss + params.lambda_aux * aux_loss).detach().cpu().item()
|
||||
)
|
||||
|
||||
return loss, info
|
||||
info["transducer_loss"] = transducer_loss.detach().cpu().item()
|
||||
info["aux_transducer_loss"] = aux_transducer_loss.detach().cpu().item()
|
||||
info["kl_loss"] = kl_loss.detach().cpu().item()
|
||||
|
||||
return transducer_loss, aux_loss, info
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
@ -468,7 +496,7 @@ def compute_validation_loss(
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
loss, loss_info = compute_loss(
|
||||
transduer_loss, aux_loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
@ -481,7 +509,7 @@ def compute_validation_loss(
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(loss.device)
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
loss_value = tot_loss["tot_loss"] / tot_loss["frames"]
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_value
|
||||
@ -557,7 +585,7 @@ def train_one_epoch(
|
||||
|
||||
libri = is_libri(batch["supervisions"]["cut"][0])
|
||||
|
||||
loss, loss_info = compute_loss(
|
||||
transducer_loss, aux_loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
@ -581,7 +609,9 @@ def train_one_epoch(
|
||||
# in the batch and there is no normalization to it so far.
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
(transducer_loss + aux_loss * params.lambda_aux).backward()
|
||||
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
|
||||
@ -849,14 +879,16 @@ def scan_pessimistic_batches_for_oom(
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
optimizer.zero_grad()
|
||||
loss, _ = compute_loss(
|
||||
transducer_loss, aux_loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
)
|
||||
loss.backward()
|
||||
|
||||
(transducer_loss + aux_loss * params.lambda_aux).backward()
|
||||
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
except RuntimeError as e:
|
||||
|
Loading…
x
Reference in New Issue
Block a user