use rnn_t loss of K2

This commit is contained in:
PingFeng Luo 2022-01-19 14:27:56 +08:00
parent 8f21e92b5f
commit 46d03ed9f0
5 changed files with 120 additions and 56 deletions

View File

@ -73,9 +73,9 @@ def greedy_search(
continue continue
# fmt: off # fmt: off
current_encoder_out = encoder_out[:, t:t+1, :] current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# fmt: on # fmt: on
logits = model.joiner(current_encoder_out, decoder_out) logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1))
# logits is (1, 1, 1, vocab_size) # logits is (1, 1, 1, vocab_size)
y = logits.argmax().item() y = logits.argmax().item()
@ -128,7 +128,6 @@ class HypothesisList(object):
def data(self): def data(self):
return self._data return self._data
# def add(self, ys: List[int], log_prob: float):
def add(self, hyp: Hypothesis): def add(self, hyp: Hypothesis):
"""Add a Hypothesis to `self`. """Add a Hypothesis to `self`.
@ -266,7 +265,7 @@ def beam_search(
while t < T and sym_per_utt < max_sym_per_utt: while t < T and sym_per_utt < max_sym_per_utt:
# fmt: off # fmt: off
current_encoder_out = encoder_out[:, t:t+1, :] current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# fmt: on # fmt: on
A = B A = B
B = HypothesisList() B = HypothesisList()
@ -294,9 +293,11 @@ def beam_search(
cached_key += f"-t-{t}" cached_key += f"-t-{t}"
if cached_key not in joint_cache: if cached_key not in joint_cache:
logits = model.joiner(current_encoder_out, decoder_out) logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1)
)
# TODO(fangjun): Ccale the blank posterior # TODO(fangjun): Cache the blank posterior
log_prob = logits.log_softmax(dim=-1) log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size) # log_prob is (1, 1, 1, vocab_size)

View File

@ -30,21 +30,14 @@ class Joiner(nn.Module):
""" """
Args: Args:
encoder_out: encoder_out:
Output from the encoder. Its shape is (N, T, C). Output from the encoder. Its shape is (N, T, s_range, C).
decoder_out: decoder_out:
Output from the decoder. Its shape is (N, U, C). Output from the decoder. Its shape is (N, T, s_range, C).
Returns: Returns:
Return a tensor of shape (N, T, U, C). Return a tensor of shape (N, T, s_range, C).
""" """
assert encoder_out.ndim == decoder_out.ndim == 3 assert encoder_out.ndim == decoder_out.ndim == 4
assert encoder_out.size(0) == decoder_out.size(0) assert encoder_out.shape == decoder_out.shape
assert encoder_out.size(2) == decoder_out.size(2)
encoder_out = encoder_out.unsqueeze(2)
# Now encoder_out is (N, T, 1, C)
decoder_out = decoder_out.unsqueeze(1)
# Now decoder_out is (N, 1, U, C)
logit = encoder_out + decoder_out logit = encoder_out + decoder_out
logit = torch.tanh(logit) logit = torch.tanh(logit)

View File

@ -14,15 +14,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Note we use `rnnt_loss` from torchaudio, which exists only in
torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
"""
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchaudio
import torchaudio.functional
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from icefall.utils import add_sos from icefall.utils import add_sos
@ -38,6 +33,7 @@ class Transducer(nn.Module):
encoder: EncoderInterface, encoder: EncoderInterface,
decoder: nn.Module, decoder: nn.Module,
joiner: nn.Module, joiner: nn.Module,
prune_range: int = 3,
): ):
""" """
Args: Args:
@ -62,6 +58,7 @@ class Transducer(nn.Module):
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
self.joiner = joiner self.joiner = joiner
self.prune_range = prune_range
def forward( def forward(
self, self,
@ -102,24 +99,32 @@ class Transducer(nn.Module):
decoder_out = self.decoder(sos_y_padded) decoder_out = self.decoder(sos_y_padded)
logits = self.joiner(encoder_out, decoder_out)
# rnnt_loss requires 0 padded targets
# Note: y does not start with SOS # Note: y does not start with SOS
y_padded = y.pad(mode="constant", padding_value=0) y_padded = y.pad(mode="constant", padding_value=0)
assert hasattr(torchaudio.functional, "rnnt_loss"), ( y_padded = y_padded.to(torch.int64)
f"Current torchaudio version: {torchaudio.__version__}\n" boundary = torch.zeros(
"Please install a version >= 0.10.0" (x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple(
decoder_out, encoder_out, y_padded, blank_id, boundary, True
) )
loss = torchaudio.functional.rnnt_loss( ranges = k2.get_rnnt_prune_ranges(
logits=logits, px_grad, py_grad, boundary, self.prune_range
targets=y_padded,
logit_lengths=x_lens,
target_lengths=y_lens,
blank=blank_id,
reduction="sum",
) )
return loss am_pruning, lm_pruning = k2.do_rnnt_pruning(
encoder_out, decoder_out, ranges
)
logits = self.joiner(am_pruning, lm_pruning)
pruning_loss = k2.rnnt_loss_pruned(
logits, y_padded, ranges, blank_id, boundary
)
return (-torch.sum(simple_loss), -torch.sum(pruning_loss))

View File

@ -45,9 +45,9 @@ import argparse
import logging import logging
import math import math
from typing import List from typing import List
from pathlib import Path
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search from beam_search import beam_search, greedy_search
@ -59,6 +59,8 @@ from torch.nn.utils.rnn import pad_sequence
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import AttributeDict from icefall.utils import AttributeDict
from icefall.lexicon import Lexicon
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
def get_parser(): def get_parser():
@ -76,9 +78,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--lang-dir",
type=str, type=str,
help="""Path to bpe.model. help="""Path to lang.
Used only when method is ctc-decoding. Used only when method is ctc-decoding.
""", """,
) )
@ -220,18 +222,10 @@ def read_sound_files(
def main(): def main():
parser = get_parser() parser = get_parser()
args = parser.parse_args() args = parser.parse_args()
args.lang_dir = Path(args.lang_dir)
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")
@ -240,6 +234,15 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
params.vocab_size = max(lexicon.tokens) + 1
logging.info("Creating model") logging.info("Creating model")
model = get_transducer_model(params) model = get_transducer_model(params)
@ -303,7 +306,7 @@ def main():
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split()) hyps.append([lexicon.token_table[i] for i in hyp])
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):

View File

@ -47,7 +47,15 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
measure_gradient_norms,
measure_weight_norms,
optim_step_and_measure_param_change,
setup_logger,
str2bool,
)
def get_parser(): def get_parser():
@ -128,6 +136,14 @@ def get_parser():
"2 means tri-gram", "2 means tri-gram",
) )
parser.add_argument(
"--prune-range",
type=int,
default=3,
help="The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss",
)
return parser return parser
@ -185,6 +201,7 @@ def get_params() -> AttributeDict:
"log_interval": 50, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800 "valid_interval": 3000, # For the 100h subset, use 800
"log_diagnostics": False,
# parameters for conformer # parameters for conformer
"feature_dim": 80, "feature_dim": 80,
"encoder_out_dim": 512, "encoder_out_dim": 512,
@ -373,7 +390,8 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
loss = model(x=feature, x_lens=feature_lens, y=y) simple_loss, pruned_loss = model(x=feature, x_lens=feature_lens, y=y)
loss = simple_loss + pruned_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
@ -382,6 +400,8 @@ def compute_loss(
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
return loss, info return loss, info
@ -456,6 +476,45 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
def maybe_log_gradients(tag: str):
if (
params.log_diagnostics
and tb_writer is not None
and params.batch_idx_train % (params.log_interval * 5) == 0
):
tb_writer.add_scalars(
tag,
measure_gradient_norms(model, norm="l2"),
global_step=params.batch_idx_train,
)
def maybe_log_weights(tag: str):
if (
params.log_diagnostics
and tb_writer is not None
and params.batch_idx_train % (params.log_interval * 5) == 0
):
tb_writer.add_scalars(
tag,
measure_weight_norms(model, norm="l2"),
global_step=params.batch_idx_train,
)
def maybe_log_param_relative_changes():
if (
params.log_diagnostics
and tb_writer is not None
and params.batch_idx_train % (params.log_interval * 5) == 0
):
deltas = optim_step_and_measure_param_change(model, optimizer)
tb_writer.add_scalars(
"train/relative_param_change_per_minibatch",
deltas,
global_step=params.batch_idx_train,
)
else:
optimizer.step()
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -473,10 +532,13 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
optimizer.zero_grad()
loss.backward() loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() maybe_log_weights("train/param_norms")
maybe_log_gradients("train/grad_norms")
maybe_log_param_relative_changes()
optimizer.zero_grad()
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
logging.info( logging.info(