use rnn_t loss of K2

This commit is contained in:
PingFeng Luo 2022-01-19 14:27:56 +08:00
parent 8f21e92b5f
commit 46d03ed9f0
5 changed files with 120 additions and 56 deletions

View File

@ -73,9 +73,9 @@ def greedy_search(
continue
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# fmt: on
logits = model.joiner(current_encoder_out, decoder_out)
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1))
# logits is (1, 1, 1, vocab_size)
y = logits.argmax().item()
@ -128,7 +128,6 @@ class HypothesisList(object):
def data(self):
return self._data
# def add(self, ys: List[int], log_prob: float):
def add(self, hyp: Hypothesis):
"""Add a Hypothesis to `self`.
@ -266,7 +265,7 @@ def beam_search(
while t < T and sym_per_utt < max_sym_per_utt:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# fmt: on
A = B
B = HypothesisList()
@ -294,9 +293,11 @@ def beam_search(
cached_key += f"-t-{t}"
if cached_key not in joint_cache:
logits = model.joiner(current_encoder_out, decoder_out)
logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1)
)
# TODO(fangjun): Ccale the blank posterior
# TODO(fangjun): Cache the blank posterior
log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)

View File

@ -30,21 +30,14 @@ class Joiner(nn.Module):
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, C).
Output from the encoder. Its shape is (N, T, s_range, C).
decoder_out:
Output from the decoder. Its shape is (N, U, C).
Output from the decoder. Its shape is (N, T, s_range, C).
Returns:
Return a tensor of shape (N, T, U, C).
Return a tensor of shape (N, T, s_range, C).
"""
assert encoder_out.ndim == decoder_out.ndim == 3
assert encoder_out.size(0) == decoder_out.size(0)
assert encoder_out.size(2) == decoder_out.size(2)
encoder_out = encoder_out.unsqueeze(2)
# Now encoder_out is (N, T, 1, C)
decoder_out = decoder_out.unsqueeze(1)
# Now decoder_out is (N, 1, U, C)
assert encoder_out.ndim == decoder_out.ndim == 4
assert encoder_out.shape == decoder_out.shape
logit = encoder_out + decoder_out
logit = torch.tanh(logit)

View File

@ -14,15 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note we use `rnnt_loss` from torchaudio, which exists only in
torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
"""
import k2
import torch
import torch.nn as nn
import torchaudio
import torchaudio.functional
from encoder_interface import EncoderInterface
from icefall.utils import add_sos
@ -38,6 +33,7 @@ class Transducer(nn.Module):
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
prune_range: int = 3,
):
"""
Args:
@ -62,6 +58,7 @@ class Transducer(nn.Module):
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
self.prune_range = prune_range
def forward(
self,
@ -102,24 +99,32 @@ class Transducer(nn.Module):
decoder_out = self.decoder(sos_y_padded)
logits = self.joiner(encoder_out, decoder_out)
# rnnt_loss requires 0 padded targets
# Note: y does not start with SOS
y_padded = y.pad(mode="constant", padding_value=0)
assert hasattr(torchaudio.functional, "rnnt_loss"), (
f"Current torchaudio version: {torchaudio.__version__}\n"
"Please install a version >= 0.10.0"
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple(
decoder_out, encoder_out, y_padded, blank_id, boundary, True
)
loss = torchaudio.functional.rnnt_loss(
logits=logits,
targets=y_padded,
logit_lengths=x_lens,
target_lengths=y_lens,
blank=blank_id,
reduction="sum",
ranges = k2.get_rnnt_prune_ranges(
px_grad, py_grad, boundary, self.prune_range
)
return loss
am_pruning, lm_pruning = k2.do_rnnt_pruning(
encoder_out, decoder_out, ranges
)
logits = self.joiner(am_pruning, lm_pruning)
pruning_loss = k2.rnnt_loss_pruned(
logits, y_padded, ranges, blank_id, boundary
)
return (-torch.sum(simple_loss), -torch.sum(pruning_loss))

View File

@ -45,9 +45,9 @@ import argparse
import logging
import math
from typing import List
from pathlib import Path
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import beam_search, greedy_search
@ -59,6 +59,8 @@ from torch.nn.utils.rnn import pad_sequence
from icefall.env import get_env_info
from icefall.utils import AttributeDict
from icefall.lexicon import Lexicon
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
def get_parser():
@ -76,9 +78,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--lang-dir",
type=str,
help="""Path to bpe.model.
help="""Path to lang.
Used only when method is ctc-decoding.
""",
)
@ -220,18 +222,10 @@ def read_sound_files(
def main():
parser = get_parser()
args = parser.parse_args()
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
@ -240,6 +234,15 @@ def main():
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
params.vocab_size = max(lexicon.tokens) + 1
logging.info("Creating model")
model = get_transducer_model(params)
@ -303,7 +306,7 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append([lexicon.token_table[i] for i in hyp])
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):

View File

@ -47,7 +47,15 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
measure_gradient_norms,
measure_weight_norms,
optim_step_and_measure_param_change,
setup_logger,
str2bool,
)
def get_parser():
@ -128,6 +136,14 @@ def get_parser():
"2 means tri-gram",
)
parser.add_argument(
"--prune-range",
type=int,
default=3,
help="The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss",
)
return parser
@ -185,6 +201,7 @@ def get_params() -> AttributeDict:
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800
"log_diagnostics": False,
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
@ -373,7 +390,8 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
loss = model(x=feature, x_lens=feature_lens, y=y)
simple_loss, pruned_loss = model(x=feature, x_lens=feature_lens, y=y)
loss = simple_loss + pruned_loss
assert loss.requires_grad == is_training
@ -382,6 +400,8 @@ def compute_loss(
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
return loss, info
@ -456,6 +476,45 @@ def train_one_epoch(
tot_loss = MetricsTracker()
def maybe_log_gradients(tag: str):
if (
params.log_diagnostics
and tb_writer is not None
and params.batch_idx_train % (params.log_interval * 5) == 0
):
tb_writer.add_scalars(
tag,
measure_gradient_norms(model, norm="l2"),
global_step=params.batch_idx_train,
)
def maybe_log_weights(tag: str):
if (
params.log_diagnostics
and tb_writer is not None
and params.batch_idx_train % (params.log_interval * 5) == 0
):
tb_writer.add_scalars(
tag,
measure_weight_norms(model, norm="l2"),
global_step=params.batch_idx_train,
)
def maybe_log_param_relative_changes():
if (
params.log_diagnostics
and tb_writer is not None
and params.batch_idx_train % (params.log_interval * 5) == 0
):
deltas = optim_step_and_measure_param_change(model, optimizer)
tb_writer.add_scalars(
"train/relative_param_change_per_minibatch",
deltas,
global_step=params.batch_idx_train,
)
else:
optimizer.step()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
@ -473,10 +532,13 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
maybe_log_weights("train/param_norms")
maybe_log_gradients("train/grad_norms")
maybe_log_param_relative_changes()
optimizer.zero_grad()
if batch_idx % params.log_interval == 0:
logging.info(