mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
support transformer-lm in lm rescoring
This commit is contained in:
parent
a6c02a4d8c
commit
6eedf4d156
@ -1575,8 +1575,13 @@ def modified_beam_search_lm_rescore(
|
||||
x = x.to(device).to(torch.int64)
|
||||
y = y.to(device).to(torch.int64)
|
||||
sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64)
|
||||
|
||||
if LM.lm_type == "rnn":
|
||||
lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths)
|
||||
else:
|
||||
lm_scores = LM.lm(x=x, y=y, x_lens=sentence_token_lengths)
|
||||
lm_scores = lm_scores.reshape(x.size(0), -1)
|
||||
|
||||
lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths)
|
||||
assert lm_scores.ndim == 2
|
||||
lm_scores = -1 * lm_scores.sum(dim=1)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user