modified beam search for stateless3,4

This commit is contained in:
pkufool 2022-07-23 17:18:47 +08:00
parent 72d76a4ff8
commit 353863a55c
4 changed files with 114 additions and 292 deletions

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/streaming_beam_search.py

View File

@ -44,6 +44,11 @@ from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet from lhotse import CutSet
from librispeech import LibriSpeech from librispeech import LibriSpeech
from streaming_beam_search import (
fast_beam_search_one_best,
greedy_search,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -52,10 +57,8 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.decode import one_best_decoding
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
write_error_stats, write_error_stats,
@ -115,10 +118,21 @@ def get_parser():
"--decoding-method", "--decoding-method",
type=str, type=str,
default="greedy_search", default="greedy_search",
help="""Support only greedy_search and fast_beam_search now. help="""Supported decoding methods are:
greedy_search
modified_beam_search
fast_beam_search
""", """,
) )
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is modified_beam_search.""",
)
parser.add_argument( parser.add_argument(
"--beam", "--beam",
type=float, type=float,
@ -186,109 +200,6 @@ def get_parser():
return parser return parser
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
) -> List[List[int]]:
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# logging.info(f"decoder_out shape : {decoder_out.shape}")
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
decoder_out = model.joiner.decoder_proj(decoder_out)
hyp_tokens = []
for stream in streams:
hyp_tokens.append(stream.hyp)
return hyp_tokens
def fast_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
decoding_streams: k2.RnntDecodingStreams,
) -> List[List[int]]:
B, T, C = encoder_out.shape
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyp_tokens = get_texts(best_path)
return hyp_tokens
def decode_one_chunk( def decode_one_chunk(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -313,7 +224,6 @@ def decode_one_chunk(
feature_lens = [] feature_lens = []
states = [] states = []
rnnt_stream_list = []
processed_lens = [] processed_lens = []
for stream in decode_streams: for stream in decode_streams:
@ -324,8 +234,6 @@ def decode_one_chunk(
feature_lens.append(feat_len) feature_lens.append(feat_len)
states.append(stream.states) states.append(stream.states)
processed_lens.append(stream.done_frames) processed_lens.append(stream.done_frames)
if params.decoding_method == "fast_beam_search":
rnnt_stream_list.append(stream.rnnt_decoding_stream)
feature_lens = torch.tensor(feature_lens, device=device) feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
@ -337,19 +245,13 @@ def decode_one_chunk(
# frames. # frames.
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
if features.size(1) < tail_length: if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1) pad_length = tail_length - features.size(1)
features = torch.cat( feature_lens += pad_length
[ features = torch.nn.functional.pad(
features, features,
torch.tensor( (0, 0, 0, pad_length),
LOG_EPS, dtype=features.dtype, device=device mode="constant",
).expand( value=LOG_EPS,
features.size(0),
tail_length - features.size(1),
features.size(2),
),
],
dim=1,
) )
states = [ states = [
@ -370,22 +272,31 @@ def decode_one_chunk(
encoder_out = model.joiner.encoder_proj(encoder_out) encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
hyp_tokens = greedy_search(model, encoder_out, decode_streams) greedy_search(
elif params.decoding_method == "fast_beam_search": model=model, encoder_out=encoder_out, streams=decode_streams
config = k2.RnntDecodingConfig(
vocab_size=params.vocab_size,
decoder_history_len=params.context_size,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
) )
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) elif params.decoding_method == "fast_beam_search":
processed_lens = processed_lens + encoder_out_lens processed_lens = processed_lens + encoder_out_lens
hyp_tokens = fast_beam_search( fast_beam_search_one_best(
model, encoder_out, processed_lens, decoding_streams model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
streams=decode_streams,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=decode_streams,
encoder_out=encoder_out,
beam=params.beam_size,
) )
else: else:
assert False raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
@ -393,8 +304,6 @@ def decode_one_chunk(
for i in range(len(decode_streams)): for i in range(len(decode_streams)):
decode_streams[i].states = [states[0][i], states[1][i]] decode_streams[i].states = [states[0][i], states[1][i]]
decode_streams[i].done_frames += encoder_out_lens[i] decode_streams[i].done_frames += encoder_out_lens[i]
if params.decoding_method == "fast_beam_search":
decode_streams[i].hyp = hyp_tokens[i]
if decode_streams[i].done: if decode_streams[i].done:
finished_streams.append(i) finished_streams.append(i)
@ -478,13 +387,10 @@ def decode_dataset(
params=params, model=model, decode_streams=decode_streams params=params, model=model, decode_streams=decode_streams
) )
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] # noqa
decode_results.append( decode_results.append(
( (
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )
) )
del decode_streams[i] del decode_streams[i]
@ -498,24 +404,28 @@ def decode_dataset(
params=params, model=model, decode_streams=decode_streams params=params, model=model, decode_streams=decode_streams
) )
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] # noqa
decode_results.append( decode_results.append(
( (
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )
) )
del decode_streams[i] del decode_streams[i]
key = "greedy_search" if params.decoding_method == "greedy_search":
if params.decoding_method == "fast_beam_search": key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = ( key = (
f"beam_{params.beam}_" f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_" f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}" f"max_states_{params.max_states}"
) )
elif params.decoding_method == "modified_beam_search":
key = f"beam_size_{params.beam_size}"
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
return {key: decode_results} return {key: decode_results}

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/streaming_beam_search.py

View File

@ -43,6 +43,11 @@ from asr_datamodule import LibriSpeechAsrDataModule
from decode_stream import DecodeStream from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet from lhotse import CutSet
from streaming_beam_search import (
fast_beam_search_one_best,
greedy_search,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -52,10 +57,8 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.decode import one_best_decoding
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
str2bool, str2bool,
@ -127,10 +130,21 @@ def get_parser():
"--decoding-method", "--decoding-method",
type=str, type=str,
default="greedy_search", default="greedy_search",
help="""Support only greedy_search and fast_beam_search now. help="""Supported decoding methods are:
greedy_search
modified_beam_search
fast_beam_search
""", """,
) )
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is modified_beam_search.""",
)
parser.add_argument( parser.add_argument(
"--beam", "--beam",
type=float, type=float,
@ -198,109 +212,6 @@ def get_parser():
return parser return parser
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
) -> List[List[int]]:
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# logging.info(f"decoder_out shape : {decoder_out.shape}")
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
decoder_out = model.joiner.decoder_proj(decoder_out)
hyp_tokens = []
for stream in streams:
hyp_tokens.append(stream.hyp)
return hyp_tokens
def fast_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
decoding_streams: k2.RnntDecodingStreams,
) -> List[List[int]]:
B, T, C = encoder_out.shape
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyp_tokens = get_texts(best_path)
return hyp_tokens
def decode_one_chunk( def decode_one_chunk(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -325,7 +236,6 @@ def decode_one_chunk(
feature_lens = [] feature_lens = []
states = [] states = []
rnnt_stream_list = []
processed_lens = [] processed_lens = []
for stream in decode_streams: for stream in decode_streams:
@ -336,8 +246,6 @@ def decode_one_chunk(
feature_lens.append(feat_len) feature_lens.append(feat_len)
states.append(stream.states) states.append(stream.states)
processed_lens.append(stream.done_frames) processed_lens.append(stream.done_frames)
if params.decoding_method == "fast_beam_search":
rnnt_stream_list.append(stream.rnnt_decoding_stream)
feature_lens = torch.tensor(feature_lens, device=device) feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
@ -349,19 +257,13 @@ def decode_one_chunk(
# frames. # frames.
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
if features.size(1) < tail_length: if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1) pad_length = tail_length - features.size(1)
features = torch.cat( feature_lens += pad_length
[ features = torch.nn.functional.pad(
features, features,
torch.tensor( (0, 0, 0, pad_length),
LOG_EPS, dtype=features.dtype, device=device mode="constant",
).expand( value=LOG_EPS,
features.size(0),
tail_length - features.size(1),
features.size(2),
),
],
dim=1,
) )
states = [ states = [
@ -382,22 +284,31 @@ def decode_one_chunk(
encoder_out = model.joiner.encoder_proj(encoder_out) encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
hyp_tokens = greedy_search(model, encoder_out, decode_streams) greedy_search(
elif params.decoding_method == "fast_beam_search": model=model, encoder_out=encoder_out, streams=decode_streams
config = k2.RnntDecodingConfig(
vocab_size=params.vocab_size,
decoder_history_len=params.context_size,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
) )
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) elif params.decoding_method == "fast_beam_search":
processed_lens = processed_lens + encoder_out_lens processed_lens = processed_lens + encoder_out_lens
hyp_tokens = fast_beam_search( fast_beam_search_one_best(
model, encoder_out, processed_lens, decoding_streams model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
streams=decode_streams,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=decode_streams,
encoder_out=encoder_out,
beam=params.beam_size,
) )
else: else:
assert False raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
@ -405,8 +316,6 @@ def decode_one_chunk(
for i in range(len(decode_streams)): for i in range(len(decode_streams)):
decode_streams[i].states = [states[0][i], states[1][i]] decode_streams[i].states = [states[0][i], states[1][i]]
decode_streams[i].done_frames += encoder_out_lens[i] decode_streams[i].done_frames += encoder_out_lens[i]
if params.decoding_method == "fast_beam_search":
decode_streams[i].hyp = hyp_tokens[i]
if decode_streams[i].done: if decode_streams[i].done:
finished_streams.append(i) finished_streams.append(i)
@ -490,13 +399,10 @@ def decode_dataset(
params=params, model=model, decode_streams=decode_streams params=params, model=model, decode_streams=decode_streams
) )
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] # noqa
decode_results.append( decode_results.append(
( (
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )
) )
del decode_streams[i] del decode_streams[i]
@ -510,24 +416,28 @@ def decode_dataset(
params=params, model=model, decode_streams=decode_streams params=params, model=model, decode_streams=decode_streams
) )
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] # noqa
decode_results.append( decode_results.append(
( (
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )
) )
del decode_streams[i] del decode_streams[i]
key = "greedy_search" if params.decoding_method == "greedy_search":
if params.decoding_method == "fast_beam_search": key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = ( key = (
f"beam_{params.beam}_" f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_" f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}" f"max_states_{params.max_states}"
) )
elif params.decoding_method == "modified_beam_search":
key = f"beam_size_{params.beam_size}"
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
return {key: decode_results} return {key: decode_results}