mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 23:54:17 +00:00
Ignore padding frames during RNN-T decoding.
This commit is contained in:
parent
bc284e88e6
commit
a35e949a25
@ -335,7 +335,9 @@ def greedy_search(
|
|||||||
|
|
||||||
|
|
||||||
def greedy_search_batch(
|
def greedy_search_batch(
|
||||||
model: Transducer, encoder_out: torch.Tensor
|
model: Transducer,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
Args:
|
Args:
|
||||||
@ -343,6 +345,9 @@ def greedy_search_batch(
|
|||||||
The transducer model.
|
The transducer model.
|
||||||
encoder_out:
|
encoder_out:
|
||||||
Output from the encoder. Its shape is (N, T, C), where N >= 1.
|
Output from the encoder. Its shape is (N, T, C), where N >= 1.
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||||
|
encoder_out before padding.
|
||||||
Returns:
|
Returns:
|
||||||
Return a list-of-list of token IDs containing the decoded results.
|
Return a list-of-list of token IDs containing the decoded results.
|
||||||
len(ans) equals to encoder_out.size(0).
|
len(ans) equals to encoder_out.size(0).
|
||||||
@ -350,31 +355,49 @@ def greedy_search_batch(
|
|||||||
assert encoder_out.ndim == 3
|
assert encoder_out.ndim == 3
|
||||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
device = next(model.parameters()).device
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
batch_size = encoder_out.size(0)
|
device = next(model.parameters()).device
|
||||||
T = encoder_out.size(1)
|
|
||||||
|
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
unk_id = getattr(model, "unk_id", blank_id)
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
hyps = [[blank_id] * context_size for _ in range(batch_size)]
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
hyps,
|
hyps,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
) # (batch_size, context_size)
|
) # (N, context_size)
|
||||||
|
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
# decoder_out: (N, 1, decoder_out_dim)
|
||||||
|
|
||||||
# decoder_out: (batch_size, 1, decoder_out_dim)
|
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||||
for t in range(T):
|
|
||||||
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
offset = 0
|
||||||
|
for batch_size in batch_size_list:
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
|
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||||
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
decoder_out = decoder_out[:batch_size]
|
||||||
|
|
||||||
logits = model.joiner(
|
logits = model.joiner(
|
||||||
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
||||||
)
|
)
|
||||||
@ -390,7 +413,7 @@ def greedy_search_batch(
|
|||||||
emitted = True
|
emitted = True
|
||||||
if emitted:
|
if emitted:
|
||||||
# update decoder output
|
# update decoder output
|
||||||
decoder_input = [h[-context_size:] for h in hyps]
|
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
decoder_input,
|
decoder_input,
|
||||||
device=device,
|
device=device,
|
||||||
@ -399,7 +422,12 @@ def greedy_search_batch(
|
|||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
|
||||||
ans = [h[context_size:] for h in hyps]
|
sorted_ans = [h[context_size:] for h in hyps]
|
||||||
|
ans = []
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
for i in range(N):
|
||||||
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
@ -557,6 +585,7 @@ def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
|||||||
def modified_beam_search(
|
def modified_beam_search(
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||||
@ -566,6 +595,9 @@ def modified_beam_search(
|
|||||||
The transducer model.
|
The transducer model.
|
||||||
encoder_out:
|
encoder_out:
|
||||||
Output from the encoder. Its shape is (N, T, C).
|
Output from the encoder. Its shape is (N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||||
|
encoder_out before padding.
|
||||||
beam:
|
beam:
|
||||||
Number of active paths during the beam search.
|
Number of active paths during the beam search.
|
||||||
Returns:
|
Returns:
|
||||||
@ -573,16 +605,26 @@ def modified_beam_search(
|
|||||||
for the i-th utterance.
|
for the i-th utterance.
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == 3, encoder_out.shape
|
assert encoder_out.ndim == 3, encoder_out.shape
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
batch_size = encoder_out.size(0)
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
T = encoder_out.size(1)
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
unk_id = getattr(model, "unk_id", blank_id)
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
B = [HypothesisList() for _ in range(batch_size)]
|
|
||||||
for i in range(batch_size):
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
B = [HypothesisList() for _ in range(N)]
|
||||||
|
for i in range(N):
|
||||||
B[i].add(
|
B[i].add(
|
||||||
Hypothesis(
|
Hypothesis(
|
||||||
ys=[blank_id] * context_size,
|
ys=[blank_id] * context_size,
|
||||||
@ -590,11 +632,20 @@ def modified_beam_search(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||||
|
|
||||||
for t in range(T):
|
offset = 0
|
||||||
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
finalized_B = []
|
||||||
|
for batch_size in batch_size_list:
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
|
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||||
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
finalized_B = B[batch_size:] + finalized_B
|
||||||
|
B = B[:batch_size]
|
||||||
|
|
||||||
hyps_shape = _get_hyps_shape(B).to(device)
|
hyps_shape = _get_hyps_shape(B).to(device)
|
||||||
|
|
||||||
@ -668,8 +719,14 @@ def modified_beam_search(
|
|||||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||||
B[i].add(new_hyp)
|
B[i].add(new_hyp)
|
||||||
|
|
||||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
B = B + finalized_B
|
||||||
ans = [h.ys[context_size:] for h in best_hyps]
|
best_hyps = [b.get_most_probable(length_norm=False) for b in B]
|
||||||
|
|
||||||
|
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||||
|
ans = []
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
for i in range(N):
|
||||||
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
@ -270,6 +270,7 @@ def decode_one_batch(
|
|||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -277,6 +278,7 @@ def decode_one_batch(
|
|||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
@ -307,6 +307,7 @@ def decode_one_batch(
|
|||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -314,6 +315,7 @@ def decode_one_batch(
|
|||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
@ -284,6 +284,7 @@ def decode_one_batch(
|
|||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -291,6 +292,7 @@ def decode_one_batch(
|
|||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user