From 46d03ed9f0aee13d0929d290ebd41c051d2564f0 Mon Sep 17 00:00:00 2001 From: PingFeng Luo Date: Wed, 19 Jan 2022 14:27:56 +0800 Subject: [PATCH] use rnn_t loss of K2 --- .../ASR/transducer_stateless/beam_search.py | 13 ++-- .../ASR/transducer_stateless/joiner.py | 17 ++--- .../ASR/transducer_stateless/model.py | 45 ++++++------ .../ASR/transducer_stateless/pretrained.py | 29 ++++---- .../ASR/transducer_stateless/train.py | 72 +++++++++++++++++-- 5 files changed, 120 insertions(+), 56 deletions(-) diff --git a/egs/wenetspeech/ASR/transducer_stateless/beam_search.py b/egs/wenetspeech/ASR/transducer_stateless/beam_search.py index 9ed9b2ad1..3d4818509 100644 --- a/egs/wenetspeech/ASR/transducer_stateless/beam_search.py +++ b/egs/wenetspeech/ASR/transducer_stateless/beam_search.py @@ -73,9 +73,9 @@ def greedy_search( continue # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) # fmt: on - logits = model.joiner(current_encoder_out, decoder_out) + logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) # logits is (1, 1, 1, vocab_size) y = logits.argmax().item() @@ -128,7 +128,6 @@ class HypothesisList(object): def data(self): return self._data - # def add(self, ys: List[int], log_prob: float): def add(self, hyp: Hypothesis): """Add a Hypothesis to `self`. @@ -266,7 +265,7 @@ def beam_search( while t < T and sym_per_utt < max_sym_per_utt: # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) # fmt: on A = B B = HypothesisList() @@ -294,9 +293,11 @@ def beam_search( cached_key += f"-t-{t}" if cached_key not in joint_cache: - logits = model.joiner(current_encoder_out, decoder_out) + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1) + ) - # TODO(fangjun): Ccale the blank posterior + # TODO(fangjun): Cache the blank posterior log_prob = logits.log_softmax(dim=-1) # log_prob is (1, 1, 1, vocab_size) diff --git a/egs/wenetspeech/ASR/transducer_stateless/joiner.py b/egs/wenetspeech/ASR/transducer_stateless/joiner.py index 2ef3f1de6..89fab9e4c 100644 --- a/egs/wenetspeech/ASR/transducer_stateless/joiner.py +++ b/egs/wenetspeech/ASR/transducer_stateless/joiner.py @@ -30,21 +30,14 @@ class Joiner(nn.Module): """ Args: encoder_out: - Output from the encoder. Its shape is (N, T, C). + Output from the encoder. Its shape is (N, T, s_range, C). decoder_out: - Output from the decoder. Its shape is (N, U, C). + Output from the decoder. Its shape is (N, T, s_range, C). Returns: - Return a tensor of shape (N, T, U, C). + Return a tensor of shape (N, T, s_range, C). """ - assert encoder_out.ndim == decoder_out.ndim == 3 - assert encoder_out.size(0) == decoder_out.size(0) - assert encoder_out.size(2) == decoder_out.size(2) - - encoder_out = encoder_out.unsqueeze(2) - # Now encoder_out is (N, T, 1, C) - - decoder_out = decoder_out.unsqueeze(1) - # Now decoder_out is (N, 1, U, C) + assert encoder_out.ndim == decoder_out.ndim == 4 + assert encoder_out.shape == decoder_out.shape logit = encoder_out + decoder_out logit = torch.tanh(logit) diff --git a/egs/wenetspeech/ASR/transducer_stateless/model.py b/egs/wenetspeech/ASR/transducer_stateless/model.py index 2f0f9a183..3d562e1a8 100644 --- a/egs/wenetspeech/ASR/transducer_stateless/model.py +++ b/egs/wenetspeech/ASR/transducer_stateless/model.py @@ -14,15 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Note we use `rnnt_loss` from torchaudio, which exists only in -torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0 -""" + import k2 import torch import torch.nn as nn -import torchaudio -import torchaudio.functional from encoder_interface import EncoderInterface from icefall.utils import add_sos @@ -38,6 +33,7 @@ class Transducer(nn.Module): encoder: EncoderInterface, decoder: nn.Module, joiner: nn.Module, + prune_range: int = 3, ): """ Args: @@ -62,6 +58,7 @@ class Transducer(nn.Module): self.encoder = encoder self.decoder = decoder self.joiner = joiner + self.prune_range = prune_range def forward( self, @@ -102,24 +99,32 @@ class Transducer(nn.Module): decoder_out = self.decoder(sos_y_padded) - logits = self.joiner(encoder_out, decoder_out) - - # rnnt_loss requires 0 padded targets # Note: y does not start with SOS y_padded = y.pad(mode="constant", padding_value=0) - assert hasattr(torchaudio.functional, "rnnt_loss"), ( - f"Current torchaudio version: {torchaudio.__version__}\n" - "Please install a version >= 0.10.0" + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple( + decoder_out, encoder_out, y_padded, blank_id, boundary, True ) - loss = torchaudio.functional.rnnt_loss( - logits=logits, - targets=y_padded, - logit_lengths=x_lens, - target_lengths=y_lens, - blank=blank_id, - reduction="sum", + ranges = k2.get_rnnt_prune_ranges( + px_grad, py_grad, boundary, self.prune_range ) - return loss + am_pruning, lm_pruning = k2.do_rnnt_pruning( + encoder_out, decoder_out, ranges + ) + + logits = self.joiner(am_pruning, lm_pruning) + + pruning_loss = k2.rnnt_loss_pruned( + logits, y_padded, ranges, blank_id, boundary + ) + + return (-torch.sum(simple_loss), -torch.sum(pruning_loss)) diff --git a/egs/wenetspeech/ASR/transducer_stateless/pretrained.py b/egs/wenetspeech/ASR/transducer_stateless/pretrained.py index e5dba8f0e..65ac5f3ff 100755 --- a/egs/wenetspeech/ASR/transducer_stateless/pretrained.py +++ b/egs/wenetspeech/ASR/transducer_stateless/pretrained.py @@ -45,9 +45,9 @@ import argparse import logging import math from typing import List +from pathlib import Path import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import beam_search, greedy_search @@ -59,6 +59,8 @@ from torch.nn.utils.rnn import pad_sequence from icefall.env import get_env_info from icefall.utils import AttributeDict +from icefall.lexicon import Lexicon +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler def get_parser(): @@ -76,9 +78,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--lang-dir", type=str, - help="""Path to bpe.model. + help="""Path to lang. Used only when method is ctc-decoding. """, ) @@ -220,18 +222,10 @@ def read_sound_files( def main(): parser = get_parser() args = parser.parse_args() + args.lang_dir = Path(args.lang_dir) params = get_params() - params.update(vars(args)) - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - logging.info(f"{params}") device = torch.device("cpu") @@ -240,6 +234,15 @@ def main(): logging.info(f"device: {device}") + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + params.blank_id = graph_compiler.texts_to_ids("")[0][0] + params.vocab_size = max(lexicon.tokens) + 1 + logging.info("Creating model") model = get_transducer_model(params) @@ -303,7 +306,7 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append([lexicon.token_table[i] for i in hyp]) s = "\n" for filename, hyp in zip(params.sound_files, hyps): diff --git a/egs/wenetspeech/ASR/transducer_stateless/train.py b/egs/wenetspeech/ASR/transducer_stateless/train.py index 5bcc82761..e2a382910 100755 --- a/egs/wenetspeech/ASR/transducer_stateless/train.py +++ b/egs/wenetspeech/ASR/transducer_stateless/train.py @@ -47,7 +47,15 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + measure_gradient_norms, + measure_weight_norms, + optim_step_and_measure_param_change, + setup_logger, + str2bool, +) def get_parser(): @@ -128,6 +136,14 @@ def get_parser(): "2 means tri-gram", ) + parser.add_argument( + "--prune-range", + type=int, + default=3, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + return parser @@ -185,6 +201,7 @@ def get_params() -> AttributeDict: "log_interval": 50, "reset_interval": 200, "valid_interval": 3000, # For the 100h subset, use 800 + "log_diagnostics": False, # parameters for conformer "feature_dim": 80, "encoder_out_dim": 512, @@ -373,7 +390,8 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - loss = model(x=feature, x_lens=feature_lens, y=y) + simple_loss, pruned_loss = model(x=feature, x_lens=feature_lens, y=y) + loss = simple_loss + pruned_loss assert loss.requires_grad == is_training @@ -382,6 +400,8 @@ def compute_loss( # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() return loss, info @@ -456,6 +476,45 @@ def train_one_epoch( tot_loss = MetricsTracker() + def maybe_log_gradients(tag: str): + if ( + params.log_diagnostics + and tb_writer is not None + and params.batch_idx_train % (params.log_interval * 5) == 0 + ): + tb_writer.add_scalars( + tag, + measure_gradient_norms(model, norm="l2"), + global_step=params.batch_idx_train, + ) + + def maybe_log_weights(tag: str): + if ( + params.log_diagnostics + and tb_writer is not None + and params.batch_idx_train % (params.log_interval * 5) == 0 + ): + tb_writer.add_scalars( + tag, + measure_weight_norms(model, norm="l2"), + global_step=params.batch_idx_train, + ) + + def maybe_log_param_relative_changes(): + if ( + params.log_diagnostics + and tb_writer is not None + and params.batch_idx_train % (params.log_interval * 5) == 0 + ): + deltas = optim_step_and_measure_param_change(model, optimizer) + tb_writer.add_scalars( + "train/relative_param_change_per_minibatch", + deltas, + global_step=params.batch_idx_train, + ) + else: + optimizer.step() + for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -473,10 +532,13 @@ def train_one_epoch( # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. - optimizer.zero_grad() loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() + + maybe_log_weights("train/param_norms") + maybe_log_gradients("train/grad_norms") + maybe_log_param_relative_changes() + + optimizer.zero_grad() if batch_idx % params.log_interval == 0: logging.info(