mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Replace torchaudio rnnt_loss to k2 pruned rnnt loss
This commit is contained in:
parent
6caff5fd38
commit
f4b8b0641a
@ -73,9 +73,9 @@ def greedy_search(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
logits = model.joiner(current_encoder_out, decoder_out)
|
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1))
|
||||||
# logits is (1, 1, 1, vocab_size)
|
# logits is (1, 1, 1, vocab_size)
|
||||||
|
|
||||||
y = logits.argmax().item()
|
y = logits.argmax().item()
|
||||||
|
@ -30,21 +30,14 @@ class Joiner(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
encoder_out:
|
encoder_out:
|
||||||
Output from the encoder. Its shape is (N, T, C).
|
Output from the encoder. Its shape is (N, T, s_range, C).
|
||||||
decoder_out:
|
decoder_out:
|
||||||
Output from the decoder. Its shape is (N, U, C).
|
Output from the decoder. Its shape is (N, T, s_range, C).
|
||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, T, U, C).
|
Return a tensor of shape (N, T, s_range, C).
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == decoder_out.ndim == 3
|
assert encoder_out.ndim == decoder_out.ndim == 4
|
||||||
assert encoder_out.size(0) == decoder_out.size(0)
|
assert encoder_out.shape == decoder_out.shape
|
||||||
assert encoder_out.size(2) == decoder_out.size(2)
|
|
||||||
|
|
||||||
encoder_out = encoder_out.unsqueeze(2)
|
|
||||||
# Now encoder_out is (N, T, 1, C)
|
|
||||||
|
|
||||||
decoder_out = decoder_out.unsqueeze(1)
|
|
||||||
# Now decoder_out is (N, 1, U, C)
|
|
||||||
|
|
||||||
logit = encoder_out + decoder_out
|
logit = encoder_out + decoder_out
|
||||||
logit = torch.tanh(logit)
|
logit = torch.tanh(logit)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -14,15 +14,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
|
||||||
Note we use `rnnt_loss` from torchaudio, which exists only in
|
|
||||||
torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
|
|
||||||
"""
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchaudio
|
|
||||||
import torchaudio.functional
|
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos
|
||||||
@ -38,6 +33,7 @@ class Transducer(nn.Module):
|
|||||||
encoder: EncoderInterface,
|
encoder: EncoderInterface,
|
||||||
decoder: nn.Module,
|
decoder: nn.Module,
|
||||||
joiner: nn.Module,
|
joiner: nn.Module,
|
||||||
|
prune_range: int = 3,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -62,6 +58,7 @@ class Transducer(nn.Module):
|
|||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
self.joiner = joiner
|
self.joiner = joiner
|
||||||
|
self.prune_range = prune_range
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -102,24 +99,32 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
decoder_out = self.decoder(sos_y_padded)
|
decoder_out = self.decoder(sos_y_padded)
|
||||||
|
|
||||||
logits = self.joiner(encoder_out, decoder_out)
|
|
||||||
|
|
||||||
# rnnt_loss requires 0 padded targets
|
|
||||||
# Note: y does not start with SOS
|
# Note: y does not start with SOS
|
||||||
y_padded = y.pad(mode="constant", padding_value=0)
|
y_padded = y.pad(mode="constant", padding_value=0)
|
||||||
|
|
||||||
assert hasattr(torchaudio.functional, "rnnt_loss"), (
|
y_padded = y_padded.to(torch.int64)
|
||||||
f"Current torchaudio version: {torchaudio.__version__}\n"
|
boundary = torch.zeros(
|
||||||
"Please install a version >= 0.10.0"
|
(x.size(0), 4), dtype=torch.int64, device=x.device
|
||||||
|
)
|
||||||
|
boundary[:, 2] = y_lens
|
||||||
|
boundary[:, 3] = x_lens
|
||||||
|
|
||||||
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple(
|
||||||
|
decoder_out, encoder_out, y_padded, blank_id, boundary, True
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = torchaudio.functional.rnnt_loss(
|
ranges = k2.get_rnnt_prune_ranges(
|
||||||
logits=logits,
|
px_grad, py_grad, boundary, self.prune_range
|
||||||
targets=y_padded,
|
|
||||||
logit_lengths=x_lens,
|
|
||||||
target_lengths=y_lens,
|
|
||||||
blank=blank_id,
|
|
||||||
reduction="sum",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return loss
|
am_pruning, lm_pruning = k2.do_rnnt_pruning(
|
||||||
|
encoder_out, decoder_out, ranges
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = self.joiner(am_pruning, lm_pruning)
|
||||||
|
|
||||||
|
pruning_loss = k2.rnnt_loss_pruned(
|
||||||
|
logits, y_padded, ranges, blank_id, boundary
|
||||||
|
)
|
||||||
|
|
||||||
|
return (-torch.sum(simple_loss), -torch.sum(pruning_loss))
|
||||||
|
@ -60,7 +60,15 @@ from icefall.checkpoint import load_checkpoint
|
|||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
measure_gradient_norms,
|
||||||
|
measure_weight_norms,
|
||||||
|
optim_step_and_measure_param_change,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -138,6 +146,14 @@ def get_parser():
|
|||||||
"2 means tri-gram",
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--prune-range",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="The prune range for rnnt loss, it means how many symbols(context)"
|
||||||
|
"we are using to compute the loss",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -195,6 +211,7 @@ def get_params() -> AttributeDict:
|
|||||||
"log_interval": 50,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 3000, # For the 100h subset, use 800
|
"valid_interval": 3000, # For the 100h subset, use 800
|
||||||
|
"log_diagnostics": False,
|
||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"encoder_out_dim": 512,
|
"encoder_out_dim": 512,
|
||||||
@ -383,7 +400,8 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
loss = model(x=feature, x_lens=feature_lens, y=y)
|
simple_loss, pruned_loss = model(x=feature, x_lens=feature_lens, y=y)
|
||||||
|
loss = simple_loss + pruned_loss
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
@ -392,6 +410,8 @@ def compute_loss(
|
|||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
|
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||||
|
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
@ -466,6 +486,45 @@ def train_one_epoch(
|
|||||||
|
|
||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
|
def maybe_log_gradients(tag: str):
|
||||||
|
if (
|
||||||
|
params.log_diagnostics
|
||||||
|
and tb_writer is not None
|
||||||
|
and params.batch_idx_train % (params.log_interval * 5) == 0
|
||||||
|
):
|
||||||
|
tb_writer.add_scalars(
|
||||||
|
tag,
|
||||||
|
measure_gradient_norms(model, norm="l2"),
|
||||||
|
global_step=params.batch_idx_train,
|
||||||
|
)
|
||||||
|
|
||||||
|
def maybe_log_weights(tag: str):
|
||||||
|
if (
|
||||||
|
params.log_diagnostics
|
||||||
|
and tb_writer is not None
|
||||||
|
and params.batch_idx_train % (params.log_interval * 5) == 0
|
||||||
|
):
|
||||||
|
tb_writer.add_scalars(
|
||||||
|
tag,
|
||||||
|
measure_weight_norms(model, norm="l2"),
|
||||||
|
global_step=params.batch_idx_train,
|
||||||
|
)
|
||||||
|
|
||||||
|
def maybe_log_param_relative_changes():
|
||||||
|
if (
|
||||||
|
params.log_diagnostics
|
||||||
|
and tb_writer is not None
|
||||||
|
and params.batch_idx_train % (params.log_interval * 5) == 0
|
||||||
|
):
|
||||||
|
deltas = optim_step_and_measure_param_change(model, optimizer)
|
||||||
|
tb_writer.add_scalars(
|
||||||
|
"train/relative_param_change_per_minibatch",
|
||||||
|
deltas,
|
||||||
|
global_step=params.batch_idx_train,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
@ -483,10 +542,13 @@ def train_one_epoch(
|
|||||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||||
# in the batch and there is no normalization to it so far.
|
# in the batch and there is no normalization to it so far.
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
|
||||||
optimizer.step()
|
maybe_log_weights("train/param_norms")
|
||||||
|
maybe_log_gradients("train/grad_norms")
|
||||||
|
maybe_log_param_relative_changes()
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
logging.info(
|
logging.info(
|
||||||
|
@ -690,3 +690,94 @@ def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
|
|||||||
expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths)
|
expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths)
|
||||||
|
|
||||||
return expaned_lengths >= lengths.unsqueeze(1)
|
return expaned_lengths >= lengths.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
def l1_norm(x):
|
||||||
|
return torch.sum(torch.abs(x))
|
||||||
|
|
||||||
|
|
||||||
|
def l2_norm(x):
|
||||||
|
return torch.sum(torch.pow(x, 2))
|
||||||
|
|
||||||
|
|
||||||
|
def linf_norm(x):
|
||||||
|
return torch.max(torch.abs(x))
|
||||||
|
|
||||||
|
|
||||||
|
def measure_weight_norms(
|
||||||
|
model: nn.Module, norm: str = "l2"
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Compute the norms of the model's parameters.
|
||||||
|
|
||||||
|
:param model: a torch.nn.Module instance
|
||||||
|
:param norm: how to compute the norm. Available values: 'l1', 'l2', 'linf'
|
||||||
|
:return: a dict mapping from parameter's name to its norm.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
norms = {}
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if norm == "l1":
|
||||||
|
val = l1_norm(param)
|
||||||
|
elif norm == "l2":
|
||||||
|
val = l2_norm(param)
|
||||||
|
elif norm == "linf":
|
||||||
|
val = linf_norm(param)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown norm type: {norm}")
|
||||||
|
norms[name] = val.item()
|
||||||
|
return norms
|
||||||
|
|
||||||
|
|
||||||
|
def measure_gradient_norms(
|
||||||
|
model: nn.Module, norm: str = "l1"
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Compute the norms of the gradients for each of model's parameters.
|
||||||
|
|
||||||
|
:param model: a torch.nn.Module instance
|
||||||
|
:param norm: how to compute the norm. Available values: 'l1', 'l2', 'linf'
|
||||||
|
:return: a dict mapping from parameter's name to its gradient's norm.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
norms = {}
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if norm == "l1":
|
||||||
|
val = l1_norm(param.grad)
|
||||||
|
elif norm == "l2":
|
||||||
|
val = l2_norm(param.grad)
|
||||||
|
elif norm == "linf":
|
||||||
|
val = linf_norm(param.grad)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown norm type: {norm}")
|
||||||
|
norms[name] = val.item()
|
||||||
|
return norms
|
||||||
|
|
||||||
|
|
||||||
|
def optim_step_and_measure_param_change(
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
scaler: Optional[GradScaler] = None,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Perform model weight update and measure the "relative change in parameters per minibatch."
|
||||||
|
It is understood as a ratio between the L2 norm of the difference between original and updates parameters,
|
||||||
|
and the L2 norm of the original parameter. It is given by the formula:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\begin{aligned}
|
||||||
|
\delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2}
|
||||||
|
\end{aligned}
|
||||||
|
"""
|
||||||
|
param_copy = {n: p.detach().clone() for n, p in model.named_parameters()}
|
||||||
|
if scaler:
|
||||||
|
scaler.step(optimizer)
|
||||||
|
else:
|
||||||
|
optimizer.step()
|
||||||
|
relative_change = {}
|
||||||
|
with torch.no_grad():
|
||||||
|
for n, p_new in model.named_parameters():
|
||||||
|
p_orig = param_copy[n]
|
||||||
|
delta = l2_norm(p_orig - p_new) / l2_norm(p_orig)
|
||||||
|
relative_change[n] = delta.item()
|
||||||
|
return relative_change
|
||||||
|
Loading…
x
Reference in New Issue
Block a user