mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 15:14:18 +00:00
Add greedy search in batch mode.
This commit is contained in:
parent
aa71eaaac7
commit
7fa5860073
@ -229,7 +229,11 @@ def greedy_search_batch(
|
|||||||
if emitted:
|
if emitted:
|
||||||
# update decoder output
|
# update decoder output
|
||||||
decoder_input = [h[-context_size:] for h in hyps]
|
decoder_input = [h[-context_size:] for h in hyps]
|
||||||
decoder_input = torch.tensor(decoder_input, device=device)
|
decoder_input = torch.tensor(
|
||||||
|
decoder_input,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.in64,
|
||||||
|
)
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
|
||||||
ans = [h[context_size:] for h in hyps]
|
ans = [h[context_size:] for h in hyps]
|
||||||
|
@ -192,7 +192,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
type=int,
|
type=int,
|
||||||
default=3,
|
default=1,
|
||||||
help="""Maximum number of symbols per frame.
|
help="""Maximum number of symbols per frame.
|
||||||
Used only when --decoding_method is greedy_search""",
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
|
@ -24,7 +24,7 @@ from model import Transducer
|
|||||||
def greedy_search(
|
def greedy_search(
|
||||||
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
|
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""
|
"""Greedy search for a single utterance.
|
||||||
Args:
|
Args:
|
||||||
model:
|
model:
|
||||||
An instance of `Transducer`.
|
An instance of `Transducer`.
|
||||||
@ -80,7 +80,7 @@ def greedy_search(
|
|||||||
logits = model.joiner(
|
logits = model.joiner(
|
||||||
current_encoder_out, decoder_out, encoder_out_len, decoder_out_len
|
current_encoder_out, decoder_out, encoder_out_len, decoder_out_len
|
||||||
)
|
)
|
||||||
# logits is (1, 1, 1, vocab_size)
|
# logits is (1, vocab_size)
|
||||||
|
|
||||||
y = logits.argmax().item()
|
y = logits.argmax().item()
|
||||||
if y != blank_id:
|
if y != blank_id:
|
||||||
@ -101,6 +101,75 @@ def greedy_search(
|
|||||||
return hyp
|
return hyp
|
||||||
|
|
||||||
|
|
||||||
|
def greedy_search_batch(
|
||||||
|
model: Transducer, encoder_out: torch.Tensor
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The transducer model.
|
||||||
|
encoder_out:
|
||||||
|
Output from the encoder. Its shape is (N, T, C), where N >= 1.
|
||||||
|
Returns:
|
||||||
|
Return a list-of-list of token IDs containing the decoded results.
|
||||||
|
len(ans) equals to encoder_out.size(0).
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == 3
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
|
device = model.device
|
||||||
|
|
||||||
|
batch_size = encoder_out.size(0)
|
||||||
|
T = encoder_out.size(1)
|
||||||
|
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
|
hyps = [[blank_id] * context_size for _ in range(batch_size)]
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
hyps,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (batch_size, context_size)
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
# decoder_out: (batch_size, 1, decoder_out_dim)
|
||||||
|
|
||||||
|
encoder_out_len = torch.ones(batch_size, dtype=torch.int32)
|
||||||
|
decoder_out_len = torch.ones(batch_size, dtype=torch.int32)
|
||||||
|
|
||||||
|
for t in range(T):
|
||||||
|
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
|
||||||
|
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
|
||||||
|
logits = model.joiner(
|
||||||
|
current_encoder_out, decoder_out, encoder_out_len, decoder_out_len
|
||||||
|
) # (batch_size, vocab_size)
|
||||||
|
|
||||||
|
assert logits.ndim == 2, logits.shape
|
||||||
|
y = logits.argmax(dim=1).tolist()
|
||||||
|
emitted = False
|
||||||
|
for i, v in enumerate(y):
|
||||||
|
if v != blank_id:
|
||||||
|
hyps[i].append(v)
|
||||||
|
emitted = True
|
||||||
|
|
||||||
|
if emitted:
|
||||||
|
# update decoder output
|
||||||
|
decoder_input = [h[-context_size:] for h in hyps]
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
decoder_input,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (batch_size, context_size)
|
||||||
|
decoder_out = model.decoder(
|
||||||
|
decoder_input,
|
||||||
|
need_pad=False,
|
||||||
|
) # (batch_size, 1, decoder_out_dim)
|
||||||
|
|
||||||
|
ans = [h[context_size:] for h in hyps]
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Hypothesis:
|
class Hypothesis:
|
||||||
# The predicted tokens so far.
|
# The predicted tokens so far.
|
||||||
@ -252,9 +321,11 @@ def run_decoder(
|
|||||||
|
|
||||||
device = model.device
|
device = model.device
|
||||||
|
|
||||||
decoder_input = torch.tensor([ys[-context_size:]], device=device).reshape(
|
decoder_input = torch.tensor(
|
||||||
1, context_size
|
[ys[-context_size:]],
|
||||||
)
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
).reshape(1, context_size)
|
||||||
|
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
decoder_cache[key] = decoder_out
|
decoder_cache[key] = decoder_out
|
||||||
@ -341,12 +412,6 @@ def modified_beam_search(
|
|||||||
|
|
||||||
device = model.device
|
device = model.device
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
|
||||||
[blank_id] * context_size, device=device
|
|
||||||
).reshape(1, context_size)
|
|
||||||
|
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
|
||||||
|
|
||||||
T = encoder_out.size(1)
|
T = encoder_out.size(1)
|
||||||
|
|
||||||
B = HypothesisList()
|
B = HypothesisList()
|
||||||
|
@ -55,8 +55,13 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from beam_search import beam_search, greedy_search, modified_beam_search
|
from beam_search import (
|
||||||
from train import get_transducer_model, get_params
|
beam_search,
|
||||||
|
greedy_search,
|
||||||
|
greedy_search_batch,
|
||||||
|
modified_beam_search,
|
||||||
|
)
|
||||||
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
@ -131,7 +136,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
type=int,
|
type=int,
|
||||||
default=3,
|
default=1,
|
||||||
help="""Maximum number of symbols per frame.
|
help="""Maximum number of symbols per frame.
|
||||||
Used only when --decoding_method is greedy_search""",
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
@ -183,32 +188,47 @@ def decode_one_batch(
|
|||||||
encoder_out, encoder_out_lens = model.encoder(
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
x=feature, x_lens=feature_lens
|
x=feature, x_lens=feature_lens
|
||||||
)
|
)
|
||||||
hyps = []
|
hyp_list: List[List[int]] = []
|
||||||
batch_size = encoder_out.size(0)
|
|
||||||
|
|
||||||
for i in range(batch_size):
|
if (
|
||||||
# fmt: off
|
params.decoding_method == "greedy_search"
|
||||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
and params.max_sym_per_frame == 1
|
||||||
# fmt: on
|
):
|
||||||
if params.decoding_method == "greedy_search":
|
hyp_list = greedy_search_batch(
|
||||||
hyp = greedy_search(
|
model=model,
|
||||||
model=model,
|
encoder_out=encoder_out,
|
||||||
encoder_out=encoder_out_i,
|
)
|
||||||
max_sym_per_frame=params.max_sym_per_frame,
|
else:
|
||||||
)
|
batch_size = encoder_out.size(0)
|
||||||
elif params.decoding_method == "beam_search":
|
for i in range(batch_size):
|
||||||
hyp = beam_search(
|
# fmt: off
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
)
|
# fmt: on
|
||||||
elif params.decoding_method == "modified_beam_search":
|
if params.decoding_method == "greedy_search":
|
||||||
hyp = modified_beam_search(
|
hyp = greedy_search(
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
model=model,
|
||||||
)
|
encoder_out=encoder_out_i,
|
||||||
else:
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
raise ValueError(
|
)
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
elif params.decoding_method == "beam_search":
|
||||||
)
|
hyp = beam_search(
|
||||||
hyps.append(sp.decode(hyp).split())
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "modified_beam_search":
|
||||||
|
hyp = modified_beam_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
|
)
|
||||||
|
hyp_list.append(hyp)
|
||||||
|
|
||||||
|
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search": hyps}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user