mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
minor updates
This commit is contained in:
parent
d70f6e21f2
commit
634931cb61
@ -779,6 +779,78 @@ def greedy_search_batch(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def deprecated_greedy_search_batch(
|
||||||
|
model: nn.Module, encoder_out: torch.Tensor
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The transducer model.
|
||||||
|
encoder_out:
|
||||||
|
Output from the encoder. Its shape is (N, T, C), where N >= 1.
|
||||||
|
Returns:
|
||||||
|
Return a list-of-list of token IDs containing the decoded results.
|
||||||
|
len(ans) equals to encoder_out.size(0).
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == 3
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
|
batch_size = encoder_out.size(0)
|
||||||
|
T = encoder_out.size(1)
|
||||||
|
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
|
hyps = [[blank_id] * context_size for _ in range(batch_size)]
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
hyps,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (batch_size, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||||
|
|
||||||
|
# decoder_out: (batch_size, 1, decoder_out_dim)
|
||||||
|
for t in range(T):
|
||||||
|
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
||||||
|
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||||
|
logits = model.joiner(
|
||||||
|
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
||||||
|
)
|
||||||
|
print(current_encoder_out)
|
||||||
|
print(decoder_out.unsqueeze(1))
|
||||||
|
print(logits)
|
||||||
|
exit()
|
||||||
|
# logits'shape (batch_size, 1, 1, vocab_size)
|
||||||
|
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
||||||
|
assert logits.ndim == 2, logits.shape
|
||||||
|
y = logits.argmax(dim=1).tolist()
|
||||||
|
emitted = False
|
||||||
|
for i, v in enumerate(y):
|
||||||
|
if v not in (blank_id, unk_id):
|
||||||
|
hyps[i].append(v)
|
||||||
|
emitted = True
|
||||||
|
if emitted:
|
||||||
|
# update decoder output
|
||||||
|
decoder_input = [h[-context_size:] for h in hyps]
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
decoder_input,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
|
||||||
|
ans = [h[context_size:] for h in hyps]
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
def deprecated_greedy_search_batch_for_cross_attn(
|
def deprecated_greedy_search_batch_for_cross_attn(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
@ -822,23 +894,29 @@ def deprecated_greedy_search_batch_for_cross_attn(
|
|||||||
|
|
||||||
# decoder_out: (batch_size, 1, decoder_out_dim)
|
# decoder_out: (batch_size, 1, decoder_out_dim)
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
# current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
||||||
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||||
current_encoder_out = model.joiner.label_level_am_attention(
|
attn_encoder_out = model.joiner.label_level_am_attention(
|
||||||
encoder_out.unsqueeze(2),
|
encoder_out.unsqueeze(2),
|
||||||
decoder_out.unsqueeze(2),
|
decoder_out.unsqueeze(2),
|
||||||
# encoder_out_lens,
|
encoder_out_lens,
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
|
# print(encoder_out[:, t : t + 1, :].unsqueeze(2))
|
||||||
|
# current_encoder_out = torch.zeros_like(current_encoder_out)
|
||||||
logits = model.joiner(
|
logits = model.joiner(
|
||||||
current_encoder_out,
|
current_encoder_out,
|
||||||
decoder_out.unsqueeze(1),
|
decoder_out.unsqueeze(1),
|
||||||
|
torch.zeros_like(current_encoder_out),
|
||||||
None,
|
None,
|
||||||
apply_attn=False,
|
apply_attn=False,
|
||||||
project_input=False,
|
project_input=False,
|
||||||
)
|
)
|
||||||
# logits'shape (batch_size, 1, 1, vocab_size)
|
# logits'shape (batch_size, 1, 1, vocab_size)
|
||||||
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
||||||
|
# print(current_encoder_out)
|
||||||
|
# print(decoder_out)
|
||||||
|
# print(logits)
|
||||||
|
# # exit()
|
||||||
assert logits.ndim == 2, logits.shape
|
assert logits.ndim == 2, logits.shape
|
||||||
y = logits.argmax(dim=1).tolist()
|
y = logits.argmax(dim=1).tolist()
|
||||||
emitted = False
|
emitted = False
|
||||||
|
@ -39,6 +39,7 @@ class Joiner(nn.Module):
|
|||||||
self,
|
self,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
decoder_out: torch.Tensor,
|
decoder_out: torch.Tensor,
|
||||||
|
attn_encoder_out: torch.Tensor,
|
||||||
lengths: torch.Tensor,
|
lengths: torch.Tensor,
|
||||||
apply_attn: bool = True,
|
apply_attn: bool = True,
|
||||||
project_input: bool = True,
|
project_input: bool = True,
|
||||||
@ -64,14 +65,14 @@ class Joiner(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if apply_attn and lengths is not None:
|
if apply_attn and lengths is not None:
|
||||||
encoder_out = self.label_level_am_attention(
|
attn_encoder_out = self.label_level_am_attention(
|
||||||
encoder_out, decoder_out, lengths
|
encoder_out, decoder_out, lengths
|
||||||
)
|
)
|
||||||
|
|
||||||
if project_input:
|
if project_input:
|
||||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
||||||
else:
|
else:
|
||||||
logit = encoder_out + decoder_out
|
logit = encoder_out + decoder_out + attn_encoder_out
|
||||||
|
|
||||||
logit = self.output_linear(torch.tanh(logit))
|
logit = self.output_linear(torch.tanh(logit))
|
||||||
|
|
||||||
|
@ -265,7 +265,12 @@ class AsrModel(nn.Module):
|
|||||||
# project_input=False since we applied the decoder's input projections
|
# project_input=False since we applied the decoder's input projections
|
||||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||||
logits = self.joiner(
|
logits = self.joiner(
|
||||||
am_pruned, lm_pruned, encoder_out_lens, apply_attn=True, project_input=False
|
am_pruned,
|
||||||
|
lm_pruned,
|
||||||
|
None,
|
||||||
|
encoder_out_lens,
|
||||||
|
apply_attn=True,
|
||||||
|
project_input=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user