mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-11 11:02:29 +00:00
Add auxiliary losses.
This commit is contained in:
parent
5e4b1e01fe
commit
8b4c571824
@ -15,11 +15,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos
|
||||||
@ -37,6 +38,7 @@ class Transducer(nn.Module):
|
|||||||
joiner: nn.Module,
|
joiner: nn.Module,
|
||||||
decoder_giga: Optional[nn.Module] = None,
|
decoder_giga: Optional[nn.Module] = None,
|
||||||
joiner_giga: Optional[nn.Module] = None,
|
joiner_giga: Optional[nn.Module] = None,
|
||||||
|
aux_module: Optional[nn.Module] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -57,6 +59,8 @@ class Transducer(nn.Module):
|
|||||||
The decoder for the GigaSpeech dataset.
|
The decoder for the GigaSpeech dataset.
|
||||||
joiner_giga:
|
joiner_giga:
|
||||||
The joiner for the GigaSpeech dataset.
|
The joiner for the GigaSpeech dataset.
|
||||||
|
aux_module:
|
||||||
|
Optional. The auxiliary branch for computing auxiliary losses.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||||
@ -73,6 +77,8 @@ class Transducer(nn.Module):
|
|||||||
self.decoder_giga = decoder_giga
|
self.decoder_giga = decoder_giga
|
||||||
self.joiner_giga = joiner_giga
|
self.joiner_giga = joiner_giga
|
||||||
|
|
||||||
|
self.aux_module = aux_module
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -80,7 +86,7 @@ class Transducer(nn.Module):
|
|||||||
y: k2.RaggedTensor,
|
y: k2.RaggedTensor,
|
||||||
libri: bool = True,
|
libri: bool = True,
|
||||||
modified_transducer_prob: float = 0.0,
|
modified_transducer_prob: float = 0.0,
|
||||||
) -> torch.Tensor:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
@ -97,7 +103,10 @@ class Transducer(nn.Module):
|
|||||||
modified_transducer_prob:
|
modified_transducer_prob:
|
||||||
The probability to use modified transducer loss.
|
The probability to use modified transducer loss.
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer loss.
|
Return a tuple of 3 scalar tensors containing:
|
||||||
|
- transducer loss
|
||||||
|
- auxiliary transducer loss
|
||||||
|
- KL loss
|
||||||
"""
|
"""
|
||||||
assert x.ndim == 3, x.shape
|
assert x.ndim == 3, x.shape
|
||||||
assert x_lens.ndim == 1, x_lens.shape
|
assert x_lens.ndim == 1, x_lens.shape
|
||||||
@ -105,7 +114,7 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||||
|
|
||||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
encoder_out, x_lens, aux_input = self.encoder(x, x_lens)
|
||||||
assert torch.all(x_lens > 0)
|
assert torch.all(x_lens > 0)
|
||||||
|
|
||||||
# Now for the decoder, i.e., the prediction network
|
# Now for the decoder, i.e., the prediction network
|
||||||
@ -154,15 +163,51 @@ class Transducer(nn.Module):
|
|||||||
else:
|
else:
|
||||||
one_sym_per_frame = False
|
one_sym_per_frame = False
|
||||||
|
|
||||||
loss = optimized_transducer.transducer_loss(
|
log_probs = F.log_softmax(logits, dim=-1)
|
||||||
logits=logits,
|
transducer_loss = optimized_transducer.transducer_loss(
|
||||||
|
logits=log_probs,
|
||||||
targets=y_padded,
|
targets=y_padded,
|
||||||
logit_lengths=x_lens,
|
logit_lengths=x_lens,
|
||||||
target_lengths=y_lens,
|
target_lengths=y_lens,
|
||||||
blank=blank_id,
|
blank=blank_id,
|
||||||
reduction="sum",
|
reduction="sum",
|
||||||
one_sym_per_frame=one_sym_per_frame,
|
one_sym_per_frame=one_sym_per_frame,
|
||||||
from_log_softmax=False,
|
from_log_softmax=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return loss
|
aux_output = self.aux_module(aux_input)
|
||||||
|
|
||||||
|
# Now process the auxiliary branch
|
||||||
|
aux_logits = joiner(
|
||||||
|
encoder_out=aux_output,
|
||||||
|
decoder_out=decoder_out,
|
||||||
|
encoder_out_len=x_lens,
|
||||||
|
decoder_out_len=y_lens + 1,
|
||||||
|
)
|
||||||
|
aux_log_probs = F.log_softmax(aux_logits, dim=-1)
|
||||||
|
|
||||||
|
aux_transducer_loss = optimized_transducer.transducer_loss(
|
||||||
|
logits=aux_log_probs,
|
||||||
|
targets=y_padded,
|
||||||
|
logit_lengths=x_lens,
|
||||||
|
target_lengths=y_lens,
|
||||||
|
blank=blank_id,
|
||||||
|
reduction="sum",
|
||||||
|
one_sym_per_frame=one_sym_per_frame,
|
||||||
|
from_log_softmax=True,
|
||||||
|
)
|
||||||
|
kl_loss_1 = F.kl_div(
|
||||||
|
input=log_probs,
|
||||||
|
target=aux_log_probs,
|
||||||
|
reduction="sum",
|
||||||
|
log_target=True,
|
||||||
|
)
|
||||||
|
kl_loss_2 = F.kl_div(
|
||||||
|
input=aux_log_probs,
|
||||||
|
target=log_probs,
|
||||||
|
reduction="sum",
|
||||||
|
log_target=True,
|
||||||
|
)
|
||||||
|
kl_loss = (kl_loss_1 + kl_loss_2) * 0.5
|
||||||
|
|
||||||
|
return transducer_loss, aux_transducer_loss, kl_loss
|
||||||
|
@ -168,6 +168,13 @@ def get_parser():
|
|||||||
help="The probability to select a batch from the GigaSpeech dataset",
|
help="The probability to select a batch from the GigaSpeech dataset",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lambda-aux",
|
||||||
|
type=float,
|
||||||
|
default=0.3,
|
||||||
|
help="The scale applied to the auxiliary losses",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -280,6 +287,14 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
return joiner
|
return joiner
|
||||||
|
|
||||||
|
|
||||||
|
def get_aux_model(params: AttributeDict) -> nn.Module:
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Linear(params.attention_dim, params.encoder_out_dim),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(params.encoder_out_dim, params.encoder_out_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
|
|
||||||
@ -289,12 +304,15 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
decoder_giga = get_decoder_model(params)
|
decoder_giga = get_decoder_model(params)
|
||||||
joiner_giga = get_joiner_model(params)
|
joiner_giga = get_joiner_model(params)
|
||||||
|
|
||||||
|
aux_module = get_aux_model(params)
|
||||||
|
|
||||||
model = Transducer(
|
model = Transducer(
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
decoder_giga=decoder_giga,
|
decoder_giga=decoder_giga,
|
||||||
joiner_giga=joiner_giga,
|
joiner_giga=joiner_giga,
|
||||||
|
aux_module=aux_module,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -436,7 +454,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
loss = model(
|
transducer_loss, aux_transducer_loss, kl_loss = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -444,15 +462,25 @@ def compute_loss(
|
|||||||
modified_transducer_prob=params.modified_transducer_prob,
|
modified_transducer_prob=params.modified_transducer_prob,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
aux_loss = aux_transducer_loss + kl_loss
|
||||||
|
|
||||||
|
assert transducer_loss.requires_grad == is_training
|
||||||
|
assert aux_transducer_loss.requires_grad == is_training
|
||||||
|
assert kl_loss.requires_grad == is_training
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["tot_loss"] = (
|
||||||
|
(transducer_loss + params.lambda_aux * aux_loss).detach().cpu().item()
|
||||||
|
)
|
||||||
|
|
||||||
return loss, info
|
info["transducer_loss"] = transducer_loss.detach().cpu().item()
|
||||||
|
info["aux_transducer_loss"] = aux_transducer_loss.detach().cpu().item()
|
||||||
|
info["kl_loss"] = kl_loss.detach().cpu().item()
|
||||||
|
|
||||||
|
return transducer_loss, aux_loss, info
|
||||||
|
|
||||||
|
|
||||||
def compute_validation_loss(
|
def compute_validation_loss(
|
||||||
@ -468,7 +496,7 @@ def compute_validation_loss(
|
|||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
loss, loss_info = compute_loss(
|
transduer_loss, aux_loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
@ -481,7 +509,7 @@ def compute_validation_loss(
|
|||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
tot_loss.reduce(loss.device)
|
tot_loss.reduce(loss.device)
|
||||||
|
|
||||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
loss_value = tot_loss["tot_loss"] / tot_loss["frames"]
|
||||||
if loss_value < params.best_valid_loss:
|
if loss_value < params.best_valid_loss:
|
||||||
params.best_valid_epoch = params.cur_epoch
|
params.best_valid_epoch = params.cur_epoch
|
||||||
params.best_valid_loss = loss_value
|
params.best_valid_loss = loss_value
|
||||||
@ -557,7 +585,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
libri = is_libri(batch["supervisions"]["cut"][0])
|
libri = is_libri(batch["supervisions"]["cut"][0])
|
||||||
|
|
||||||
loss, loss_info = compute_loss(
|
transducer_loss, aux_loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
@ -581,7 +609,9 @@ def train_one_epoch(
|
|||||||
# in the batch and there is no normalization to it so far.
|
# in the batch and there is no normalization to it so far.
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
|
||||||
|
(transducer_loss + aux_loss * params.lambda_aux).backward()
|
||||||
|
|
||||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
@ -849,14 +879,16 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss, _ = compute_loss(
|
transducer_loss, aux_loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
)
|
)
|
||||||
loss.backward()
|
|
||||||
|
(transducer_loss + aux_loss * params.lambda_aux).backward()
|
||||||
|
|
||||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user