mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
Fixed streaming decoding codes for emformer model.
This commit is contained in:
parent
e74654c2a2
commit
cf0ce8db32
@ -367,7 +367,7 @@ class HypothesisList(object):
|
||||
return ", ".join(s)
|
||||
|
||||
|
||||
def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
||||
def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
||||
"""Return a ragged shape with axes [utt][num_hyps].
|
||||
|
||||
Args:
|
||||
@ -431,7 +431,7 @@ def modified_beam_search(
|
||||
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
||||
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||
|
||||
hyps_shape = _get_hyps_shape(B).to(device)
|
||||
hyps_shape = get_hyps_shape(B).to(device)
|
||||
|
||||
A = [list(b) for b in B]
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
|
@ -28,16 +28,16 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import Hypothesis, HypothesisList, _get_hyps_shape
|
||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
from emformer import LOG_EPSILON, stack_states, unstack_states
|
||||
from streaming_feature_extractor import (
|
||||
FeatureExtractionStream,
|
||||
GreedySearchStream,
|
||||
ModifiedBeamSearchStream,
|
||||
)
|
||||
from streaming_feature_extractor import FeatureExtractionStream
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import AttributeDict, setup_logger
|
||||
|
||||
|
||||
@ -225,16 +225,12 @@ class StreamList(object):
|
||||
- greedy_search
|
||||
- modified_beam_search
|
||||
"""
|
||||
decoding_classes = {
|
||||
"greedy_search": GreedySearchStream,
|
||||
"modified_beam_search": ModifiedBeamSearchStream,
|
||||
}
|
||||
|
||||
assert decoding_method in decoding_classes
|
||||
cls = decoding_classes[decoding_method]
|
||||
|
||||
self.streams = [
|
||||
cls(context_size=context_size) for _ in range(batch_size)
|
||||
FeatureExtractionStream(
|
||||
context_size=context_size, decoding_method=decoding_method
|
||||
)
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
|
||||
@property
|
||||
@ -325,7 +321,7 @@ class StreamList(object):
|
||||
|
||||
def greedy_search(
|
||||
model: nn.Module,
|
||||
streams: List[GreedySearchStream],
|
||||
streams: List[FeatureExtractionStream],
|
||||
encoder_out: torch.Tensor,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
):
|
||||
@ -333,36 +329,31 @@ def greedy_search(
|
||||
Args:
|
||||
model:
|
||||
The RNN-T model.
|
||||
stream:
|
||||
A stream object.
|
||||
streams:
|
||||
A list of GreedySearchDecodingStream objects.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
|
||||
the encoder model.
|
||||
sp:
|
||||
The BPE model.
|
||||
"""
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
device = model.device
|
||||
assert len(streams) == encoder_out.size(0)
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
for s in streams:
|
||||
if s.hyp is None:
|
||||
s.hyp = Hypothesis(
|
||||
ys=([blank_id] * context_size),
|
||||
log_prob=torch.tensor([0.0], device=device),
|
||||
)
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
device = model.device
|
||||
T = encoder_out.size(1)
|
||||
|
||||
if streams[0].decoder_out is None:
|
||||
for stream in streams:
|
||||
stream.hyp = [blank_id] * context_size
|
||||
decoder_input = torch.tensor(
|
||||
[stream.hyp.ys[-context_size:] for stream in streams],
|
||||
[stream.hyp[-context_size:] for stream in streams],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=False,
|
||||
).squeeze(1)
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1)
|
||||
# decoder_out is of shape (N, decoder_out_dim)
|
||||
else:
|
||||
decoder_out = torch.stack(
|
||||
@ -370,7 +361,6 @@ def greedy_search(
|
||||
dim=0,
|
||||
)
|
||||
|
||||
T = encoder_out.size(1)
|
||||
for t in range(T):
|
||||
current_encoder_out = encoder_out[:, t]
|
||||
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||
@ -383,22 +373,23 @@ def greedy_search(
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
streams[i].hyp.ys.append(v)
|
||||
streams[i].hyp.append(v)
|
||||
emitted = True
|
||||
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = torch.tensor(
|
||||
[stream.hyp.ys[-context_size:] for stream in streams],
|
||||
[stream.hyp[-context_size:] for stream in streams],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(
|
||||
1
|
||||
)
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=False,
|
||||
).squeeze(1)
|
||||
|
||||
for k, s in enumerate(streams):
|
||||
logging.info(f"Partial result {k}:\n{sp.decode(s.result)}")
|
||||
for k, stream in enumerate(streams):
|
||||
result = sp.decode(stream.decoding_result())
|
||||
logging.info(f"Partial result {k}:\n{result}")
|
||||
|
||||
decoder_out_list = decoder_out.unbind(dim=0)
|
||||
for i, d in enumerate(decoder_out_list):
|
||||
@ -407,7 +398,7 @@ def greedy_search(
|
||||
|
||||
def modified_beam_search(
|
||||
model: nn.Module,
|
||||
streams: List[ModifiedBeamSearchStream],
|
||||
streams: List[FeatureExtractionStream],
|
||||
encoder_out: torch.Tensor,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
beam: int = 4,
|
||||
@ -426,36 +417,35 @@ def modified_beam_search(
|
||||
beam:
|
||||
Number of active paths during the beam search.
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
assert len(streams) == encoder_out.size(0)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
device = model.device
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
assert len(streams) == encoder_out.size(0)
|
||||
batch_size = len(streams)
|
||||
T = encoder_out.size(1)
|
||||
|
||||
for s in streams:
|
||||
if len(s.hyps) == 0:
|
||||
s.hyps.add(
|
||||
for stream in streams:
|
||||
if len(stream.hyps) == 0:
|
||||
stream.hyps.add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
)
|
||||
)
|
||||
|
||||
B = [s.hyps for s in streams]
|
||||
|
||||
T = encoder_out.size(1)
|
||||
B = [stream.hyps for stream in streams]
|
||||
for t in range(T):
|
||||
current_encoder_out = encoder_out[:, t]
|
||||
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||
|
||||
hyps_shape = _get_hyps_shape(B).to(device)
|
||||
hyps_shape = get_hyps_shape(B).to(device)
|
||||
|
||||
A = [list(b) for b in B]
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
|
||||
ys_log_probs = torch.cat(
|
||||
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
|
||||
ys_log_probs = torch.stack(
|
||||
[hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0
|
||||
) # (num_hyps, 1)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
@ -516,7 +506,8 @@ def modified_beam_search(
|
||||
B[i].add(new_hyp)
|
||||
|
||||
streams[i].hyps = B[i]
|
||||
logging.info(f"Partial result {i}:\n{sp.decode(streams[i].result)}")
|
||||
result = sp.decode(streams[i].decoding_result())
|
||||
logging.info(f"Partial result {i}:\n{result}")
|
||||
|
||||
|
||||
def process_features(
|
||||
@ -645,8 +636,8 @@ def decode_batch(
|
||||
sp=sp,
|
||||
)
|
||||
results = []
|
||||
for s in stream_list.streams:
|
||||
text = sp.decode(s.result)
|
||||
for stream in stream_list.streams:
|
||||
text = sp.decode(stream.decoding_result())
|
||||
results.append(text)
|
||||
return results
|
||||
|
||||
|
@ -14,11 +14,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from beam_search import Hypothesis, HypothesisList
|
||||
from beam_search import HypothesisList
|
||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
|
||||
|
||||
def _create_streaming_feature_extractor() -> OnlineFeature:
|
||||
@ -41,21 +40,28 @@ def _create_streaming_feature_extractor() -> OnlineFeature:
|
||||
|
||||
|
||||
class FeatureExtractionStream(object):
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
def __init__(self, context_size: int, decoding_method: str) -> None:
|
||||
self.feature_extractor = _create_streaming_feature_extractor()
|
||||
# It contains a list of 1-D tensors representing the feature frames.
|
||||
self.feature_frames: List[torch.Tensor] = []
|
||||
|
||||
self.num_fetched_frames = 0
|
||||
# After calling `self.input_finished()`, we set this flag to True
|
||||
self._done = False
|
||||
|
||||
# For the emformer model, it contains the states of each
|
||||
# encoder layer.
|
||||
self.states: Optional[List[List[torch.Tensor]]] = None
|
||||
|
||||
# After calling `self.input_finished()`, we set this flag to True
|
||||
self._done = False
|
||||
# It use different attributes for different decoding methods.
|
||||
self.context_size = context_size
|
||||
self.decoding_method = decoding_method
|
||||
if decoding_method == "greedy_search":
|
||||
self.hyp: List[int] = None
|
||||
self.decoder_out: Optional[torch.Tensor] = None
|
||||
elif decoding_method == "modified_beam_search":
|
||||
self.hyps = HypothesisList()
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {decoding_method}")
|
||||
|
||||
def accept_waveform(
|
||||
self,
|
||||
@ -106,32 +112,11 @@ class FeatureExtractionStream(object):
|
||||
self.feature_frames.append(frame)
|
||||
self.num_fetched_frames += 1
|
||||
|
||||
|
||||
class GreedySearchStream(FeatureExtractionStream):
|
||||
def __init__(self, context_size: int) -> None:
|
||||
"""FeatureExtractionStream class for greedy search."""
|
||||
super().__init__()
|
||||
self.context_size = context_size
|
||||
# For the RNN-T decoder, it contains the decoder output
|
||||
# corresponding to the decoder input self.hyp.ys[-context_size:]
|
||||
# Its shape is (decoder_out_dim,)
|
||||
self.hyp: Hypothesis = None
|
||||
self.decoder_out: Optional[torch.Tensor] = None
|
||||
|
||||
@property
|
||||
def result(self) -> List[int]:
|
||||
return self.hyp.ys[self.context_size :]
|
||||
|
||||
|
||||
class ModifiedBeamSearchStream(FeatureExtractionStream):
|
||||
def __init__(self, context_size: int) -> None:
|
||||
"""FeatureExtractionStream class for modified beam search decoding."""
|
||||
super().__init__()
|
||||
self.context_size = context_size
|
||||
self.hyps = HypothesisList()
|
||||
self.best_hyp = None
|
||||
|
||||
@property
|
||||
def result(self) -> List[int]:
|
||||
best_hyp = self.hyps.get_most_probable(length_norm=True)
|
||||
return best_hyp.ys[self.context_size :]
|
||||
def decoding_result(self) -> List[int]:
|
||||
"""Obtain current decoding result."""
|
||||
if self.decoding_method == "greedy_search":
|
||||
return self.hyp[self.context_size :]
|
||||
else:
|
||||
assert self.decoding_method == "modified_beam_search"
|
||||
best_hyp = self.hyps.get_most_probable(length_norm=True)
|
||||
return best_hyp.ys[self.context_size :]
|
||||
|
Loading…
x
Reference in New Issue
Block a user