support transformer-lm when using lm rescoring

This commit is contained in:
Ryuk 2024-08-25 16:25:09 +08:00
parent a6c02a4d8c
commit 9865cd379f

View File

@ -1576,7 +1576,12 @@ def modified_beam_search_lm_rescore(
y = y.to(device).to(torch.int64)
sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64)
lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths)
if LM.lm_type == "rnn":
lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths)
else:
lm_scores = LM.lm(x=x, y=y, x_lens=sentence_token_lengths)
lm_scores = lm_scores.reshape(x.size(0), -1)
assert lm_scores.ndim == 2
lm_scores = -1 * lm_scores.sum(dim=1)