Minor fixes.

This commit is contained in:
Fangjun Kuang 2021-10-18 17:23:08 +08:00
parent b1c08e3f9d
commit 6370d519d2

View File

@ -379,10 +379,11 @@ def compute_loss(
cut_ids=cut_ids, cut_ids=cut_ids,
alignments=ali, alignments=ali,
num_classes=nnet_output.shape[2], num_classes=nnet_output.shape[2],
log_score=-1,
).to(nnet_output) ).to(nnet_output)
min_len = min(nnet_output.shape[1], mask.shape[1]) min_len = min(nnet_output.shape[1], mask.shape[1])
ali_scale = 500.0 / (params.batch_idx_train + 500) ali_scale = 5000.0 / (params.batch_idx_train + 5000)
nnet_output = nnet_output.clone() nnet_output = nnet_output.clone()
nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :]