Merge 054e2399b9d21cbd1aac6186f80997f0eef2104f into ad28c8c5ebef4cef33b99482469c83e7b36ff07d

This commit is contained in:
Fangjun Kuang 2022-03-18 17:28:07 +01:00 committed by GitHub
commit 77e2c4a28a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 76 additions and 3 deletions

View File

@ -166,4 +166,23 @@ class Transducer(nn.Module):
reduction="sum",
)
return (simple_loss, pruned_loss)
B = px_grad.size(0)
S = px_grad.size(1)
T = px_grad.size(2) - 1
# px_grad's shape (B, S, T+1)
# py_grad's shape (B, S+1, T)
px_grad_pad = torch.zeros(
(B, 1, T + 1), dtype=px_grad.dtype, device=px_grad.device
)
py_grad_pad = torch.zeros(
(B, S + 1, 1), dtype=px_grad.dtype, device=px_grad.device
)
px_grad_padded = torch.cat([px_grad, px_grad_pad], dim=1)
py_grad_padded = torch.cat([py_grad, py_grad_pad], dim=2)
# tot_grad's shape (B, S+1, T+1)
tot_grad = px_grad_padded + py_grad_padded
return (simple_loss, pruned_loss, tot_grad, x_lens, y_lens)

View File

@ -35,7 +35,7 @@ import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
from typing import List, Optional, Tuple
import k2
import sentencepiece as spm
@ -434,7 +434,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(
simple_loss, pruned_loss, tot_grad, x_lens, y_lens, = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -442,6 +442,12 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
cut_ids = [c.id for c in supervisions["cut"]]
save_and_plot_tot_grad(tot_grad, cut_ids, x_lens, y_lens)
import sys
sys.exit()
loss = params.simple_loss_scale * simple_loss + pruned_loss
assert loss.requires_grad == is_training
@ -491,6 +497,53 @@ def compute_validation_loss(
return tot_loss
def save_and_plot_tot_grad(
tot_grad: torch.Tensor,
cut_ids: List[str],
x_lens: torch.Tensor,
y_lens: torch.Tensor,
):
"""Save and plot the tot_grad.
Args:
tot_grad:
A tensor of shape (B, U+1, T+1). It contains the gradient of
each node in the lattice.
cut_ids:
A list of size B, containing the cut ID of each utterance in the batch.
x_lens:
A 1-D tensor of shape (B,), specifying the number of valid acoustic
frames in tot_grad for each utterance in the batch.
y_lens:
A 1-D tensor of shape (B,), specifying the number of valid tokens
in tot_grad for each utterance in the batch.
"""
import matplotlib.pyplot as plt
tot_grad = tot_grad.detach().cpu().permute(0, 2, 1)
ext = "png" # supported types: png, ps, pdf, svg
x_lens = x_lens.tolist()
y_lens = y_lens.tolist()
tot_grad = tot_grad.unbind(0)
for i in range(len(cut_ids)):
cid = cut_ids[i]
T = x_lens[i]
U = y_lens[i]
grad = tot_grad[i][:T, :U]
filename = f"{cid}.{ext}"
logging.info(f"Saving to {filename}")
# plt.matshow(grad.t(), origin="lower", cmap="gray")
plt.matshow(grad.t(), origin="lower")
plt.xlabel("t")
plt.ylabel("u")
plt.title(cid)
plt.savefig(filename)
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
@ -577,6 +630,7 @@ def train_one_epoch(
batch=batch,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info