Ignore padding frames during RNN-T decoding.

This commit is contained in:
Fangjun Kuang 2022-05-11 22:46:44 +08:00
parent bc284e88e6
commit a35e949a25
4 changed files with 85 additions and 22 deletions

View File

@ -335,7 +335,9 @@ def greedy_search(
def greedy_search_batch(
model: Transducer, encoder_out: torch.Tensor
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
@ -343,6 +345,9 @@ def greedy_search_batch(
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
Returns:
Return a list-of-list of token IDs containing the decoded results.
len(ans) equals to encoder_out.size(0).
@ -350,31 +355,49 @@ def greedy_search_batch(
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
device = next(model.parameters()).device
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
device = next(model.parameters()).device
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(batch_size)]
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (batch_size, context_size)
) # (N, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
encoder_out = model.joiner.encoder_proj(encoder_out)
# decoder_out: (N, 1, decoder_out_dim)
# decoder_out: (batch_size, 1, decoder_out_dim)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
)
@ -390,7 +413,7 @@ def greedy_search_batch(
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps]
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
device=device,
@ -399,7 +422,12 @@ def greedy_search_batch(
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
ans = [h[context_size:] for h in hyps]
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@ -557,6 +585,7 @@ def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
def modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: int = 4,
) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
@ -566,6 +595,9 @@ def modified_beam_search(
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
beam:
Number of active paths during the beam search.
Returns:
@ -573,16 +605,26 @@ def modified_beam_search(
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size):
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
@ -590,11 +632,20 @@ def modified_beam_search(
)
)
encoder_out = model.joiner.encoder_proj(encoder_out)
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
offset = 0
finalized_B = []
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = _get_hyps_shape(B).to(device)
@ -668,8 +719,14 @@ def modified_beam_search(
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
ans = [h.ys[context_size:] for h in best_hyps]
B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=False) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans

View File

@ -270,6 +270,7 @@ def decode_one_batch(
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -277,6 +278,7 @@ def decode_one_batch(
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):

View File

@ -307,6 +307,7 @@ def decode_one_batch(
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -314,6 +315,7 @@ def decode_one_batch(
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):

View File

@ -284,6 +284,7 @@ def decode_one_batch(
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -291,6 +292,7 @@ def decode_one_batch(
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):