Add greedy search in batch mode.

This commit is contained in:
Fangjun Kuang 2022-03-23 11:43:26 +08:00
parent aa71eaaac7
commit 7fa5860073
4 changed files with 130 additions and 41 deletions

View File

@ -229,7 +229,11 @@ def greedy_search_batch(
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps]
decoder_input = torch.tensor(decoder_input, device=device)
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.in64,
)
decoder_out = model.decoder(decoder_input, need_pad=False)
ans = [h[context_size:] for h in hyps]

View File

@ -192,7 +192,7 @@ def get_parser():
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)

View File

@ -24,7 +24,7 @@ from model import Transducer
def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
) -> List[int]:
"""
"""Greedy search for a single utterance.
Args:
model:
An instance of `Transducer`.
@ -80,7 +80,7 @@ def greedy_search(
logits = model.joiner(
current_encoder_out, decoder_out, encoder_out_len, decoder_out_len
)
# logits is (1, 1, 1, vocab_size)
# logits is (1, vocab_size)
y = logits.argmax().item()
if y != blank_id:
@ -101,6 +101,75 @@ def greedy_search(
return hyp
def greedy_search_batch(
model: Transducer, encoder_out: torch.Tensor
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
Returns:
Return a list-of-list of token IDs containing the decoded results.
len(ans) equals to encoder_out.size(0).
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
device = model.device
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(batch_size)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (batch_size, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
# decoder_out: (batch_size, 1, decoder_out_dim)
encoder_out_len = torch.ones(batch_size, dtype=torch.int32)
decoder_out_len = torch.ones(batch_size, dtype=torch.int32)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out, decoder_out, encoder_out_len, decoder_out_len
) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
) # (batch_size, context_size)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
) # (batch_size, 1, decoder_out_dim)
ans = [h[context_size:] for h in hyps]
return ans
@dataclass
class Hypothesis:
# The predicted tokens so far.
@ -252,9 +321,11 @@ def run_decoder(
device = model.device
decoder_input = torch.tensor([ys[-context_size:]], device=device).reshape(
1, context_size
)
decoder_input = torch.tensor(
[ys[-context_size:]],
device=device,
dtype=torch.int64,
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_cache[key] = decoder_out
@ -341,12 +412,6 @@ def modified_beam_search(
device = model.device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
T = encoder_out.size(1)
B = HypothesisList()

View File

@ -55,8 +55,13 @@ import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search
from train import get_transducer_model, get_params
from beam_search import (
beam_search,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import (
@ -131,7 +136,7 @@ def get_parser():
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
@ -183,32 +188,47 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
batch_size = encoder_out.size(0)
hyp_list: List[List[int]] = []
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyp_list.append(hyp)
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}