mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
use rnn_t loss of K2
This commit is contained in:
parent
8f21e92b5f
commit
46d03ed9f0
@ -73,9 +73,9 @@ def greedy_search(
|
||||
continue
|
||||
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
||||
# fmt: on
|
||||
logits = model.joiner(current_encoder_out, decoder_out)
|
||||
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1))
|
||||
# logits is (1, 1, 1, vocab_size)
|
||||
|
||||
y = logits.argmax().item()
|
||||
@ -128,7 +128,6 @@ class HypothesisList(object):
|
||||
def data(self):
|
||||
return self._data
|
||||
|
||||
# def add(self, ys: List[int], log_prob: float):
|
||||
def add(self, hyp: Hypothesis):
|
||||
"""Add a Hypothesis to `self`.
|
||||
|
||||
@ -266,7 +265,7 @@ def beam_search(
|
||||
|
||||
while t < T and sym_per_utt < max_sym_per_utt:
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
||||
# fmt: on
|
||||
A = B
|
||||
B = HypothesisList()
|
||||
@ -294,9 +293,11 @@ def beam_search(
|
||||
|
||||
cached_key += f"-t-{t}"
|
||||
if cached_key not in joint_cache:
|
||||
logits = model.joiner(current_encoder_out, decoder_out)
|
||||
logits = model.joiner(
|
||||
current_encoder_out, decoder_out.unsqueeze(1)
|
||||
)
|
||||
|
||||
# TODO(fangjun): Ccale the blank posterior
|
||||
# TODO(fangjun): Cache the blank posterior
|
||||
|
||||
log_prob = logits.log_softmax(dim=-1)
|
||||
# log_prob is (1, 1, 1, vocab_size)
|
||||
|
@ -30,21 +30,14 @@ class Joiner(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, C).
|
||||
Output from the encoder. Its shape is (N, T, s_range, C).
|
||||
decoder_out:
|
||||
Output from the decoder. Its shape is (N, U, C).
|
||||
Output from the decoder. Its shape is (N, T, s_range, C).
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, U, C).
|
||||
Return a tensor of shape (N, T, s_range, C).
|
||||
"""
|
||||
assert encoder_out.ndim == decoder_out.ndim == 3
|
||||
assert encoder_out.size(0) == decoder_out.size(0)
|
||||
assert encoder_out.size(2) == decoder_out.size(2)
|
||||
|
||||
encoder_out = encoder_out.unsqueeze(2)
|
||||
# Now encoder_out is (N, T, 1, C)
|
||||
|
||||
decoder_out = decoder_out.unsqueeze(1)
|
||||
# Now decoder_out is (N, 1, U, C)
|
||||
assert encoder_out.ndim == decoder_out.ndim == 4
|
||||
assert encoder_out.shape == decoder_out.shape
|
||||
|
||||
logit = encoder_out + decoder_out
|
||||
logit = torch.tanh(logit)
|
||||
|
@ -14,15 +14,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Note we use `rnnt_loss` from torchaudio, which exists only in
|
||||
torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
|
||||
"""
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
import torchaudio.functional
|
||||
from encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos
|
||||
@ -38,6 +33,7 @@ class Transducer(nn.Module):
|
||||
encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
joiner: nn.Module,
|
||||
prune_range: int = 3,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -62,6 +58,7 @@ class Transducer(nn.Module):
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
self.prune_range = prune_range
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -102,24 +99,32 @@ class Transducer(nn.Module):
|
||||
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
logits = self.joiner(encoder_out, decoder_out)
|
||||
|
||||
# rnnt_loss requires 0 padded targets
|
||||
# Note: y does not start with SOS
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
assert hasattr(torchaudio.functional, "rnnt_loss"), (
|
||||
f"Current torchaudio version: {torchaudio.__version__}\n"
|
||||
"Please install a version >= 0.10.0"
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros(
|
||||
(x.size(0), 4), dtype=torch.int64, device=x.device
|
||||
)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple(
|
||||
decoder_out, encoder_out, y_padded, blank_id, boundary, True
|
||||
)
|
||||
|
||||
loss = torchaudio.functional.rnnt_loss(
|
||||
logits=logits,
|
||||
targets=y_padded,
|
||||
logit_lengths=x_lens,
|
||||
target_lengths=y_lens,
|
||||
blank=blank_id,
|
||||
reduction="sum",
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad, py_grad, boundary, self.prune_range
|
||||
)
|
||||
|
||||
return loss
|
||||
am_pruning, lm_pruning = k2.do_rnnt_pruning(
|
||||
encoder_out, decoder_out, ranges
|
||||
)
|
||||
|
||||
logits = self.joiner(am_pruning, lm_pruning)
|
||||
|
||||
pruning_loss = k2.rnnt_loss_pruned(
|
||||
logits, y_padded, ranges, blank_id, boundary
|
||||
)
|
||||
|
||||
return (-torch.sum(simple_loss), -torch.sum(pruning_loss))
|
||||
|
@ -45,9 +45,9 @@ import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import beam_search, greedy_search
|
||||
@ -59,6 +59,8 @@ from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -76,9 +78,9 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Path to bpe.model.
|
||||
help="""Path to lang.
|
||||
Used only when method is ctc-decoding.
|
||||
""",
|
||||
)
|
||||
@ -220,18 +222,10 @@ def read_sound_files(
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
@ -240,6 +234,15 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
|
||||
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
@ -303,7 +306,7 @@ def main():
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {params.method}")
|
||||
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
hyps.append([lexicon.token_table[i] for i in hyp])
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
|
@ -47,7 +47,15 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
measure_gradient_norms,
|
||||
measure_weight_norms,
|
||||
optim_step_and_measure_param_change,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -128,6 +136,14 @@ def get_parser():
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prune-range",
|
||||
type=int,
|
||||
default=3,
|
||||
help="The prune range for rnnt loss, it means how many symbols(context)"
|
||||
"we are using to compute the loss",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -185,6 +201,7 @@ def get_params() -> AttributeDict:
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000, # For the 100h subset, use 800
|
||||
"log_diagnostics": False,
|
||||
# parameters for conformer
|
||||
"feature_dim": 80,
|
||||
"encoder_out_dim": 512,
|
||||
@ -373,7 +390,8 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
loss = model(x=feature, x_lens=feature_lens, y=y)
|
||||
simple_loss, pruned_loss = model(x=feature, x_lens=feature_lens, y=y)
|
||||
loss = simple_loss + pruned_loss
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
@ -382,6 +400,8 @@ def compute_loss(
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
||||
@ -456,6 +476,45 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
def maybe_log_gradients(tag: str):
|
||||
if (
|
||||
params.log_diagnostics
|
||||
and tb_writer is not None
|
||||
and params.batch_idx_train % (params.log_interval * 5) == 0
|
||||
):
|
||||
tb_writer.add_scalars(
|
||||
tag,
|
||||
measure_gradient_norms(model, norm="l2"),
|
||||
global_step=params.batch_idx_train,
|
||||
)
|
||||
|
||||
def maybe_log_weights(tag: str):
|
||||
if (
|
||||
params.log_diagnostics
|
||||
and tb_writer is not None
|
||||
and params.batch_idx_train % (params.log_interval * 5) == 0
|
||||
):
|
||||
tb_writer.add_scalars(
|
||||
tag,
|
||||
measure_weight_norms(model, norm="l2"),
|
||||
global_step=params.batch_idx_train,
|
||||
)
|
||||
|
||||
def maybe_log_param_relative_changes():
|
||||
if (
|
||||
params.log_diagnostics
|
||||
and tb_writer is not None
|
||||
and params.batch_idx_train % (params.log_interval * 5) == 0
|
||||
):
|
||||
deltas = optim_step_and_measure_param_change(model, optimizer)
|
||||
tb_writer.add_scalars(
|
||||
"train/relative_param_change_per_minibatch",
|
||||
deltas,
|
||||
global_step=params.batch_idx_train,
|
||||
)
|
||||
else:
|
||||
optimizer.step()
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
@ -473,10 +532,13 @@ def train_one_epoch(
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
|
||||
maybe_log_weights("train/param_norms")
|
||||
maybe_log_gradients("train/grad_norms")
|
||||
maybe_log_param_relative_changes()
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
logging.info(
|
||||
|
Loading…
x
Reference in New Issue
Block a user