mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-02 21:54:18 +00:00
Merge 054e2399b9d21cbd1aac6186f80997f0eef2104f into ad28c8c5ebef4cef33b99482469c83e7b36ff07d
This commit is contained in:
commit
77e2c4a28a
@ -166,4 +166,23 @@ class Transducer(nn.Module):
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
||||
B = px_grad.size(0)
|
||||
S = px_grad.size(1)
|
||||
T = px_grad.size(2) - 1
|
||||
# px_grad's shape (B, S, T+1)
|
||||
# py_grad's shape (B, S+1, T)
|
||||
|
||||
px_grad_pad = torch.zeros(
|
||||
(B, 1, T + 1), dtype=px_grad.dtype, device=px_grad.device
|
||||
)
|
||||
py_grad_pad = torch.zeros(
|
||||
(B, S + 1, 1), dtype=px_grad.dtype, device=px_grad.device
|
||||
)
|
||||
|
||||
px_grad_padded = torch.cat([px_grad, px_grad_pad], dim=1)
|
||||
py_grad_padded = torch.cat([py_grad, py_grad_pad], dim=2)
|
||||
|
||||
# tot_grad's shape (B, S+1, T+1)
|
||||
tot_grad = px_grad_padded + py_grad_padded
|
||||
|
||||
return (simple_loss, pruned_loss, tot_grad, x_lens, y_lens)
|
||||
|
@ -35,7 +35,7 @@ import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
@ -434,7 +434,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss = model(
|
||||
simple_loss, pruned_loss, tot_grad, x_lens, y_lens, = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -442,6 +442,12 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
cut_ids = [c.id for c in supervisions["cut"]]
|
||||
save_and_plot_tot_grad(tot_grad, cut_ids, x_lens, y_lens)
|
||||
import sys
|
||||
|
||||
sys.exit()
|
||||
|
||||
loss = params.simple_loss_scale * simple_loss + pruned_loss
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
@ -491,6 +497,53 @@ def compute_validation_loss(
|
||||
return tot_loss
|
||||
|
||||
|
||||
def save_and_plot_tot_grad(
|
||||
tot_grad: torch.Tensor,
|
||||
cut_ids: List[str],
|
||||
x_lens: torch.Tensor,
|
||||
y_lens: torch.Tensor,
|
||||
):
|
||||
"""Save and plot the tot_grad.
|
||||
|
||||
Args:
|
||||
tot_grad:
|
||||
A tensor of shape (B, U+1, T+1). It contains the gradient of
|
||||
each node in the lattice.
|
||||
cut_ids:
|
||||
A list of size B, containing the cut ID of each utterance in the batch.
|
||||
x_lens:
|
||||
A 1-D tensor of shape (B,), specifying the number of valid acoustic
|
||||
frames in tot_grad for each utterance in the batch.
|
||||
y_lens:
|
||||
A 1-D tensor of shape (B,), specifying the number of valid tokens
|
||||
in tot_grad for each utterance in the batch.
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
tot_grad = tot_grad.detach().cpu().permute(0, 2, 1)
|
||||
|
||||
ext = "png" # supported types: png, ps, pdf, svg
|
||||
|
||||
x_lens = x_lens.tolist()
|
||||
y_lens = y_lens.tolist()
|
||||
|
||||
tot_grad = tot_grad.unbind(0)
|
||||
for i in range(len(cut_ids)):
|
||||
cid = cut_ids[i]
|
||||
T = x_lens[i]
|
||||
U = y_lens[i]
|
||||
grad = tot_grad[i][:T, :U]
|
||||
|
||||
filename = f"{cid}.{ext}"
|
||||
logging.info(f"Saving to {filename}")
|
||||
# plt.matshow(grad.t(), origin="lower", cmap="gray")
|
||||
plt.matshow(grad.t(), origin="lower")
|
||||
plt.xlabel("t")
|
||||
plt.ylabel("u")
|
||||
plt.title(cid)
|
||||
plt.savefig(filename)
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
@ -577,6 +630,7 @@ def train_one_epoch(
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
)
|
||||
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user