From 054e2399b9d21cbd1aac6186f80997f0eef2104f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 14 Mar 2022 13:55:07 +0800 Subject: [PATCH] [Not for Merge]: Visualize the gradient of each node in the lattice. --- .../ASR/pruned_transducer_stateless/model.py | 21 ++++++- .../ASR/pruned_transducer_stateless/train.py | 58 ++++++++++++++++++- 2 files changed, 76 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/model.py b/egs/librispeech/ASR/pruned_transducer_stateless/model.py index 2f019bcdb..8fbd954e4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/model.py @@ -166,4 +166,23 @@ class Transducer(nn.Module): reduction="sum", ) - return (simple_loss, pruned_loss) + B = px_grad.size(0) + S = px_grad.size(1) + T = px_grad.size(2) - 1 + # px_grad's shape (B, S, T+1) + # py_grad's shape (B, S+1, T) + + px_grad_pad = torch.zeros( + (B, 1, T + 1), dtype=px_grad.dtype, device=px_grad.device + ) + py_grad_pad = torch.zeros( + (B, S + 1, 1), dtype=px_grad.dtype, device=px_grad.device + ) + + px_grad_padded = torch.cat([px_grad, px_grad_pad], dim=1) + py_grad_padded = torch.cat([py_grad, py_grad_pad], dim=2) + + # tot_grad's shape (B, S+1, T+1) + tot_grad = px_grad_padded + py_grad_padded + + return (simple_loss, pruned_loss, tot_grad, x_lens, y_lens) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index f0ea2ccaa..7a8d11c27 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -35,7 +35,7 @@ import argparse import logging from pathlib import Path from shutil import copyfile -from typing import Optional, Tuple +from typing import List, Optional, Tuple import k2 import sentencepiece as spm @@ -434,7 +434,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( + simple_loss, pruned_loss, tot_grad, x_lens, y_lens, = model( x=feature, x_lens=feature_lens, y=y, @@ -442,6 +442,12 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + cut_ids = [c.id for c in supervisions["cut"]] + save_and_plot_tot_grad(tot_grad, cut_ids, x_lens, y_lens) + import sys + + sys.exit() + loss = params.simple_loss_scale * simple_loss + pruned_loss assert loss.requires_grad == is_training @@ -491,6 +497,53 @@ def compute_validation_loss( return tot_loss +def save_and_plot_tot_grad( + tot_grad: torch.Tensor, + cut_ids: List[str], + x_lens: torch.Tensor, + y_lens: torch.Tensor, +): + """Save and plot the tot_grad. + + Args: + tot_grad: + A tensor of shape (B, U+1, T+1). It contains the gradient of + each node in the lattice. + cut_ids: + A list of size B, containing the cut ID of each utterance in the batch. + x_lens: + A 1-D tensor of shape (B,), specifying the number of valid acoustic + frames in tot_grad for each utterance in the batch. + y_lens: + A 1-D tensor of shape (B,), specifying the number of valid tokens + in tot_grad for each utterance in the batch. + """ + import matplotlib.pyplot as plt + + tot_grad = tot_grad.detach().cpu().permute(0, 2, 1) + + ext = "png" # supported types: png, ps, pdf, svg + + x_lens = x_lens.tolist() + y_lens = y_lens.tolist() + + tot_grad = tot_grad.unbind(0) + for i in range(len(cut_ids)): + cid = cut_ids[i] + T = x_lens[i] + U = y_lens[i] + grad = tot_grad[i][:T, :U] + + filename = f"{cid}.{ext}" + logging.info(f"Saving to {filename}") + # plt.matshow(grad.t(), origin="lower", cmap="gray") + plt.matshow(grad.t(), origin="lower") + plt.xlabel("t") + plt.ylabel("u") + plt.title(cid) + plt.savefig(filename) + + def train_one_epoch( params: AttributeDict, model: nn.Module, @@ -577,6 +630,7 @@ def train_one_epoch( batch=batch, is_training=True, ) + # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info