mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
406a93be59
commit
ce8fb6fea0
Binary file not shown.
Binary file not shown.
@ -719,41 +719,44 @@ def compute_loss(
|
|||||||
|
|
||||||
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||||
|
|
||||||
'''
|
info = MetricsTracker()
|
||||||
# Compute ctc loss
|
|
||||||
|
if params.ctc_loss_scale > 0:
|
||||||
|
# Compute ctc loss
|
||||||
|
|
||||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||||
# different duration in decreasing order, required by
|
# different duration in decreasing order, required by
|
||||||
# `k2.intersect_dense` called in `k2.ctc_loss`
|
# `k2.intersect_dense` called in `k2.ctc_loss`
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
supervision_segments, token_ids = encode_supervisions(
|
supervision_segments, token_ids = encode_supervisions(
|
||||||
supervisions,
|
supervisions,
|
||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Works with a BPE model
|
||||||
|
decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device)
|
||||||
|
dense_fsa_vec = k2.DenseFsaVec(
|
||||||
|
ctc_output,
|
||||||
|
supervision_segments,
|
||||||
|
allow_truncate=params.subsampling_factor - 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Works with a BPE model
|
ctc_loss = k2.ctc_loss(
|
||||||
decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device)
|
decoding_graph=decoding_graph,
|
||||||
dense_fsa_vec = k2.DenseFsaVec(
|
dense_fsa_vec=dense_fsa_vec,
|
||||||
ctc_output,
|
output_beam=params.beam_size,
|
||||||
supervision_segments,
|
reduction="sum",
|
||||||
allow_truncate=params.subsampling_factor - 1,
|
use_double_scores=params.use_double_scores,
|
||||||
)
|
)
|
||||||
|
assert ctc_loss.requires_grad == is_training
|
||||||
ctc_loss = k2.ctc_loss(
|
loss += params.ctc_loss_scale * ctc_loss
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
dense_fsa_vec=dense_fsa_vec,
|
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||||
output_beam=params.beam_size,
|
|
||||||
reduction="sum",
|
|
||||||
use_double_scores=params.use_double_scores,
|
|
||||||
)
|
|
||||||
assert ctc_loss.requires_grad == is_training
|
|
||||||
loss += params.ctc_loss_scale * ctc_loss
|
|
||||||
'''
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = MetricsTracker()
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||||
@ -762,7 +765,6 @@ def compute_loss(
|
|||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||||
#info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user