Merge f5f6903581d8d56b6a8b7236d2c3077c92a38902 into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9

This commit is contained in:
Ryuk 2025-08-08 15:14:11 +08:00 committed by GitHub
commit b881007e2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1579,7 +1579,12 @@ def modified_beam_search_lm_rescore(
y = y.to(device).to(torch.int64)
sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64)
lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths)
if LM.lm_type == "rnn":
lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths)
else:
lm_scores = LM.lm(x=x, y=y, x_lens=sentence_token_lengths)
lm_scores = lm_scores.reshape(x.size(0), -1)
assert lm_scores.ndim == 2
lm_scores = -1 * lm_scores.sum(dim=1)