mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
Merge branch 'streaming_decoding' into streaming
This commit is contained in:
commit
42f8afd264
@ -18,23 +18,26 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from beam_search import Hypothesis, HypothesisList, _get_hyps_shape
|
||||||
from emformer import LOG_EPSILON, stack_states, unstack_states
|
from emformer import LOG_EPSILON, stack_states, unstack_states
|
||||||
from streaming_feature_extractor import FeatureExtractionStream
|
from streaming_feature_extractor import (
|
||||||
|
FeatureExtractionStream,
|
||||||
|
GreedySearchStream,
|
||||||
|
ModifiedBeamSearchStream,
|
||||||
|
)
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||||
average_checkpoints,
|
|
||||||
find_checkpoints,
|
|
||||||
load_checkpoint,
|
|
||||||
)
|
|
||||||
from icefall.utils import AttributeDict, setup_logger
|
from icefall.utils import AttributeDict, setup_logger
|
||||||
|
|
||||||
|
|
||||||
@ -50,6 +53,7 @@ def get_parser():
|
|||||||
help="It specifies the checkpoint to use for decoding."
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"Note: Epoch counts from 0.",
|
"Note: Epoch counts from 0.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
@ -208,7 +212,7 @@ class StreamList(object):
|
|||||||
self,
|
self,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
context_size: int,
|
context_size: int,
|
||||||
blank_id: int,
|
decoding_method: str,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -216,14 +220,21 @@ class StreamList(object):
|
|||||||
Size of this batch.
|
Size of this batch.
|
||||||
context_size:
|
context_size:
|
||||||
Context size of the RNN-T decoder model.
|
Context size of the RNN-T decoder model.
|
||||||
blank_id:
|
decoding_method:
|
||||||
The ID of the blank symbol of the BPE model.
|
Decoding method. The possible values are:
|
||||||
|
- greedy_search
|
||||||
|
- modified_beam_search
|
||||||
"""
|
"""
|
||||||
|
decoding_classes = {
|
||||||
|
"greedy_search": GreedySearchStream,
|
||||||
|
"modified_beam_search": ModifiedBeamSearchStream,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert decoding_method in decoding_classes
|
||||||
|
cls = decoding_classes[decoding_method]
|
||||||
|
|
||||||
self.streams = [
|
self.streams = [
|
||||||
FeatureExtractionStream(
|
cls(context_size=context_size) for _ in range(batch_size)
|
||||||
context_size=context_size, blank_id=blank_id
|
|
||||||
)
|
|
||||||
for _ in range(batch_size)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -238,7 +249,7 @@ class StreamList(object):
|
|||||||
audio_samples: List[torch.Tensor],
|
audio_samples: List[torch.Tensor],
|
||||||
sampling_rate: float,
|
sampling_rate: float,
|
||||||
):
|
):
|
||||||
"""Feeed audio samples to each stream.
|
"""Feed audio samples to each stream.
|
||||||
Args:
|
Args:
|
||||||
audio_samples:
|
audio_samples:
|
||||||
A list of 1-D tensors containing the audio samples for each
|
A list of 1-D tensors containing the audio samples for each
|
||||||
@ -314,7 +325,7 @@ class StreamList(object):
|
|||||||
|
|
||||||
def greedy_search(
|
def greedy_search(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
streams: List[FeatureExtractionStream],
|
streams: List[GreedySearchStream],
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
):
|
):
|
||||||
@ -333,7 +344,15 @@ def greedy_search(
|
|||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
device = model.device
|
device = model.device
|
||||||
|
assert len(streams) == encoder_out.size(0)
|
||||||
|
assert encoder_out.ndim == 3
|
||||||
|
|
||||||
|
for s in streams:
|
||||||
|
if s.hyp is None:
|
||||||
|
s.hyp = Hypothesis(
|
||||||
|
ys=([blank_id] * context_size),
|
||||||
|
log_prob=torch.tensor([0.0], device=device),
|
||||||
|
)
|
||||||
if streams[0].decoder_out is None:
|
if streams[0].decoder_out is None:
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
[stream.hyp.ys[-context_size:] for stream in streams],
|
[stream.hyp.ys[-context_size:] for stream in streams],
|
||||||
@ -351,8 +370,6 @@ def greedy_search(
|
|||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert encoder_out.ndim == 3
|
|
||||||
|
|
||||||
T = encoder_out.size(1)
|
T = encoder_out.size(1)
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
current_encoder_out = encoder_out[:, t]
|
current_encoder_out = encoder_out[:, t]
|
||||||
@ -381,20 +398,132 @@ def greedy_search(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for k, s in enumerate(streams):
|
for k, s in enumerate(streams):
|
||||||
logging.info(
|
logging.info(f"Partial result {k}:\n{sp.decode(s.result)}")
|
||||||
f"Partial result {k}:\n{sp.decode(s.hyp.ys[context_size:])}"
|
|
||||||
)
|
|
||||||
|
|
||||||
decoder_out_list = decoder_out.unbind(dim=0)
|
decoder_out_list = decoder_out.unbind(dim=0)
|
||||||
|
|
||||||
for i, d in enumerate(decoder_out_list):
|
for i, d in enumerate(decoder_out_list):
|
||||||
streams[i].decoder_out = d
|
streams[i].decoder_out = d
|
||||||
|
|
||||||
|
|
||||||
|
def modified_beam_search(
|
||||||
|
model: nn.Module,
|
||||||
|
streams: List[ModifiedBeamSearchStream],
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
beam: int = 4,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The RNN-T model.
|
||||||
|
stream:
|
||||||
|
A stream object.
|
||||||
|
encoder_out:
|
||||||
|
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
|
||||||
|
the encoder model.
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
beam:
|
||||||
|
Number of active paths during the beam search.
|
||||||
|
"""
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
device = model.device
|
||||||
|
assert encoder_out.ndim == 3, encoder_out.shape
|
||||||
|
assert len(streams) == encoder_out.size(0)
|
||||||
|
batch_size = len(streams)
|
||||||
|
|
||||||
|
for s in streams:
|
||||||
|
if len(s.hyps) == 0:
|
||||||
|
s.hyps.add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=[blank_id] * context_size,
|
||||||
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
B = [s.hyps for s in streams]
|
||||||
|
|
||||||
|
T = encoder_out.size(1)
|
||||||
|
for t in range(T):
|
||||||
|
current_encoder_out = encoder_out[:, t]
|
||||||
|
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||||
|
|
||||||
|
hyps_shape = _get_hyps_shape(B).to(device)
|
||||||
|
|
||||||
|
A = [list(b) for b in B]
|
||||||
|
B = [HypothesisList() for _ in range(batch_size)]
|
||||||
|
|
||||||
|
ys_log_probs = torch.cat(
|
||||||
|
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
|
||||||
|
) # (num_hyps, 1)
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (num_hyps, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1)
|
||||||
|
# decoder_out is of shape (num_hyps, decoder_output_dim)
|
||||||
|
|
||||||
|
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||||
|
# as index, so we use `to(torch.int64)` below.
|
||||||
|
current_encoder_out = torch.index_select(
|
||||||
|
current_encoder_out,
|
||||||
|
dim=0,
|
||||||
|
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||||
|
) # (num_hyps, encoder_out_dim)
|
||||||
|
|
||||||
|
logits = model.joiner(current_encoder_out, decoder_out)
|
||||||
|
# logits is of shape (num_hyps, vocab_size)
|
||||||
|
|
||||||
|
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||||
|
|
||||||
|
log_probs.add_(ys_log_probs)
|
||||||
|
|
||||||
|
vocab_size = log_probs.size(-1)
|
||||||
|
|
||||||
|
log_probs = log_probs.reshape(-1)
|
||||||
|
|
||||||
|
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||||
|
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||||
|
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||||
|
)
|
||||||
|
ragged_log_probs = k2.RaggedTensor(
|
||||||
|
shape=log_probs_shape, value=log_probs
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||||
|
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||||
|
|
||||||
|
for k in range(len(topk_hyp_indexes)):
|
||||||
|
hyp_idx = topk_hyp_indexes[k]
|
||||||
|
hyp = A[i][hyp_idx]
|
||||||
|
|
||||||
|
new_ys = hyp.ys[:]
|
||||||
|
new_token = topk_token_indexes[k]
|
||||||
|
if new_token != blank_id:
|
||||||
|
new_ys.append(new_token)
|
||||||
|
|
||||||
|
new_log_prob = topk_log_probs[k]
|
||||||
|
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||||
|
B[i].add(new_hyp)
|
||||||
|
|
||||||
|
streams[i].hyps = B[i]
|
||||||
|
logging.info(f"Partial result {i}:\n{sp.decode(streams[i].result)}")
|
||||||
|
|
||||||
|
|
||||||
def process_features(
|
def process_features(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
features: torch.Tensor,
|
features: torch.Tensor,
|
||||||
streams: List[FeatureExtractionStream],
|
streams: List[FeatureExtractionStream],
|
||||||
|
params: AttributeDict,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Process features for each stream in parallel.
|
"""Process features for each stream in parallel.
|
||||||
@ -406,6 +535,8 @@ def process_features(
|
|||||||
A 3-D tensor of shape (N, T, C).
|
A 3-D tensor of shape (N, T, C).
|
||||||
streams:
|
streams:
|
||||||
A list of streams of size (N,).
|
A list of streams of size (N,).
|
||||||
|
params:
|
||||||
|
It is the return value of :func:`get_params`.
|
||||||
sp:
|
sp:
|
||||||
The BPE model.
|
The BPE model.
|
||||||
"""
|
"""
|
||||||
@ -439,12 +570,25 @@ def process_features(
|
|||||||
for i, s in enumerate(state_list):
|
for i, s in enumerate(state_list):
|
||||||
streams[i].states = s
|
streams[i].states = s
|
||||||
|
|
||||||
greedy_search(
|
if params.decoding_method == "greedy_search":
|
||||||
model=model,
|
greedy_search(
|
||||||
streams=streams,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
streams=streams,
|
||||||
sp=sp,
|
encoder_out=encoder_out,
|
||||||
)
|
sp=sp,
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "modified_beam_search":
|
||||||
|
modified_beam_search(
|
||||||
|
model=model,
|
||||||
|
streams=streams,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
sp=sp,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def decode_batch(
|
def decode_batch(
|
||||||
@ -479,7 +623,7 @@ def decode_batch(
|
|||||||
stream_list = StreamList(
|
stream_list = StreamList(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
context_size=params.context_size,
|
context_size=params.context_size,
|
||||||
blank_id=params.blank_id,
|
decoding_method=params.decoding_method,
|
||||||
)
|
)
|
||||||
|
|
||||||
while not streaming_audio_samples.done:
|
while not streaming_audio_samples.done:
|
||||||
@ -497,11 +641,12 @@ def decode_batch(
|
|||||||
model=model,
|
model=model,
|
||||||
features=features,
|
features=features,
|
||||||
streams=active_streams,
|
streams=active_streams,
|
||||||
|
params=params,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
)
|
)
|
||||||
results = []
|
results = []
|
||||||
for s in stream_list.streams:
|
for s in stream_list.streams:
|
||||||
text = sp.decode(s.hyp.ys[params.context_size :])
|
text = sp.decode(s.result)
|
||||||
results.append(text)
|
results.append(text)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from beam_search import Hypothesis
|
from beam_search import Hypothesis, HypothesisList
|
||||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||||
|
|
||||||
|
|
||||||
@ -41,14 +41,10 @@ def _create_streaming_feature_extractor() -> OnlineFeature:
|
|||||||
|
|
||||||
|
|
||||||
class FeatureExtractionStream(object):
|
class FeatureExtractionStream(object):
|
||||||
def __init__(self, context_size: int, blank_id: int = 0) -> None:
|
def __init__(
|
||||||
"""Context size of the RNN-T decoder model."""
|
self,
|
||||||
|
) -> None:
|
||||||
self.feature_extractor = _create_streaming_feature_extractor()
|
self.feature_extractor = _create_streaming_feature_extractor()
|
||||||
self.hyp = Hypothesis(
|
|
||||||
ys=([blank_id] * context_size),
|
|
||||||
log_prob=torch.tensor([0.0]),
|
|
||||||
) # for greedy search, will extend it to beam search
|
|
||||||
|
|
||||||
# It contains a list of 1-D tensors representing the feature frames.
|
# It contains a list of 1-D tensors representing the feature frames.
|
||||||
self.feature_frames: List[torch.Tensor] = []
|
self.feature_frames: List[torch.Tensor] = []
|
||||||
|
|
||||||
@ -58,11 +54,6 @@ class FeatureExtractionStream(object):
|
|||||||
# encoder layer.
|
# encoder layer.
|
||||||
self.states: Optional[List[List[torch.Tensor]]] = None
|
self.states: Optional[List[List[torch.Tensor]]] = None
|
||||||
|
|
||||||
# For the RNN-T decoder, it contains the decoder output
|
|
||||||
# corresponding to the decoder input self.hyp.ys[-context_size:]
|
|
||||||
# Its shape is (decoder_out_dim,)
|
|
||||||
self.decoder_out: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
# After calling `self.input_finished()`, we set this flag to True
|
# After calling `self.input_finished()`, we set this flag to True
|
||||||
self._done = False
|
self._done = False
|
||||||
|
|
||||||
@ -85,9 +76,9 @@ class FeatureExtractionStream(object):
|
|||||||
check to ensure that the input sampling rate equals to the one
|
check to ensure that the input sampling rate equals to the one
|
||||||
used in the extractor. If they are not equal, then no resampling
|
used in the extractor. If they are not equal, then no resampling
|
||||||
will be performed; instead an error will be thrown.
|
will be performed; instead an error will be thrown.
|
||||||
waveform:
|
waveform:
|
||||||
A 1-D torch tensor of dtype torch.float32 containing audio samples.
|
A 1-D torch tensor of dtype torch.float32 containing audio samples.
|
||||||
It should be on CPU.
|
It should be on CPU.
|
||||||
"""
|
"""
|
||||||
self.feature_extractor.accept_waveform(
|
self.feature_extractor.accept_waveform(
|
||||||
sampling_rate=sampling_rate,
|
sampling_rate=sampling_rate,
|
||||||
@ -114,3 +105,33 @@ class FeatureExtractionStream(object):
|
|||||||
frame = self.feature_extractor.get_frame(self.num_fetched_frames)
|
frame = self.feature_extractor.get_frame(self.num_fetched_frames)
|
||||||
self.feature_frames.append(frame)
|
self.feature_frames.append(frame)
|
||||||
self.num_fetched_frames += 1
|
self.num_fetched_frames += 1
|
||||||
|
|
||||||
|
|
||||||
|
class GreedySearchStream(FeatureExtractionStream):
|
||||||
|
def __init__(self, context_size: int) -> None:
|
||||||
|
"""FeatureExtractionStream class for greedy search."""
|
||||||
|
super().__init__()
|
||||||
|
self.context_size = context_size
|
||||||
|
# For the RNN-T decoder, it contains the decoder output
|
||||||
|
# corresponding to the decoder input self.hyp.ys[-context_size:]
|
||||||
|
# Its shape is (decoder_out_dim,)
|
||||||
|
self.hyp: Hypothesis = None
|
||||||
|
self.decoder_out: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def result(self) -> List[int]:
|
||||||
|
return self.hyp.ys[self.context_size :]
|
||||||
|
|
||||||
|
|
||||||
|
class ModifiedBeamSearchStream(FeatureExtractionStream):
|
||||||
|
def __init__(self, context_size: int) -> None:
|
||||||
|
"""FeatureExtractionStream class for modified beam search decoding."""
|
||||||
|
super().__init__()
|
||||||
|
self.context_size = context_size
|
||||||
|
self.hyps = HypothesisList()
|
||||||
|
self.best_hyp = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def result(self) -> List[int]:
|
||||||
|
best_hyp = self.hyps.get_most_probable(length_norm=True)
|
||||||
|
return best_hyp.ys[self.context_size :]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user