Merge 054e2399b9d21cbd1aac6186f80997f0eef2104f into ad28c8c5ebef4cef33b99482469c83e7b36ff07d

This commit is contained in:
Fangjun Kuang 2022-03-18 17:28:07 +01:00 committed by GitHub
commit 77e2c4a28a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 76 additions and 3 deletions

View File

@ -166,4 +166,23 @@ class Transducer(nn.Module):
reduction="sum", reduction="sum",
) )
return (simple_loss, pruned_loss) B = px_grad.size(0)
S = px_grad.size(1)
T = px_grad.size(2) - 1
# px_grad's shape (B, S, T+1)
# py_grad's shape (B, S+1, T)
px_grad_pad = torch.zeros(
(B, 1, T + 1), dtype=px_grad.dtype, device=px_grad.device
)
py_grad_pad = torch.zeros(
(B, S + 1, 1), dtype=px_grad.dtype, device=px_grad.device
)
px_grad_padded = torch.cat([px_grad, px_grad_pad], dim=1)
py_grad_padded = torch.cat([py_grad, py_grad_pad], dim=2)
# tot_grad's shape (B, S+1, T+1)
tot_grad = px_grad_padded + py_grad_padded
return (simple_loss, pruned_loss, tot_grad, x_lens, y_lens)

View File

@ -35,7 +35,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import List, Optional, Tuple
import k2 import k2
import sentencepiece as spm import sentencepiece as spm
@ -434,7 +434,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model( simple_loss, pruned_loss, tot_grad, x_lens, y_lens, = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -442,6 +442,12 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
cut_ids = [c.id for c in supervisions["cut"]]
save_and_plot_tot_grad(tot_grad, cut_ids, x_lens, y_lens)
import sys
sys.exit()
loss = params.simple_loss_scale * simple_loss + pruned_loss loss = params.simple_loss_scale * simple_loss + pruned_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
@ -491,6 +497,53 @@ def compute_validation_loss(
return tot_loss return tot_loss
def save_and_plot_tot_grad(
tot_grad: torch.Tensor,
cut_ids: List[str],
x_lens: torch.Tensor,
y_lens: torch.Tensor,
):
"""Save and plot the tot_grad.
Args:
tot_grad:
A tensor of shape (B, U+1, T+1). It contains the gradient of
each node in the lattice.
cut_ids:
A list of size B, containing the cut ID of each utterance in the batch.
x_lens:
A 1-D tensor of shape (B,), specifying the number of valid acoustic
frames in tot_grad for each utterance in the batch.
y_lens:
A 1-D tensor of shape (B,), specifying the number of valid tokens
in tot_grad for each utterance in the batch.
"""
import matplotlib.pyplot as plt
tot_grad = tot_grad.detach().cpu().permute(0, 2, 1)
ext = "png" # supported types: png, ps, pdf, svg
x_lens = x_lens.tolist()
y_lens = y_lens.tolist()
tot_grad = tot_grad.unbind(0)
for i in range(len(cut_ids)):
cid = cut_ids[i]
T = x_lens[i]
U = y_lens[i]
grad = tot_grad[i][:T, :U]
filename = f"{cid}.{ext}"
logging.info(f"Saving to {filename}")
# plt.matshow(grad.t(), origin="lower", cmap="gray")
plt.matshow(grad.t(), origin="lower")
plt.xlabel("t")
plt.ylabel("u")
plt.title(cid)
plt.savefig(filename)
def train_one_epoch( def train_one_epoch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -577,6 +630,7 @@ def train_one_epoch(
batch=batch, batch=batch,
is_training=True, is_training=True,
) )
# summary stats # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info