support transformer-lm in lm rescoring

This commit is contained in:
Ryuk 2024-08-25 16:25:09 +08:00
parent a6c02a4d8c
commit 6eedf4d156

View File

@ -1575,8 +1575,13 @@ def modified_beam_search_lm_rescore(
x = x.to(device).to(torch.int64)
y = y.to(device).to(torch.int64)
sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64)
if LM.lm_type == "rnn":
lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths)
else:
lm_scores = LM.lm(x=x, y=y, x_lens=sentence_token_lengths)
lm_scores = lm_scores.reshape(x.size(0), -1)
lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths)
assert lm_scores.ndim == 2
lm_scores = -1 * lm_scores.sum(dim=1)