Add modified_beam_search for streaming decode (#489)

* Add modified_beam_search for pruned_transducer_stateless/streaming_decode.py

* refactor

* modified beam search for stateless3,4

* Fix comments

* Add real streamng ci
This commit is contained in:
Wei Kang 2022-07-25 16:53:23 +08:00 committed by GitHub
parent 8203d10be7
commit b1d0956855
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 843 additions and 592 deletions

View File

@ -70,7 +70,7 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
max_duration=100 max_duration=100
for method in greedy_search fast_beam_search modified_beam_search; do for method in greedy_search fast_beam_search modified_beam_search; do
log "Decoding with $method" log "Simulate streaming decoding with $method"
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--decoding-method $method \ --decoding-method $method \
@ -82,5 +82,19 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
--causal-convolution 1 --causal-convolution 1
done done
for method in greedy_search fast_beam_search modified_beam_search; do
log "Real streaming decoding with $method"
./pruned_transducer_stateless2/streaming_decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--num-decode-streams 100 \
--exp-dir pruned_transducer_stateless2/exp \
--left-context 32 \
--decode-chunk-size 8 \
--right-context 0
done
rm pruned_transducer_stateless2/exp/*.pt rm pruned_transducer_stateless2/exp/*.pt
fi fi

View File

@ -751,7 +751,7 @@ class HypothesisList(object):
return ", ".join(s) return ", ".join(s)
def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
"""Return a ragged shape with axes [utt][num_hyps]. """Return a ragged shape with axes [utt][num_hyps].
Args: Args:
@ -847,7 +847,7 @@ def modified_beam_search(
finalized_B = B[batch_size:] + finalized_B finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size] B = B[:batch_size]
hyps_shape = _get_hyps_shape(B).to(device) hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B] A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)] B = [HypothesisList() for _ in range(batch_size)]

View File

@ -19,6 +19,7 @@ from typing import List, Optional, Tuple
import k2 import k2
import torch import torch
from beam_search import Hypothesis, HypothesisList
from icefall.utils import AttributeDict from icefall.utils import AttributeDict
@ -42,7 +43,8 @@ class DecodeStream(object):
device: device:
The device to run this stream. The device to run this stream.
""" """
if decoding_graph is not None: if params.decoding_method == "fast_beam_search":
assert decoding_graph is not None
assert device == decoding_graph.device assert device == decoding_graph.device
self.params = params self.params = params
@ -77,15 +79,23 @@ class DecodeStream(object):
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
self.hyp = [params.blank_id] * params.context_size self.hyp = [params.blank_id] * params.context_size
elif params.decoding_method == "modified_beam_search":
self.hyps = HypothesisList()
self.hyps.add(
Hypothesis(
ys=[params.blank_id] * params.context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
# The rnnt_decoding_stream for fast_beam_search. # The rnnt_decoding_stream for fast_beam_search.
self.rnnt_decoding_stream: k2.RnntDecodingStream = ( self.rnnt_decoding_stream: k2.RnntDecodingStream = (
k2.RnntDecodingStream(decoding_graph) k2.RnntDecodingStream(decoding_graph)
) )
else: else:
assert ( raise ValueError(
False f"Unsupported decoding method: {params.decoding_method}"
), f"Decoding method :{params.decoding_method} do not support." )
@property @property
def done(self) -> bool: def done(self) -> bool:
@ -124,3 +134,14 @@ class DecodeStream(object):
self._done = True self._done = True
return ret_features, ret_length return ret_features, ret_length
def decoding_result(self) -> List[int]:
"""Obtain current decoding result."""
if self.params.decoding_method == "greedy_search":
return self.hyp[self.params.context_size :] # noqa
elif self.params.decoding_method == "modified_beam_search":
best_hyp = self.hyps.get_most_probable(length_norm=True)
return best_hyp.ys[self.params.context_size :] # noqa
else:
assert self.params.decoding_method == "fast_beam_search"
return self.hyp

View File

@ -0,0 +1,280 @@
# Copyright 2022 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import List
import k2
import torch
import torch.nn as nn
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from decode_stream import DecodeStream
from icefall.decode import one_best_decoding
from icefall.utils import get_texts
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
) -> None:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
streams:
A list of Stream objects.
"""
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, 1, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
def modified_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
num_active_paths: int = 4,
) -> None:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The RNN-T model.
encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
streams:
A list of stream objects.
num_active_paths:
Number of active paths during the beam search.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert len(streams) == encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size = len(streams)
T = encoder_out.size(1)
B = [stream.hyps for stream in streams]
for t in range(T):
current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.stack(
[hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
# decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, encoder_out_dim)
logits = model.joiner(current_encoder_out, decoder_out)
# logits is of shape (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(
num_active_paths
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
for i in range(batch_size):
streams[i].hyps = B[i]
def fast_beam_search_one_best(
model: nn.Module,
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
streams: List[DecodeStream],
beam: float,
max_states: int,
max_contexts: int,
) -> None:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first generated by Fsa-based beam search, then we get the
recognition by applying shortest path on the lattice.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
processed_lens:
A tensor of shape (N,) containing the number of processed frames
in `encoder_out` before padding.
streams:
A list of stream objects.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
"""
assert encoder_out.ndim == 3
B, T, C = encoder_out.shape
assert B == len(streams)
context_size = model.decoder.context_size
vocab_size = model.decoder.vocab_size
config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_contexts=max_contexts,
max_states=max_states,
)
individual_streams = []
for i in range(B):
individual_streams.append(streams[i].rnnt_decoding_stream)
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyp_tokens = get_texts(best_path)
for i in range(B):
streams[i].hyp = hyp_tokens[i]

View File

@ -17,13 +17,13 @@
""" """
Usage: Usage:
./pruned_transducer_stateless2/streaming_decode.py \ ./pruned_transducer_stateless/streaming_decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--decode-chunk-size 8 \ --decode-chunk-size 8 \
--left-context 32 \ --left-context 32 \
--right-context 0 \ --right-context 0 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless/exp \
--decoding_method greedy_search \ --decoding_method greedy_search \
--num-decode-streams 1000 --num-decode-streams 1000
""" """
@ -43,6 +43,11 @@ from asr_datamodule import LibriSpeechAsrDataModule
from decode_stream import DecodeStream from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet from lhotse import CutSet
from streaming_beam_search import (
fast_beam_search_one_best,
greedy_search,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -51,10 +56,8 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.decode import one_best_decoding
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
write_error_stats, write_error_stats,
@ -114,10 +117,21 @@ def get_parser():
"--decoding-method", "--decoding-method",
type=str, type=str,
default="greedy_search", default="greedy_search",
help="""Support only greedy_search and fast_beam_search now. help="""Supported decoding methods are:
greedy_search
modified_beam_search
fast_beam_search
""", """,
) )
parser.add_argument(
"--num-active-paths",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is modified_beam_search.""",
)
parser.add_argument( parser.add_argument(
"--beam", "--beam",
type=float, type=float,
@ -185,103 +199,6 @@ def get_parser():
return parser return parser
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
) -> List[List[int]]:
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
hyp_tokens = []
for stream in streams:
hyp_tokens.append(stream.hyp)
return hyp_tokens
def fast_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
decoding_streams: k2.RnntDecodingStreams,
) -> List[List[int]]:
B, T, C = encoder_out.shape
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyp_tokens = get_texts(best_path)
return hyp_tokens
def decode_one_chunk( def decode_one_chunk(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -305,8 +222,6 @@ def decode_one_chunk(
features = [] features = []
feature_lens = [] feature_lens = []
states = [] states = []
rnnt_stream_list = []
processed_lens = [] processed_lens = []
for stream in decode_streams: for stream in decode_streams:
@ -317,8 +232,6 @@ def decode_one_chunk(
feature_lens.append(feat_len) feature_lens.append(feat_len)
states.append(stream.states) states.append(stream.states)
processed_lens.append(stream.done_frames) processed_lens.append(stream.done_frames)
if params.decoding_method == "fast_beam_search":
rnnt_stream_list.append(stream.rnnt_decoding_stream)
feature_lens = torch.tensor(feature_lens, device=device) feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
@ -330,19 +243,13 @@ def decode_one_chunk(
# frames. # frames.
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
if features.size(1) < tail_length: if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1) pad_length = tail_length - features.size(1)
features = torch.cat( feature_lens += pad_length
[ features = torch.nn.functional.pad(
features, features,
torch.tensor( (0, 0, 0, pad_length),
LOG_EPS, dtype=features.dtype, device=device mode="constant",
).expand( value=LOG_EPS,
features.size(0),
tail_length - features.size(1),
features.size(2),
),
],
dim=1,
) )
states = [ states = [
@ -362,22 +269,31 @@ def decode_one_chunk(
) )
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
hyp_tokens = greedy_search(model, encoder_out, decode_streams) greedy_search(
elif params.decoding_method == "fast_beam_search": model=model, encoder_out=encoder_out, streams=decode_streams
config = k2.RnntDecodingConfig(
vocab_size=params.vocab_size,
decoder_history_len=params.context_size,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
) )
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) elif params.decoding_method == "fast_beam_search":
processed_lens = processed_lens + encoder_out_lens processed_lens = processed_lens + encoder_out_lens
hyp_tokens = fast_beam_search( fast_beam_search_one_best(
model, encoder_out, processed_lens, decoding_streams model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
streams=decode_streams,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=decode_streams,
encoder_out=encoder_out,
num_active_paths=params.num_active_paths,
) )
else: else:
assert False raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
@ -385,8 +301,6 @@ def decode_one_chunk(
for i in range(len(decode_streams)): for i in range(len(decode_streams)):
decode_streams[i].states = [states[0][i], states[1][i]] decode_streams[i].states = [states[0][i], states[1][i]]
decode_streams[i].done_frames += encoder_out_lens[i] decode_streams[i].done_frames += encoder_out_lens[i]
if params.decoding_method == "fast_beam_search":
decode_streams[i].hyp = hyp_tokens[i]
if decode_streams[i].done: if decode_streams[i].done:
finished_streams.append(i) finished_streams.append(i)
@ -469,13 +383,10 @@ def decode_dataset(
params=params, model=model, decode_streams=decode_streams params=params, model=model, decode_streams=decode_streams
) )
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] # noqa
decode_results.append( decode_results.append(
( (
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )
) )
del decode_streams[i] del decode_streams[i]
@ -489,24 +400,29 @@ def decode_dataset(
params=params, model=model, decode_streams=decode_streams params=params, model=model, decode_streams=decode_streams
) )
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] # noqa
decode_results.append( decode_results.append(
( (
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )
) )
del decode_streams[i] del decode_streams[i]
key = "greedy_search" if params.decoding_method == "greedy_search":
if params.decoding_method == "fast_beam_search": key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = ( key = (
f"beam_{params.beam}_" f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_" f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}" f"max_states_{params.max_states}"
) )
elif params.decoding_method == "modified_beam_search":
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
return {key: decode_results} return {key: decode_results}

View File

@ -0,0 +1,288 @@
# Copyright 2022 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import List
import k2
import torch
import torch.nn as nn
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from decode_stream import DecodeStream
from icefall.decode import one_best_decoding
from icefall.utils import get_texts
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
) -> None:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
streams:
A list of Stream objects.
"""
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, 1, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
decoder_out = model.joiner.decoder_proj(decoder_out)
def modified_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
num_active_paths: int = 4,
) -> None:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The RNN-T model.
encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
streams:
A list of stream objects.
num_active_paths:
Number of active paths during the beam search.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert len(streams) == encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size = len(streams)
T = encoder_out.size(1)
B = [stream.hyps for stream in streams]
for t in range(T):
current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.stack(
[hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, encoder_out_dim)
logits = model.joiner(
current_encoder_out, decoder_out, project_input=False
)
# logits is of shape (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(
num_active_paths
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
for i in range(batch_size):
streams[i].hyps = B[i]
def fast_beam_search_one_best(
model: nn.Module,
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
streams: List[DecodeStream],
beam: float,
max_states: int,
max_contexts: int,
) -> None:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first generated by Fsa-based beam search, then we get the
recognition by applying shortest path on the lattice.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
processed_lens:
A tensor of shape (N,) containing the number of processed frames
in `encoder_out` before padding.
streams:
A list of stream objects.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
"""
assert encoder_out.ndim == 3
B, T, C = encoder_out.shape
assert B == len(streams)
context_size = model.decoder.context_size
vocab_size = model.decoder.vocab_size
config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_contexts=max_contexts,
max_states=max_states,
)
individual_streams = []
for i in range(B):
individual_streams.append(streams[i].rnnt_decoding_stream)
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyp_tokens = get_texts(best_path)
for i in range(B):
streams[i].hyp = hyp_tokens[i]

View File

@ -43,6 +43,11 @@ from asr_datamodule import LibriSpeechAsrDataModule
from decode_stream import DecodeStream from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet from lhotse import CutSet
from streaming_beam_search import (
fast_beam_search_one_best,
greedy_search,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -51,10 +56,8 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.decode import one_best_decoding
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
write_error_stats, write_error_stats,
@ -114,10 +117,21 @@ def get_parser():
"--decoding-method", "--decoding-method",
type=str, type=str,
default="greedy_search", default="greedy_search",
help="""Support only greedy_search and fast_beam_search now. help="""Supported decoding methods are:
greedy_search
modified_beam_search
fast_beam_search
""", """,
) )
parser.add_argument(
"--num_active_paths",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is modified_beam_search.""",
)
parser.add_argument( parser.add_argument(
"--beam", "--beam",
type=float, type=float,
@ -185,109 +199,6 @@ def get_parser():
return parser return parser
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
) -> List[List[int]]:
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# logging.info(f"decoder_out shape : {decoder_out.shape}")
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
decoder_out = model.joiner.decoder_proj(decoder_out)
hyp_tokens = []
for stream in streams:
hyp_tokens.append(stream.hyp)
return hyp_tokens
def fast_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
decoding_streams: k2.RnntDecodingStreams,
) -> List[List[int]]:
B, T, C = encoder_out.shape
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyp_tokens = get_texts(best_path)
return hyp_tokens
def decode_one_chunk( def decode_one_chunk(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -312,7 +223,6 @@ def decode_one_chunk(
feature_lens = [] feature_lens = []
states = [] states = []
rnnt_stream_list = []
processed_lens = [] processed_lens = []
for stream in decode_streams: for stream in decode_streams:
@ -323,8 +233,6 @@ def decode_one_chunk(
feature_lens.append(feat_len) feature_lens.append(feat_len)
states.append(stream.states) states.append(stream.states)
processed_lens.append(stream.done_frames) processed_lens.append(stream.done_frames)
if params.decoding_method == "fast_beam_search":
rnnt_stream_list.append(stream.rnnt_decoding_stream)
feature_lens = torch.tensor(feature_lens, device=device) feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
@ -336,19 +244,13 @@ def decode_one_chunk(
# frames. # frames.
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
if features.size(1) < tail_length: if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1) pad_length = tail_length - features.size(1)
features = torch.cat( feature_lens += pad_length
[ features = torch.nn.functional.pad(
features, features,
torch.tensor( (0, 0, 0, pad_length),
LOG_EPS, dtype=features.dtype, device=device mode="constant",
).expand( value=LOG_EPS,
features.size(0),
tail_length - features.size(1),
features.size(2),
),
],
dim=1,
) )
states = [ states = [
@ -369,22 +271,31 @@ def decode_one_chunk(
encoder_out = model.joiner.encoder_proj(encoder_out) encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
hyp_tokens = greedy_search(model, encoder_out, decode_streams) greedy_search(
elif params.decoding_method == "fast_beam_search": model=model, encoder_out=encoder_out, streams=decode_streams
config = k2.RnntDecodingConfig(
vocab_size=params.vocab_size,
decoder_history_len=params.context_size,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
) )
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) elif params.decoding_method == "fast_beam_search":
processed_lens = processed_lens + encoder_out_lens processed_lens = processed_lens + encoder_out_lens
hyp_tokens = fast_beam_search( fast_beam_search_one_best(
model, encoder_out, processed_lens, decoding_streams model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
streams=decode_streams,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=decode_streams,
encoder_out=encoder_out,
num_active_paths=params.num_active_paths,
) )
else: else:
assert False raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
@ -392,8 +303,6 @@ def decode_one_chunk(
for i in range(len(decode_streams)): for i in range(len(decode_streams)):
decode_streams[i].states = [states[0][i], states[1][i]] decode_streams[i].states = [states[0][i], states[1][i]]
decode_streams[i].done_frames += encoder_out_lens[i] decode_streams[i].done_frames += encoder_out_lens[i]
if params.decoding_method == "fast_beam_search":
decode_streams[i].hyp = hyp_tokens[i]
if decode_streams[i].done: if decode_streams[i].done:
finished_streams.append(i) finished_streams.append(i)
@ -477,13 +386,10 @@ def decode_dataset(
params=params, model=model, decode_streams=decode_streams params=params, model=model, decode_streams=decode_streams
) )
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] # noqa
decode_results.append( decode_results.append(
( (
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )
) )
del decode_streams[i] del decode_streams[i]
@ -497,24 +403,28 @@ def decode_dataset(
params=params, model=model, decode_streams=decode_streams params=params, model=model, decode_streams=decode_streams
) )
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] # noqa
decode_results.append( decode_results.append(
( (
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )
) )
del decode_streams[i] del decode_streams[i]
key = "greedy_search" if params.decoding_method == "greedy_search":
if params.decoding_method == "fast_beam_search": key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = ( key = (
f"beam_{params.beam}_" f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_" f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}" f"max_states_{params.max_states}"
) )
elif params.decoding_method == "modified_beam_search":
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
return {key: decode_results} return {key: decode_results}

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/streaming_beam_search.py

View File

@ -17,13 +17,13 @@
""" """
Usage: Usage:
./pruned_transducer_stateless2/streaming_decode.py \ ./pruned_transducer_stateless3/streaming_decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--left-context 32 \ --left-context 32 \
--decode-chunk-size 8 \ --decode-chunk-size 8 \
--right-context 0 \ --right-context 0 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--decoding_method greedy_search \ --decoding_method greedy_search \
--num-decode-streams 1000 --num-decode-streams 1000
""" """
@ -44,6 +44,11 @@ from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet from lhotse import CutSet
from librispeech import LibriSpeech from librispeech import LibriSpeech
from streaming_beam_search import (
fast_beam_search_one_best,
greedy_search,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -52,10 +57,8 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.decode import one_best_decoding
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
write_error_stats, write_error_stats,
@ -115,10 +118,21 @@ def get_parser():
"--decoding-method", "--decoding-method",
type=str, type=str,
default="greedy_search", default="greedy_search",
help="""Support only greedy_search and fast_beam_search now. help="""Supported decoding methods are:
greedy_search
modified_beam_search
fast_beam_search
""", """,
) )
parser.add_argument(
"--num_active_paths",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is modified_beam_search.""",
)
parser.add_argument( parser.add_argument(
"--beam", "--beam",
type=float, type=float,
@ -186,109 +200,6 @@ def get_parser():
return parser return parser
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
) -> List[List[int]]:
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# logging.info(f"decoder_out shape : {decoder_out.shape}")
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
decoder_out = model.joiner.decoder_proj(decoder_out)
hyp_tokens = []
for stream in streams:
hyp_tokens.append(stream.hyp)
return hyp_tokens
def fast_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
decoding_streams: k2.RnntDecodingStreams,
) -> List[List[int]]:
B, T, C = encoder_out.shape
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyp_tokens = get_texts(best_path)
return hyp_tokens
def decode_one_chunk( def decode_one_chunk(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -313,7 +224,6 @@ def decode_one_chunk(
feature_lens = [] feature_lens = []
states = [] states = []
rnnt_stream_list = []
processed_lens = [] processed_lens = []
for stream in decode_streams: for stream in decode_streams:
@ -324,8 +234,6 @@ def decode_one_chunk(
feature_lens.append(feat_len) feature_lens.append(feat_len)
states.append(stream.states) states.append(stream.states)
processed_lens.append(stream.done_frames) processed_lens.append(stream.done_frames)
if params.decoding_method == "fast_beam_search":
rnnt_stream_list.append(stream.rnnt_decoding_stream)
feature_lens = torch.tensor(feature_lens, device=device) feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
@ -337,19 +245,13 @@ def decode_one_chunk(
# frames. # frames.
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
if features.size(1) < tail_length: if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1) pad_length = tail_length - features.size(1)
features = torch.cat( feature_lens += pad_length
[ features = torch.nn.functional.pad(
features, features,
torch.tensor( (0, 0, 0, pad_length),
LOG_EPS, dtype=features.dtype, device=device mode="constant",
).expand( value=LOG_EPS,
features.size(0),
tail_length - features.size(1),
features.size(2),
),
],
dim=1,
) )
states = [ states = [
@ -370,22 +272,31 @@ def decode_one_chunk(
encoder_out = model.joiner.encoder_proj(encoder_out) encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
hyp_tokens = greedy_search(model, encoder_out, decode_streams) greedy_search(
elif params.decoding_method == "fast_beam_search": model=model, encoder_out=encoder_out, streams=decode_streams
config = k2.RnntDecodingConfig(
vocab_size=params.vocab_size,
decoder_history_len=params.context_size,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
) )
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) elif params.decoding_method == "fast_beam_search":
processed_lens = processed_lens + encoder_out_lens processed_lens = processed_lens + encoder_out_lens
hyp_tokens = fast_beam_search( fast_beam_search_one_best(
model, encoder_out, processed_lens, decoding_streams model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
streams=decode_streams,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=decode_streams,
encoder_out=encoder_out,
num_active_paths=params.num_active_paths,
) )
else: else:
assert False raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
@ -393,8 +304,6 @@ def decode_one_chunk(
for i in range(len(decode_streams)): for i in range(len(decode_streams)):
decode_streams[i].states = [states[0][i], states[1][i]] decode_streams[i].states = [states[0][i], states[1][i]]
decode_streams[i].done_frames += encoder_out_lens[i] decode_streams[i].done_frames += encoder_out_lens[i]
if params.decoding_method == "fast_beam_search":
decode_streams[i].hyp = hyp_tokens[i]
if decode_streams[i].done: if decode_streams[i].done:
finished_streams.append(i) finished_streams.append(i)
@ -478,13 +387,10 @@ def decode_dataset(
params=params, model=model, decode_streams=decode_streams params=params, model=model, decode_streams=decode_streams
) )
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] # noqa
decode_results.append( decode_results.append(
( (
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )
) )
del decode_streams[i] del decode_streams[i]
@ -498,24 +404,28 @@ def decode_dataset(
params=params, model=model, decode_streams=decode_streams params=params, model=model, decode_streams=decode_streams
) )
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] # noqa
decode_results.append( decode_results.append(
( (
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )
) )
del decode_streams[i] del decode_streams[i]
key = "greedy_search" if params.decoding_method == "greedy_search":
if params.decoding_method == "fast_beam_search": key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = ( key = (
f"beam_{params.beam}_" f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_" f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}" f"max_states_{params.max_states}"
) )
elif params.decoding_method == "modified_beam_search":
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
return {key: decode_results} return {key: decode_results}

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/streaming_beam_search.py

View File

@ -17,13 +17,13 @@
""" """
Usage: Usage:
./pruned_transducer_stateless2/streaming_decode.py \ ./pruned_transducer_stateless4/streaming_decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--left-context 32 \ --left-context 32 \
--decode-chunk-size 8 \ --decode-chunk-size 8 \
--right-context 0 \ --right-context 0 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless4/exp \
--decoding_method greedy_search \ --decoding_method greedy_search \
--num-decode-streams 200 --num-decode-streams 200
""" """
@ -43,6 +43,11 @@ from asr_datamodule import LibriSpeechAsrDataModule
from decode_stream import DecodeStream from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet from lhotse import CutSet
from streaming_beam_search import (
fast_beam_search_one_best,
greedy_search,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -52,10 +57,8 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.decode import one_best_decoding
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
str2bool, str2bool,
@ -127,10 +130,21 @@ def get_parser():
"--decoding-method", "--decoding-method",
type=str, type=str,
default="greedy_search", default="greedy_search",
help="""Support only greedy_search and fast_beam_search now. help="""Supported decoding methods are:
greedy_search
modified_beam_search
fast_beam_search
""", """,
) )
parser.add_argument(
"--num_active_paths",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is modified_beam_search.""",
)
parser.add_argument( parser.add_argument(
"--beam", "--beam",
type=float, type=float,
@ -198,109 +212,6 @@ def get_parser():
return parser return parser
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
) -> List[List[int]]:
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# logging.info(f"decoder_out shape : {decoder_out.shape}")
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
decoder_out = model.joiner.decoder_proj(decoder_out)
hyp_tokens = []
for stream in streams:
hyp_tokens.append(stream.hyp)
return hyp_tokens
def fast_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
decoding_streams: k2.RnntDecodingStreams,
) -> List[List[int]]:
B, T, C = encoder_out.shape
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyp_tokens = get_texts(best_path)
return hyp_tokens
def decode_one_chunk( def decode_one_chunk(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -325,7 +236,6 @@ def decode_one_chunk(
feature_lens = [] feature_lens = []
states = [] states = []
rnnt_stream_list = []
processed_lens = [] processed_lens = []
for stream in decode_streams: for stream in decode_streams:
@ -336,8 +246,6 @@ def decode_one_chunk(
feature_lens.append(feat_len) feature_lens.append(feat_len)
states.append(stream.states) states.append(stream.states)
processed_lens.append(stream.done_frames) processed_lens.append(stream.done_frames)
if params.decoding_method == "fast_beam_search":
rnnt_stream_list.append(stream.rnnt_decoding_stream)
feature_lens = torch.tensor(feature_lens, device=device) feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
@ -349,19 +257,13 @@ def decode_one_chunk(
# frames. # frames.
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
if features.size(1) < tail_length: if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1) pad_length = tail_length - features.size(1)
features = torch.cat( feature_lens += pad_length
[ features = torch.nn.functional.pad(
features, features,
torch.tensor( (0, 0, 0, pad_length),
LOG_EPS, dtype=features.dtype, device=device mode="constant",
).expand( value=LOG_EPS,
features.size(0),
tail_length - features.size(1),
features.size(2),
),
],
dim=1,
) )
states = [ states = [
@ -382,22 +284,31 @@ def decode_one_chunk(
encoder_out = model.joiner.encoder_proj(encoder_out) encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
hyp_tokens = greedy_search(model, encoder_out, decode_streams) greedy_search(
elif params.decoding_method == "fast_beam_search": model=model, encoder_out=encoder_out, streams=decode_streams
config = k2.RnntDecodingConfig(
vocab_size=params.vocab_size,
decoder_history_len=params.context_size,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
) )
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) elif params.decoding_method == "fast_beam_search":
processed_lens = processed_lens + encoder_out_lens processed_lens = processed_lens + encoder_out_lens
hyp_tokens = fast_beam_search( fast_beam_search_one_best(
model, encoder_out, processed_lens, decoding_streams model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
streams=decode_streams,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=decode_streams,
encoder_out=encoder_out,
num_active_paths=params.num_active_paths,
) )
else: else:
assert False raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
@ -405,8 +316,6 @@ def decode_one_chunk(
for i in range(len(decode_streams)): for i in range(len(decode_streams)):
decode_streams[i].states = [states[0][i], states[1][i]] decode_streams[i].states = [states[0][i], states[1][i]]
decode_streams[i].done_frames += encoder_out_lens[i] decode_streams[i].done_frames += encoder_out_lens[i]
if params.decoding_method == "fast_beam_search":
decode_streams[i].hyp = hyp_tokens[i]
if decode_streams[i].done: if decode_streams[i].done:
finished_streams.append(i) finished_streams.append(i)
@ -490,13 +399,10 @@ def decode_dataset(
params=params, model=model, decode_streams=decode_streams params=params, model=model, decode_streams=decode_streams
) )
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] # noqa
decode_results.append( decode_results.append(
( (
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )
) )
del decode_streams[i] del decode_streams[i]
@ -510,24 +416,28 @@ def decode_dataset(
params=params, model=model, decode_streams=decode_streams params=params, model=model, decode_streams=decode_streams
) )
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] # noqa
decode_results.append( decode_results.append(
( (
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )
) )
del decode_streams[i] del decode_streams[i]
key = "greedy_search" if params.decoding_method == "greedy_search":
if params.decoding_method == "fast_beam_search": key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = ( key = (
f"beam_{params.beam}_" f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_" f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}" f"max_states_{params.max_states}"
) )
elif params.decoding_method == "modified_beam_search":
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
return {key: decode_results} return {key: decode_results}