mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
Merge branch 'streaming_decoding' into streaming
This commit is contained in:
commit
42f8afd264
@ -18,23 +18,26 @@
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import Hypothesis, HypothesisList, _get_hyps_shape
|
||||
from emformer import LOG_EPSILON, stack_states, unstack_states
|
||||
from streaming_feature_extractor import FeatureExtractionStream
|
||||
from streaming_feature_extractor import (
|
||||
FeatureExtractionStream,
|
||||
GreedySearchStream,
|
||||
ModifiedBeamSearchStream,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||
from icefall.utils import AttributeDict, setup_logger
|
||||
|
||||
|
||||
@ -50,6 +53,7 @@ def get_parser():
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
@ -208,7 +212,7 @@ class StreamList(object):
|
||||
self,
|
||||
batch_size: int,
|
||||
context_size: int,
|
||||
blank_id: int,
|
||||
decoding_method: str,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -216,14 +220,21 @@ class StreamList(object):
|
||||
Size of this batch.
|
||||
context_size:
|
||||
Context size of the RNN-T decoder model.
|
||||
blank_id:
|
||||
The ID of the blank symbol of the BPE model.
|
||||
decoding_method:
|
||||
Decoding method. The possible values are:
|
||||
- greedy_search
|
||||
- modified_beam_search
|
||||
"""
|
||||
decoding_classes = {
|
||||
"greedy_search": GreedySearchStream,
|
||||
"modified_beam_search": ModifiedBeamSearchStream,
|
||||
}
|
||||
|
||||
assert decoding_method in decoding_classes
|
||||
cls = decoding_classes[decoding_method]
|
||||
|
||||
self.streams = [
|
||||
FeatureExtractionStream(
|
||||
context_size=context_size, blank_id=blank_id
|
||||
)
|
||||
for _ in range(batch_size)
|
||||
cls(context_size=context_size) for _ in range(batch_size)
|
||||
]
|
||||
|
||||
@property
|
||||
@ -238,7 +249,7 @@ class StreamList(object):
|
||||
audio_samples: List[torch.Tensor],
|
||||
sampling_rate: float,
|
||||
):
|
||||
"""Feeed audio samples to each stream.
|
||||
"""Feed audio samples to each stream.
|
||||
Args:
|
||||
audio_samples:
|
||||
A list of 1-D tensors containing the audio samples for each
|
||||
@ -314,7 +325,7 @@ class StreamList(object):
|
||||
|
||||
def greedy_search(
|
||||
model: nn.Module,
|
||||
streams: List[FeatureExtractionStream],
|
||||
streams: List[GreedySearchStream],
|
||||
encoder_out: torch.Tensor,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
):
|
||||
@ -333,7 +344,15 @@ def greedy_search(
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
device = model.device
|
||||
assert len(streams) == encoder_out.size(0)
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
for s in streams:
|
||||
if s.hyp is None:
|
||||
s.hyp = Hypothesis(
|
||||
ys=([blank_id] * context_size),
|
||||
log_prob=torch.tensor([0.0], device=device),
|
||||
)
|
||||
if streams[0].decoder_out is None:
|
||||
decoder_input = torch.tensor(
|
||||
[stream.hyp.ys[-context_size:] for stream in streams],
|
||||
@ -351,8 +370,6 @@ def greedy_search(
|
||||
dim=0,
|
||||
)
|
||||
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
T = encoder_out.size(1)
|
||||
for t in range(T):
|
||||
current_encoder_out = encoder_out[:, t]
|
||||
@ -381,20 +398,132 @@ def greedy_search(
|
||||
)
|
||||
|
||||
for k, s in enumerate(streams):
|
||||
logging.info(
|
||||
f"Partial result {k}:\n{sp.decode(s.hyp.ys[context_size:])}"
|
||||
)
|
||||
logging.info(f"Partial result {k}:\n{sp.decode(s.result)}")
|
||||
|
||||
decoder_out_list = decoder_out.unbind(dim=0)
|
||||
|
||||
for i, d in enumerate(decoder_out_list):
|
||||
streams[i].decoder_out = d
|
||||
|
||||
|
||||
def modified_beam_search(
|
||||
model: nn.Module,
|
||||
streams: List[ModifiedBeamSearchStream],
|
||||
encoder_out: torch.Tensor,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
beam: int = 4,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
model:
|
||||
The RNN-T model.
|
||||
stream:
|
||||
A stream object.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
|
||||
the encoder model.
|
||||
sp:
|
||||
The BPE model.
|
||||
beam:
|
||||
Number of active paths during the beam search.
|
||||
"""
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
device = model.device
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
assert len(streams) == encoder_out.size(0)
|
||||
batch_size = len(streams)
|
||||
|
||||
for s in streams:
|
||||
if len(s.hyps) == 0:
|
||||
s.hyps.add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
)
|
||||
)
|
||||
|
||||
B = [s.hyps for s in streams]
|
||||
|
||||
T = encoder_out.size(1)
|
||||
for t in range(T):
|
||||
current_encoder_out = encoder_out[:, t]
|
||||
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||
|
||||
hyps_shape = _get_hyps_shape(B).to(device)
|
||||
|
||||
A = [list(b) for b in B]
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
|
||||
ys_log_probs = torch.cat(
|
||||
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
|
||||
) # (num_hyps, 1)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (num_hyps, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1)
|
||||
# decoder_out is of shape (num_hyps, decoder_output_dim)
|
||||
|
||||
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||
# as index, so we use `to(torch.int64)` below.
|
||||
current_encoder_out = torch.index_select(
|
||||
current_encoder_out,
|
||||
dim=0,
|
||||
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||
) # (num_hyps, encoder_out_dim)
|
||||
|
||||
logits = model.joiner(current_encoder_out, decoder_out)
|
||||
# logits is of shape (num_hyps, vocab_size)
|
||||
|
||||
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||
|
||||
log_probs.add_(ys_log_probs)
|
||||
|
||||
vocab_size = log_probs.size(-1)
|
||||
|
||||
log_probs = log_probs.reshape(-1)
|
||||
|
||||
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||
)
|
||||
ragged_log_probs = k2.RaggedTensor(
|
||||
shape=log_probs_shape, value=log_probs
|
||||
)
|
||||
|
||||
for i in range(batch_size):
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||
|
||||
for k in range(len(topk_hyp_indexes)):
|
||||
hyp_idx = topk_hyp_indexes[k]
|
||||
hyp = A[i][hyp_idx]
|
||||
|
||||
new_ys = hyp.ys[:]
|
||||
new_token = topk_token_indexes[k]
|
||||
if new_token != blank_id:
|
||||
new_ys.append(new_token)
|
||||
|
||||
new_log_prob = topk_log_probs[k]
|
||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
streams[i].hyps = B[i]
|
||||
logging.info(f"Partial result {i}:\n{sp.decode(streams[i].result)}")
|
||||
|
||||
|
||||
def process_features(
|
||||
model: nn.Module,
|
||||
features: torch.Tensor,
|
||||
streams: List[FeatureExtractionStream],
|
||||
params: AttributeDict,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
) -> None:
|
||||
"""Process features for each stream in parallel.
|
||||
@ -406,6 +535,8 @@ def process_features(
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
streams:
|
||||
A list of streams of size (N,).
|
||||
params:
|
||||
It is the return value of :func:`get_params`.
|
||||
sp:
|
||||
The BPE model.
|
||||
"""
|
||||
@ -439,12 +570,25 @@ def process_features(
|
||||
for i, s in enumerate(state_list):
|
||||
streams[i].states = s
|
||||
|
||||
greedy_search(
|
||||
model=model,
|
||||
streams=streams,
|
||||
encoder_out=encoder_out,
|
||||
sp=sp,
|
||||
)
|
||||
if params.decoding_method == "greedy_search":
|
||||
greedy_search(
|
||||
model=model,
|
||||
streams=streams,
|
||||
encoder_out=encoder_out,
|
||||
sp=sp,
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
modified_beam_search(
|
||||
model=model,
|
||||
streams=streams,
|
||||
encoder_out=encoder_out,
|
||||
sp=sp,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
|
||||
|
||||
def decode_batch(
|
||||
@ -479,7 +623,7 @@ def decode_batch(
|
||||
stream_list = StreamList(
|
||||
batch_size=batch_size,
|
||||
context_size=params.context_size,
|
||||
blank_id=params.blank_id,
|
||||
decoding_method=params.decoding_method,
|
||||
)
|
||||
|
||||
while not streaming_audio_samples.done:
|
||||
@ -497,11 +641,12 @@ def decode_batch(
|
||||
model=model,
|
||||
features=features,
|
||||
streams=active_streams,
|
||||
params=params,
|
||||
sp=sp,
|
||||
)
|
||||
results = []
|
||||
for s in stream_list.streams:
|
||||
text = sp.decode(s.hyp.ys[params.context_size :])
|
||||
text = sp.decode(s.result)
|
||||
results.append(text)
|
||||
return results
|
||||
|
||||
|
@ -17,7 +17,7 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from beam_search import Hypothesis
|
||||
from beam_search import Hypothesis, HypothesisList
|
||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||
|
||||
|
||||
@ -41,14 +41,10 @@ def _create_streaming_feature_extractor() -> OnlineFeature:
|
||||
|
||||
|
||||
class FeatureExtractionStream(object):
|
||||
def __init__(self, context_size: int, blank_id: int = 0) -> None:
|
||||
"""Context size of the RNN-T decoder model."""
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
self.feature_extractor = _create_streaming_feature_extractor()
|
||||
self.hyp = Hypothesis(
|
||||
ys=([blank_id] * context_size),
|
||||
log_prob=torch.tensor([0.0]),
|
||||
) # for greedy search, will extend it to beam search
|
||||
|
||||
# It contains a list of 1-D tensors representing the feature frames.
|
||||
self.feature_frames: List[torch.Tensor] = []
|
||||
|
||||
@ -58,11 +54,6 @@ class FeatureExtractionStream(object):
|
||||
# encoder layer.
|
||||
self.states: Optional[List[List[torch.Tensor]]] = None
|
||||
|
||||
# For the RNN-T decoder, it contains the decoder output
|
||||
# corresponding to the decoder input self.hyp.ys[-context_size:]
|
||||
# Its shape is (decoder_out_dim,)
|
||||
self.decoder_out: Optional[torch.Tensor] = None
|
||||
|
||||
# After calling `self.input_finished()`, we set this flag to True
|
||||
self._done = False
|
||||
|
||||
@ -85,9 +76,9 @@ class FeatureExtractionStream(object):
|
||||
check to ensure that the input sampling rate equals to the one
|
||||
used in the extractor. If they are not equal, then no resampling
|
||||
will be performed; instead an error will be thrown.
|
||||
waveform:
|
||||
A 1-D torch tensor of dtype torch.float32 containing audio samples.
|
||||
It should be on CPU.
|
||||
waveform:
|
||||
A 1-D torch tensor of dtype torch.float32 containing audio samples.
|
||||
It should be on CPU.
|
||||
"""
|
||||
self.feature_extractor.accept_waveform(
|
||||
sampling_rate=sampling_rate,
|
||||
@ -114,3 +105,33 @@ class FeatureExtractionStream(object):
|
||||
frame = self.feature_extractor.get_frame(self.num_fetched_frames)
|
||||
self.feature_frames.append(frame)
|
||||
self.num_fetched_frames += 1
|
||||
|
||||
|
||||
class GreedySearchStream(FeatureExtractionStream):
|
||||
def __init__(self, context_size: int) -> None:
|
||||
"""FeatureExtractionStream class for greedy search."""
|
||||
super().__init__()
|
||||
self.context_size = context_size
|
||||
# For the RNN-T decoder, it contains the decoder output
|
||||
# corresponding to the decoder input self.hyp.ys[-context_size:]
|
||||
# Its shape is (decoder_out_dim,)
|
||||
self.hyp: Hypothesis = None
|
||||
self.decoder_out: Optional[torch.Tensor] = None
|
||||
|
||||
@property
|
||||
def result(self) -> List[int]:
|
||||
return self.hyp.ys[self.context_size :]
|
||||
|
||||
|
||||
class ModifiedBeamSearchStream(FeatureExtractionStream):
|
||||
def __init__(self, context_size: int) -> None:
|
||||
"""FeatureExtractionStream class for modified beam search decoding."""
|
||||
super().__init__()
|
||||
self.context_size = context_size
|
||||
self.hyps = HypothesisList()
|
||||
self.best_hyp = None
|
||||
|
||||
@property
|
||||
def result(self) -> List[int]:
|
||||
best_hyp = self.hyps.get_most_probable(length_norm=True)
|
||||
return best_hyp.ys[self.context_size :]
|
||||
|
Loading…
x
Reference in New Issue
Block a user