from local

This commit is contained in:
dohe0342 2022-12-10 11:08:49 +09:00
parent 406a93be59
commit ce8fb6fea0
3 changed files with 33 additions and 31 deletions

Binary file not shown.

View File

@ -719,41 +719,44 @@ def compute_loss(
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
'''
# Compute ctc loss
info = MetricsTracker()
if params.ctc_loss_scale > 0:
# Compute ctc loss
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
with warnings.catch_warnings():
warnings.simplefilter("ignore")
supervision_segments, token_ids = encode_supervisions(
supervisions,
subsampling_factor=params.subsampling_factor,
token_ids=token_ids,
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
with warnings.catch_warnings():
warnings.simplefilter("ignore")
supervision_segments, token_ids = encode_supervisions(
supervisions,
subsampling_factor=params.subsampling_factor,
token_ids=token_ids,
)
# Works with a BPE model
decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device)
dense_fsa_vec = k2.DenseFsaVec(
ctc_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
# Works with a BPE model
decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device)
dense_fsa_vec = k2.DenseFsaVec(
ctc_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction="sum",
use_double_scores=params.use_double_scores,
)
assert ctc_loss.requires_grad == is_training
loss += params.ctc_loss_scale * ctc_loss
'''
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction="sum",
use_double_scores=params.use_double_scores,
)
assert ctc_loss.requires_grad == is_training
loss += params.ctc_loss_scale * ctc_loss
info["ctc_loss"] = ctc_loss.detach().cpu().item()
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
@ -762,7 +765,6 @@ def compute_loss(
info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
#info["ctc_loss"] = ctc_loss.detach().cpu().item()
return loss, info