Replace torchaudio rnnt_loss to k2 pruned rnnt loss

This commit is contained in:
pkufool 2022-01-07 13:02:46 +08:00
parent 6caff5fd38
commit f4b8b0641a
5 changed files with 191 additions and 40 deletions

View File

@ -73,9 +73,9 @@ def greedy_search(
continue
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# fmt: on
logits = model.joiner(current_encoder_out, decoder_out)
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1))
# logits is (1, 1, 1, vocab_size)
y = logits.argmax().item()

View File

@ -30,21 +30,14 @@ class Joiner(nn.Module):
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, C).
Output from the encoder. Its shape is (N, T, s_range, C).
decoder_out:
Output from the decoder. Its shape is (N, U, C).
Output from the decoder. Its shape is (N, T, s_range, C).
Returns:
Return a tensor of shape (N, T, U, C).
Return a tensor of shape (N, T, s_range, C).
"""
assert encoder_out.ndim == decoder_out.ndim == 3
assert encoder_out.size(0) == decoder_out.size(0)
assert encoder_out.size(2) == decoder_out.size(2)
encoder_out = encoder_out.unsqueeze(2)
# Now encoder_out is (N, T, 1, C)
decoder_out = decoder_out.unsqueeze(1)
# Now decoder_out is (N, 1, U, C)
assert encoder_out.ndim == decoder_out.ndim == 4
assert encoder_out.shape == decoder_out.shape
logit = encoder_out + decoder_out
logit = torch.tanh(logit)

View File

@ -1,4 +1,4 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -14,15 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note we use `rnnt_loss` from torchaudio, which exists only in
torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
"""
import k2
import torch
import torch.nn as nn
import torchaudio
import torchaudio.functional
from encoder_interface import EncoderInterface
from icefall.utils import add_sos
@ -38,6 +33,7 @@ class Transducer(nn.Module):
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
prune_range: int = 3,
):
"""
Args:
@ -62,6 +58,7 @@ class Transducer(nn.Module):
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
self.prune_range = prune_range
def forward(
self,
@ -102,24 +99,32 @@ class Transducer(nn.Module):
decoder_out = self.decoder(sos_y_padded)
logits = self.joiner(encoder_out, decoder_out)
# rnnt_loss requires 0 padded targets
# Note: y does not start with SOS
y_padded = y.pad(mode="constant", padding_value=0)
assert hasattr(torchaudio.functional, "rnnt_loss"), (
f"Current torchaudio version: {torchaudio.__version__}\n"
"Please install a version >= 0.10.0"
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple(
decoder_out, encoder_out, y_padded, blank_id, boundary, True
)
loss = torchaudio.functional.rnnt_loss(
logits=logits,
targets=y_padded,
logit_lengths=x_lens,
target_lengths=y_lens,
blank=blank_id,
reduction="sum",
ranges = k2.get_rnnt_prune_ranges(
px_grad, py_grad, boundary, self.prune_range
)
return loss
am_pruning, lm_pruning = k2.do_rnnt_pruning(
encoder_out, decoder_out, ranges
)
logits = self.joiner(am_pruning, lm_pruning)
pruning_loss = k2.rnnt_loss_pruned(
logits, y_padded, ranges, blank_id, boundary
)
return (-torch.sum(simple_loss), -torch.sum(pruning_loss))

View File

@ -60,7 +60,15 @@ from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
measure_gradient_norms,
measure_weight_norms,
optim_step_and_measure_param_change,
setup_logger,
str2bool,
)
def get_parser():
@ -138,6 +146,14 @@ def get_parser():
"2 means tri-gram",
)
parser.add_argument(
"--prune-range",
type=int,
default=3,
help="The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss",
)
return parser
@ -195,6 +211,7 @@ def get_params() -> AttributeDict:
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800
"log_diagnostics": False,
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
@ -383,7 +400,8 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
loss = model(x=feature, x_lens=feature_lens, y=y)
simple_loss, pruned_loss = model(x=feature, x_lens=feature_lens, y=y)
loss = simple_loss + pruned_loss
assert loss.requires_grad == is_training
@ -392,6 +410,8 @@ def compute_loss(
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
return loss, info
@ -466,6 +486,45 @@ def train_one_epoch(
tot_loss = MetricsTracker()
def maybe_log_gradients(tag: str):
if (
params.log_diagnostics
and tb_writer is not None
and params.batch_idx_train % (params.log_interval * 5) == 0
):
tb_writer.add_scalars(
tag,
measure_gradient_norms(model, norm="l2"),
global_step=params.batch_idx_train,
)
def maybe_log_weights(tag: str):
if (
params.log_diagnostics
and tb_writer is not None
and params.batch_idx_train % (params.log_interval * 5) == 0
):
tb_writer.add_scalars(
tag,
measure_weight_norms(model, norm="l2"),
global_step=params.batch_idx_train,
)
def maybe_log_param_relative_changes():
if (
params.log_diagnostics
and tb_writer is not None
and params.batch_idx_train % (params.log_interval * 5) == 0
):
deltas = optim_step_and_measure_param_change(model, optimizer)
tb_writer.add_scalars(
"train/relative_param_change_per_minibatch",
deltas,
global_step=params.batch_idx_train,
)
else:
optimizer.step()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
@ -483,10 +542,13 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
maybe_log_weights("train/param_norms")
maybe_log_gradients("train/grad_norms")
maybe_log_param_relative_changes()
optimizer.zero_grad()
if batch_idx % params.log_interval == 0:
logging.info(

View File

@ -690,3 +690,94 @@ def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths)
return expaned_lengths >= lengths.unsqueeze(1)
def l1_norm(x):
return torch.sum(torch.abs(x))
def l2_norm(x):
return torch.sum(torch.pow(x, 2))
def linf_norm(x):
return torch.max(torch.abs(x))
def measure_weight_norms(
model: nn.Module, norm: str = "l2"
) -> Dict[str, float]:
"""
Compute the norms of the model's parameters.
:param model: a torch.nn.Module instance
:param norm: how to compute the norm. Available values: 'l1', 'l2', 'linf'
:return: a dict mapping from parameter's name to its norm.
"""
with torch.no_grad():
norms = {}
for name, param in model.named_parameters():
if norm == "l1":
val = l1_norm(param)
elif norm == "l2":
val = l2_norm(param)
elif norm == "linf":
val = linf_norm(param)
else:
raise ValueError(f"Unknown norm type: {norm}")
norms[name] = val.item()
return norms
def measure_gradient_norms(
model: nn.Module, norm: str = "l1"
) -> Dict[str, float]:
"""
Compute the norms of the gradients for each of model's parameters.
:param model: a torch.nn.Module instance
:param norm: how to compute the norm. Available values: 'l1', 'l2', 'linf'
:return: a dict mapping from parameter's name to its gradient's norm.
"""
with torch.no_grad():
norms = {}
for name, param in model.named_parameters():
if norm == "l1":
val = l1_norm(param.grad)
elif norm == "l2":
val = l2_norm(param.grad)
elif norm == "linf":
val = linf_norm(param.grad)
else:
raise ValueError(f"Unknown norm type: {norm}")
norms[name] = val.item()
return norms
def optim_step_and_measure_param_change(
model: nn.Module,
optimizer: torch.optim.Optimizer,
scaler: Optional[GradScaler] = None,
) -> Dict[str, float]:
"""
Perform model weight update and measure the "relative change in parameters per minibatch."
It is understood as a ratio between the L2 norm of the difference between original and updates parameters,
and the L2 norm of the original parameter. It is given by the formula:
.. math::
\begin{aligned}
\delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2}
\end{aligned}
"""
param_copy = {n: p.detach().clone() for n, p in model.named_parameters()}
if scaler:
scaler.step(optimizer)
else:
optimizer.step()
relative_change = {}
with torch.no_grad():
for n, p_new in model.named_parameters():
p_orig = param_copy[n]
delta = l2_norm(p_orig - p_new) / l2_norm(p_orig)
relative_change[n] = delta.item()
return relative_change