mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Minor fixes.
This commit is contained in:
parent
b1c08e3f9d
commit
6370d519d2
@ -379,10 +379,11 @@ def compute_loss(
|
|||||||
cut_ids=cut_ids,
|
cut_ids=cut_ids,
|
||||||
alignments=ali,
|
alignments=ali,
|
||||||
num_classes=nnet_output.shape[2],
|
num_classes=nnet_output.shape[2],
|
||||||
|
log_score=-1,
|
||||||
).to(nnet_output)
|
).to(nnet_output)
|
||||||
|
|
||||||
min_len = min(nnet_output.shape[1], mask.shape[1])
|
min_len = min(nnet_output.shape[1], mask.shape[1])
|
||||||
ali_scale = 500.0 / (params.batch_idx_train + 500)
|
ali_scale = 5000.0 / (params.batch_idx_train + 5000)
|
||||||
|
|
||||||
nnet_output = nnet_output.clone()
|
nnet_output = nnet_output.clone()
|
||||||
nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :]
|
nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user