mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
support transformer-lm when using lm rescoring
This commit is contained in:
parent
a6c02a4d8c
commit
9865cd379f
@ -1576,7 +1576,12 @@ def modified_beam_search_lm_rescore(
|
|||||||
y = y.to(device).to(torch.int64)
|
y = y.to(device).to(torch.int64)
|
||||||
sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64)
|
sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64)
|
||||||
|
|
||||||
lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths)
|
if LM.lm_type == "rnn":
|
||||||
|
lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths)
|
||||||
|
else:
|
||||||
|
lm_scores = LM.lm(x=x, y=y, x_lens=sentence_token_lengths)
|
||||||
|
lm_scores = lm_scores.reshape(x.size(0), -1)
|
||||||
|
|
||||||
assert lm_scores.ndim == 2
|
assert lm_scores.ndim == 2
|
||||||
lm_scores = -1 * lm_scores.sum(dim=1)
|
lm_scores = -1 * lm_scores.sum(dim=1)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user