mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Update beam_search.py
This commit is contained in:
parent
79b8e60f93
commit
0ed25de93a
@ -779,74 +779,6 @@ def greedy_search_batch(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def deprecated_greedy_search_batch_for_cross_attn(
|
|
||||||
model: nn.Module, encoder_out: torch.Tensor
|
|
||||||
) -> List[List[int]]:
|
|
||||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
|
||||||
Args:
|
|
||||||
model:
|
|
||||||
The transducer model.
|
|
||||||
encoder_out:
|
|
||||||
Output from the encoder. Its shape is (N, T, C), where N >= 1.
|
|
||||||
Returns:
|
|
||||||
Return a list-of-list of token IDs containing the decoded results.
|
|
||||||
len(ans) equals to encoder_out.size(0).
|
|
||||||
"""
|
|
||||||
assert encoder_out.ndim == 3
|
|
||||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
|
||||||
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
|
|
||||||
batch_size = encoder_out.size(0)
|
|
||||||
T = encoder_out.size(1)
|
|
||||||
|
|
||||||
blank_id = model.decoder.blank_id
|
|
||||||
unk_id = getattr(model, "unk_id", blank_id)
|
|
||||||
context_size = model.decoder.context_size
|
|
||||||
|
|
||||||
hyps = [[blank_id] * context_size for _ in range(batch_size)]
|
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
|
||||||
hyps,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int64,
|
|
||||||
) # (batch_size, context_size)
|
|
||||||
|
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
|
||||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
|
||||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
|
||||||
|
|
||||||
# decoder_out: (batch_size, 1, decoder_out_dim)
|
|
||||||
for t in range(T):
|
|
||||||
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
|
||||||
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
|
||||||
logits = model.joiner(
|
|
||||||
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
|
||||||
)
|
|
||||||
# logits'shape (batch_size, 1, 1, vocab_size)
|
|
||||||
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
|
||||||
assert logits.ndim == 2, logits.shape
|
|
||||||
y = logits.argmax(dim=1).tolist()
|
|
||||||
emitted = False
|
|
||||||
for i, v in enumerate(y):
|
|
||||||
if v not in (blank_id, unk_id):
|
|
||||||
hyps[i].append(v)
|
|
||||||
emitted = True
|
|
||||||
if emitted:
|
|
||||||
# update decoder output
|
|
||||||
decoder_input = [h[-context_size:] for h in hyps]
|
|
||||||
decoder_input = torch.tensor(
|
|
||||||
decoder_input,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
|
||||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
|
||||||
|
|
||||||
ans = [h[context_size:] for h in hyps]
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
def deprecated_greedy_search_batch_for_cross_attn(
|
def deprecated_greedy_search_batch_for_cross_attn(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user