mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
Add modified beam search.
This commit is contained in:
parent
c3b3123b27
commit
77261bc575
@ -17,7 +17,6 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
|
||||||
@ -108,8 +107,11 @@ class Hypothesis:
|
|||||||
# Newly predicted tokens are appended to `ys`.
|
# Newly predicted tokens are appended to `ys`.
|
||||||
ys: List[int]
|
ys: List[int]
|
||||||
|
|
||||||
# The log prob of ys
|
# The log prob of ys.
|
||||||
log_prob: float
|
# It contains only one entry.
|
||||||
|
# TODO(fangjun): It was a float before. We need to change its usage
|
||||||
|
# in greedy_search and beam_search.
|
||||||
|
log_prob: torch.Tensor
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def key(self) -> str:
|
def key(self) -> str:
|
||||||
@ -145,8 +147,10 @@ class HypothesisList(object):
|
|||||||
"""
|
"""
|
||||||
key = hyp.key
|
key = hyp.key
|
||||||
if key in self:
|
if key in self:
|
||||||
old_hyp = self._data[key]
|
old_hyp = self._data[key] # shallow copy
|
||||||
old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob)
|
torch.logaddexp(
|
||||||
|
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self._data[key] = hyp
|
self._data[key] = hyp
|
||||||
|
|
||||||
@ -348,47 +352,70 @@ def modified_beam_search(
|
|||||||
T = encoder_out.size(1)
|
T = encoder_out.size(1)
|
||||||
|
|
||||||
B = HypothesisList()
|
B = HypothesisList()
|
||||||
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
|
B.add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=[blank_id] * context_size,
|
||||||
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
encoder_out_len = torch.tensor([1])
|
encoder_out_len = torch.tensor([1])
|
||||||
decoder_out_len = torch.tensor([1])
|
decoder_out_len = torch.tensor([1])
|
||||||
|
|
||||||
decoder_cache: Dict[str, torch.Tensor] = {}
|
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||||
|
# current_encoder_out is of shape (1, 1, encoder_out_dim)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
A = B
|
A = list(B)
|
||||||
B = HypothesisList()
|
B = HypothesisList()
|
||||||
|
|
||||||
joint_cache: Dict[str, torch.Tensor] = {}
|
ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A])
|
||||||
|
# ys_log_probs is of shape (num_hyps, 1)
|
||||||
|
|
||||||
for hyp in A:
|
decoder_input = torch.tensor(
|
||||||
decoder_out = run_decoder(
|
[hyp.ys[-context_size:] for hyp in A],
|
||||||
ys=hyp.ys, model=model, decoder_cache=decoder_cache
|
device=device,
|
||||||
)
|
)
|
||||||
key = "_".join(map(str, hyp.ys[-context_size:]))
|
# decoder_input is of shape (num_hyps, context_size)
|
||||||
key += f"-t-{t}"
|
|
||||||
log_prob = run_joiner(
|
|
||||||
key=key,
|
|
||||||
model=model,
|
|
||||||
encoder_out=current_encoder_out,
|
|
||||||
decoder_out=decoder_out,
|
|
||||||
encoder_out_len=encoder_out_len,
|
|
||||||
decoder_out_len=decoder_out_len,
|
|
||||||
joint_cache=joint_cache,
|
|
||||||
)
|
|
||||||
log_prob = log_prob.cpu().tolist()
|
|
||||||
|
|
||||||
for i, v in enumerate(log_prob):
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
if i == blank_id:
|
# decoder_output is of shape (num_hyps, 1, decoder_output_dim)
|
||||||
# Use [:] to make a copy
|
|
||||||
new_ys = hyp.ys[:]
|
current_encoder_out = current_encoder_out.expand(
|
||||||
else:
|
decoder_out.size(0), 1, -1
|
||||||
new_ys = hyp.ys + [i]
|
)
|
||||||
new_hyp = Hypothesis(ys=new_ys, log_prob=hyp.log_prob + v)
|
|
||||||
B.add(new_hyp)
|
logits = model.joiner(
|
||||||
B = B.topk(beam)
|
current_encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
encoder_out_len.expand(decoder_out.size(0)),
|
||||||
|
decoder_out_len.expand(decoder_out.size(0)),
|
||||||
|
)
|
||||||
|
# logits is of shape (num_hyps, vocab_size)
|
||||||
|
log_probs = logits.log_softmax(dim=-1)
|
||||||
|
|
||||||
|
log_probs.add_(ys_log_probs)
|
||||||
|
|
||||||
|
log_probs = log_probs.reshape(-1)
|
||||||
|
topk_log_probs, topk_indexes = log_probs.topk(beam)
|
||||||
|
|
||||||
|
# topk_hyp_indexes are indexes into `A`
|
||||||
|
topk_hyp_indexes = topk_indexes // logits.size(-1)
|
||||||
|
topk_token_indexes = topk_indexes % logits.size(-1)
|
||||||
|
|
||||||
|
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
||||||
|
topk_token_indexes = topk_token_indexes.tolist()
|
||||||
|
|
||||||
|
for i in range(len(topk_hyp_indexes)):
|
||||||
|
hyp = A[topk_hyp_indexes[i]]
|
||||||
|
new_ys = hyp.ys[:]
|
||||||
|
new_token = topk_token_indexes[i]
|
||||||
|
if new_token != blank_id:
|
||||||
|
new_ys.append(new_token)
|
||||||
|
new_log_prob = topk_log_probs[i]
|
||||||
|
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||||
|
B.add(new_hyp)
|
||||||
|
|
||||||
best_hyp = B.get_most_probable(length_norm=True)
|
best_hyp = B.get_most_probable(length_norm=True)
|
||||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||||
|
@ -75,24 +75,24 @@ class Decoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
y:
|
y:
|
||||||
A 2-D tensor of shape (N, U) with blank prepended.
|
A 2-D tensor of shape (N, U).
|
||||||
need_pad:
|
need_pad:
|
||||||
True to left pad the input. Should be True during training.
|
True to left pad the input. Should be True during training.
|
||||||
False to not pad the input. Should be False during inference.
|
False to not pad the input. Should be False during inference.
|
||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, U, embedding_dim).
|
Return a tensor of shape (N, U, embedding_dim).
|
||||||
"""
|
"""
|
||||||
embeding_out = self.embedding(y)
|
embedding_out = self.embedding(y)
|
||||||
if self.context_size > 1:
|
if self.context_size > 1:
|
||||||
embeding_out = embeding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
if need_pad is True:
|
if need_pad is True:
|
||||||
embeding_out = F.pad(
|
embedding_out = F.pad(
|
||||||
embeding_out, pad=(self.context_size - 1, 0)
|
embedding_out, pad=(self.context_size - 1, 0)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# During inference time, there is no need to do extra padding
|
# During inference time, there is no need to do extra padding
|
||||||
# as we only need one output
|
# as we only need one output
|
||||||
assert embeding_out.size(-1) == self.context_size
|
assert embedding_out.size(-1) == self.context_size
|
||||||
embeding_out = self.conv(embeding_out)
|
embedding_out = self.conv(embedding_out)
|
||||||
embeding_out = embeding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
return embeding_out
|
return embedding_out
|
||||||
|
Loading…
x
Reference in New Issue
Block a user