mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Merge f5f6903581d8d56b6a8b7236d2c3077c92a38902 into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9
This commit is contained in:
commit
b881007e2f
@ -1579,7 +1579,12 @@ def modified_beam_search_lm_rescore(
|
||||
y = y.to(device).to(torch.int64)
|
||||
sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64)
|
||||
|
||||
lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths)
|
||||
if LM.lm_type == "rnn":
|
||||
lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths)
|
||||
else:
|
||||
lm_scores = LM.lm(x=x, y=y, x_lens=sentence_token_lengths)
|
||||
lm_scores = lm_scores.reshape(x.size(0), -1)
|
||||
|
||||
assert lm_scores.ndim == 2
|
||||
lm_scores = -1 * lm_scores.sum(dim=1)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user