diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_beam_search.py new file mode 120000 index 000000000..3a5f89833 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/streaming_beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 8af2788be..ed638bd0a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -44,6 +44,11 @@ from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet from librispeech import LibriSpeech +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model @@ -52,10 +57,8 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.decode import one_best_decoding from icefall.utils import ( AttributeDict, - get_texts, setup_logger, store_transcripts, write_error_stats, @@ -115,10 +118,21 @@ def get_parser(): "--decoding-method", type=str, default="greedy_search", - help="""Support only greedy_search and fast_beam_search now. + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search """, ) + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + parser.add_argument( "--beam", type=float, @@ -186,109 +200,6 @@ def get_parser(): return parser -def greedy_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[DecodeStream], -) -> List[List[int]]: - - assert len(streams) == encoder_out.size(0) - assert encoder_out.ndim == 3 - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = model.device - T = encoder_out.size(1) - - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - # decoder_out is of shape (N, decoder_out_dim) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # logging.info(f"decoder_out shape : {decoder_out.shape}") - - for t in range(T): - # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) - current_encoder_out = encoder_out[:, t : t + 1, :] # noqa - - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - # logits'shape (batch_size, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - streams[i].hyp.append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=False, - ) - decoder_out = model.joiner.decoder_proj(decoder_out) - - hyp_tokens = [] - for stream in streams: - hyp_tokens.append(stream.hyp) - return hyp_tokens - - -def fast_beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - processed_lens: torch.Tensor, - decoding_streams: k2.RnntDecodingStreams, -) -> List[List[int]]: - - B, T, C = encoder_out.shape - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - log_probs = logits.log_softmax(dim=-1) - decoding_streams.advance(log_probs) - - decoding_streams.terminate_and_flush_to_streams() - - lattice = decoding_streams.format_output(processed_lens.tolist()) - best_path = one_best_decoding(lattice) - hyp_tokens = get_texts(best_path) - return hyp_tokens - - def decode_one_chunk( params: AttributeDict, model: nn.Module, @@ -313,7 +224,6 @@ def decode_one_chunk( feature_lens = [] states = [] - rnnt_stream_list = [] processed_lens = [] for stream in decode_streams: @@ -324,8 +234,6 @@ def decode_one_chunk( feature_lens.append(feat_len) states.append(stream.states) processed_lens.append(stream.done_frames) - if params.decoding_method == "fast_beam_search": - rnnt_stream_list.append(stream.rnnt_decoding_stream) feature_lens = torch.tensor(feature_lens, device=device) features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) @@ -337,19 +245,13 @@ def decode_one_chunk( # frames. tail_length = 7 + (2 + params.right_context) * params.subsampling_factor if features.size(1) < tail_length: - feature_lens += tail_length - features.size(1) - features = torch.cat( - [ - features, - torch.tensor( - LOG_EPS, dtype=features.dtype, device=device - ).expand( - features.size(0), - tail_length - features.size(1), - features.size(2), - ), - ], - dim=1, + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, ) states = [ @@ -370,22 +272,31 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - hyp_tokens = greedy_search(model, encoder_out, decode_streams) - elif params.decoding_method == "fast_beam_search": - config = k2.RnntDecodingConfig( - vocab_size=params.vocab_size, - decoder_history_len=params.context_size, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams ) - decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) + elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens - hyp_tokens = fast_beam_search( - model, encoder_out, processed_lens, decoding_streams + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + beam=params.beam_size, ) else: - assert False + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -393,8 +304,6 @@ def decode_one_chunk( for i in range(len(decode_streams)): decode_streams[i].states = [states[0][i], states[1][i]] decode_streams[i].done_frames += encoder_out_lens[i] - if params.decoding_method == "fast_beam_search": - decode_streams[i].hyp = hyp_tokens[i] if decode_streams[i].done: finished_streams.append(i) @@ -478,13 +387,10 @@ def decode_dataset( params=params, model=model, decode_streams=decode_streams ) for i in sorted(finished_streams, reverse=True): - hyp = decode_streams[i].hyp - if params.decoding_method == "greedy_search": - hyp = hyp[params.context_size :] # noqa decode_results.append( ( decode_streams[i].ground_truth.split(), - sp.decode(hyp).split(), + sp.decode(decode_streams[i].decoding_result()).split(), ) ) del decode_streams[i] @@ -498,24 +404,28 @@ def decode_dataset( params=params, model=model, decode_streams=decode_streams ) for i in sorted(finished_streams, reverse=True): - hyp = decode_streams[i].hyp - if params.decoding_method == "greedy_search": - hyp = hyp[params.context_size :] # noqa decode_results.append( ( decode_streams[i].ground_truth.split(), - sp.decode(hyp).split(), + sp.decode(decode_streams[i].decoding_result()).split(), ) ) del decode_streams[i] - key = "greedy_search" - if params.decoding_method == "fast_beam_search": + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": key = ( f"beam_{params.beam}_" f"max_contexts_{params.max_contexts}_" f"max_states_{params.max_states}" ) + elif params.decoding_method == "modified_beam_search": + key = f"beam_size_{params.beam_size}" + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_beam_search.py new file mode 120000 index 000000000..3a5f89833 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/streaming_beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index 57fd06980..171f17f03 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -43,6 +43,11 @@ from asr_datamodule import LibriSpeechAsrDataModule from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model @@ -52,10 +57,8 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.decode import one_best_decoding from icefall.utils import ( AttributeDict, - get_texts, setup_logger, store_transcripts, str2bool, @@ -127,10 +130,21 @@ def get_parser(): "--decoding-method", type=str, default="greedy_search", - help="""Support only greedy_search and fast_beam_search now. + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search """, ) + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + parser.add_argument( "--beam", type=float, @@ -198,109 +212,6 @@ def get_parser(): return parser -def greedy_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[DecodeStream], -) -> List[List[int]]: - - assert len(streams) == encoder_out.size(0) - assert encoder_out.ndim == 3 - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = model.device - T = encoder_out.size(1) - - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - # decoder_out is of shape (N, decoder_out_dim) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # logging.info(f"decoder_out shape : {decoder_out.shape}") - - for t in range(T): - # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) - current_encoder_out = encoder_out[:, t : t + 1, :] # noqa - - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - # logits'shape (batch_size, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - streams[i].hyp.append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=False, - ) - decoder_out = model.joiner.decoder_proj(decoder_out) - - hyp_tokens = [] - for stream in streams: - hyp_tokens.append(stream.hyp) - return hyp_tokens - - -def fast_beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - processed_lens: torch.Tensor, - decoding_streams: k2.RnntDecodingStreams, -) -> List[List[int]]: - - B, T, C = encoder_out.shape - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - log_probs = logits.log_softmax(dim=-1) - decoding_streams.advance(log_probs) - - decoding_streams.terminate_and_flush_to_streams() - - lattice = decoding_streams.format_output(processed_lens.tolist()) - best_path = one_best_decoding(lattice) - hyp_tokens = get_texts(best_path) - return hyp_tokens - - def decode_one_chunk( params: AttributeDict, model: nn.Module, @@ -325,7 +236,6 @@ def decode_one_chunk( feature_lens = [] states = [] - rnnt_stream_list = [] processed_lens = [] for stream in decode_streams: @@ -336,8 +246,6 @@ def decode_one_chunk( feature_lens.append(feat_len) states.append(stream.states) processed_lens.append(stream.done_frames) - if params.decoding_method == "fast_beam_search": - rnnt_stream_list.append(stream.rnnt_decoding_stream) feature_lens = torch.tensor(feature_lens, device=device) features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) @@ -349,19 +257,13 @@ def decode_one_chunk( # frames. tail_length = 7 + (2 + params.right_context) * params.subsampling_factor if features.size(1) < tail_length: - feature_lens += tail_length - features.size(1) - features = torch.cat( - [ - features, - torch.tensor( - LOG_EPS, dtype=features.dtype, device=device - ).expand( - features.size(0), - tail_length - features.size(1), - features.size(2), - ), - ], - dim=1, + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, ) states = [ @@ -382,22 +284,31 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - hyp_tokens = greedy_search(model, encoder_out, decode_streams) - elif params.decoding_method == "fast_beam_search": - config = k2.RnntDecodingConfig( - vocab_size=params.vocab_size, - decoder_history_len=params.context_size, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams ) - decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) + elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens - hyp_tokens = fast_beam_search( - model, encoder_out, processed_lens, decoding_streams + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + beam=params.beam_size, ) else: - assert False + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -405,8 +316,6 @@ def decode_one_chunk( for i in range(len(decode_streams)): decode_streams[i].states = [states[0][i], states[1][i]] decode_streams[i].done_frames += encoder_out_lens[i] - if params.decoding_method == "fast_beam_search": - decode_streams[i].hyp = hyp_tokens[i] if decode_streams[i].done: finished_streams.append(i) @@ -490,13 +399,10 @@ def decode_dataset( params=params, model=model, decode_streams=decode_streams ) for i in sorted(finished_streams, reverse=True): - hyp = decode_streams[i].hyp - if params.decoding_method == "greedy_search": - hyp = hyp[params.context_size :] # noqa decode_results.append( ( decode_streams[i].ground_truth.split(), - sp.decode(hyp).split(), + sp.decode(decode_streams[i].decoding_result()).split(), ) ) del decode_streams[i] @@ -510,24 +416,28 @@ def decode_dataset( params=params, model=model, decode_streams=decode_streams ) for i in sorted(finished_streams, reverse=True): - hyp = decode_streams[i].hyp - if params.decoding_method == "greedy_search": - hyp = hyp[params.context_size :] # noqa decode_results.append( ( decode_streams[i].ground_truth.split(), - sp.decode(hyp).split(), + sp.decode(decode_streams[i].decoding_result()).split(), ) ) del decode_streams[i] - key = "greedy_search" - if params.decoding_method == "fast_beam_search": + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": key = ( f"beam_{params.beam}_" f"max_contexts_{params.max_contexts}_" f"max_states_{params.max_states}" ) + elif params.decoding_method == "modified_beam_search": + key = f"beam_size_{params.beam_size}" + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results}