Merge branch 'streaming_decoding' into streaming

This commit is contained in:
yaozengwei 2022-04-20 11:10:53 +08:00
commit 42f8afd264
2 changed files with 211 additions and 45 deletions

View File

@ -18,23 +18,26 @@
import argparse
import logging
import warnings
from pathlib import Path
from typing import List, Optional, Tuple
import k2
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import Hypothesis, HypothesisList, _get_hyps_shape
from emformer import LOG_EPSILON, stack_states, unstack_states
from streaming_feature_extractor import FeatureExtractionStream
from streaming_feature_extractor import (
FeatureExtractionStream,
GreedySearchStream,
ModifiedBeamSearchStream,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import AttributeDict, setup_logger
@ -50,6 +53,7 @@ def get_parser():
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
@ -208,7 +212,7 @@ class StreamList(object):
self,
batch_size: int,
context_size: int,
blank_id: int,
decoding_method: str,
):
"""
Args:
@ -216,14 +220,21 @@ class StreamList(object):
Size of this batch.
context_size:
Context size of the RNN-T decoder model.
blank_id:
The ID of the blank symbol of the BPE model.
decoding_method:
Decoding method. The possible values are:
- greedy_search
- modified_beam_search
"""
decoding_classes = {
"greedy_search": GreedySearchStream,
"modified_beam_search": ModifiedBeamSearchStream,
}
assert decoding_method in decoding_classes
cls = decoding_classes[decoding_method]
self.streams = [
FeatureExtractionStream(
context_size=context_size, blank_id=blank_id
)
for _ in range(batch_size)
cls(context_size=context_size) for _ in range(batch_size)
]
@property
@ -238,7 +249,7 @@ class StreamList(object):
audio_samples: List[torch.Tensor],
sampling_rate: float,
):
"""Feeed audio samples to each stream.
"""Feed audio samples to each stream.
Args:
audio_samples:
A list of 1-D tensors containing the audio samples for each
@ -314,7 +325,7 @@ class StreamList(object):
def greedy_search(
model: nn.Module,
streams: List[FeatureExtractionStream],
streams: List[GreedySearchStream],
encoder_out: torch.Tensor,
sp: spm.SentencePieceProcessor,
):
@ -333,7 +344,15 @@ def greedy_search(
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
for s in streams:
if s.hyp is None:
s.hyp = Hypothesis(
ys=([blank_id] * context_size),
log_prob=torch.tensor([0.0], device=device),
)
if streams[0].decoder_out is None:
decoder_input = torch.tensor(
[stream.hyp.ys[-context_size:] for stream in streams],
@ -351,8 +370,6 @@ def greedy_search(
dim=0,
)
assert encoder_out.ndim == 3
T = encoder_out.size(1)
for t in range(T):
current_encoder_out = encoder_out[:, t]
@ -381,20 +398,132 @@ def greedy_search(
)
for k, s in enumerate(streams):
logging.info(
f"Partial result {k}:\n{sp.decode(s.hyp.ys[context_size:])}"
)
logging.info(f"Partial result {k}:\n{sp.decode(s.result)}")
decoder_out_list = decoder_out.unbind(dim=0)
for i, d in enumerate(decoder_out_list):
streams[i].decoder_out = d
def modified_beam_search(
model: nn.Module,
streams: List[ModifiedBeamSearchStream],
encoder_out: torch.Tensor,
sp: spm.SentencePieceProcessor,
beam: int = 4,
):
"""
Args:
model:
The RNN-T model.
stream:
A stream object.
encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
sp:
The BPE model.
beam:
Number of active paths during the beam search.
"""
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
assert encoder_out.ndim == 3, encoder_out.shape
assert len(streams) == encoder_out.size(0)
batch_size = len(streams)
for s in streams:
if len(s.hyps) == 0:
s.hyps.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
B = [s.hyps for s in streams]
T = encoder_out.size(1)
for t in range(T):
current_encoder_out = encoder_out[:, t]
# current_encoder_out's shape: (batch_size, encoder_out_dim)
hyps_shape = _get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1)
# decoder_out is of shape (num_hyps, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, encoder_out_dim)
logits = model.joiner(current_encoder_out, decoder_out)
# logits is of shape (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
streams[i].hyps = B[i]
logging.info(f"Partial result {i}:\n{sp.decode(streams[i].result)}")
def process_features(
model: nn.Module,
features: torch.Tensor,
streams: List[FeatureExtractionStream],
params: AttributeDict,
sp: spm.SentencePieceProcessor,
) -> None:
"""Process features for each stream in parallel.
@ -406,6 +535,8 @@ def process_features(
A 3-D tensor of shape (N, T, C).
streams:
A list of streams of size (N,).
params:
It is the return value of :func:`get_params`.
sp:
The BPE model.
"""
@ -439,12 +570,25 @@ def process_features(
for i, s in enumerate(state_list):
streams[i].states = s
greedy_search(
model=model,
streams=streams,
encoder_out=encoder_out,
sp=sp,
)
if params.decoding_method == "greedy_search":
greedy_search(
model=model,
streams=streams,
encoder_out=encoder_out,
sp=sp,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=streams,
encoder_out=encoder_out,
sp=sp,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
def decode_batch(
@ -479,7 +623,7 @@ def decode_batch(
stream_list = StreamList(
batch_size=batch_size,
context_size=params.context_size,
blank_id=params.blank_id,
decoding_method=params.decoding_method,
)
while not streaming_audio_samples.done:
@ -497,11 +641,12 @@ def decode_batch(
model=model,
features=features,
streams=active_streams,
params=params,
sp=sp,
)
results = []
for s in stream_list.streams:
text = sp.decode(s.hyp.ys[params.context_size :])
text = sp.decode(s.result)
results.append(text)
return results

View File

@ -17,7 +17,7 @@
from typing import List, Optional
import torch
from beam_search import Hypothesis
from beam_search import Hypothesis, HypothesisList
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
@ -41,14 +41,10 @@ def _create_streaming_feature_extractor() -> OnlineFeature:
class FeatureExtractionStream(object):
def __init__(self, context_size: int, blank_id: int = 0) -> None:
"""Context size of the RNN-T decoder model."""
def __init__(
self,
) -> None:
self.feature_extractor = _create_streaming_feature_extractor()
self.hyp = Hypothesis(
ys=([blank_id] * context_size),
log_prob=torch.tensor([0.0]),
) # for greedy search, will extend it to beam search
# It contains a list of 1-D tensors representing the feature frames.
self.feature_frames: List[torch.Tensor] = []
@ -58,11 +54,6 @@ class FeatureExtractionStream(object):
# encoder layer.
self.states: Optional[List[List[torch.Tensor]]] = None
# For the RNN-T decoder, it contains the decoder output
# corresponding to the decoder input self.hyp.ys[-context_size:]
# Its shape is (decoder_out_dim,)
self.decoder_out: Optional[torch.Tensor] = None
# After calling `self.input_finished()`, we set this flag to True
self._done = False
@ -85,9 +76,9 @@ class FeatureExtractionStream(object):
check to ensure that the input sampling rate equals to the one
used in the extractor. If they are not equal, then no resampling
will be performed; instead an error will be thrown.
waveform:
A 1-D torch tensor of dtype torch.float32 containing audio samples.
It should be on CPU.
waveform:
A 1-D torch tensor of dtype torch.float32 containing audio samples.
It should be on CPU.
"""
self.feature_extractor.accept_waveform(
sampling_rate=sampling_rate,
@ -114,3 +105,33 @@ class FeatureExtractionStream(object):
frame = self.feature_extractor.get_frame(self.num_fetched_frames)
self.feature_frames.append(frame)
self.num_fetched_frames += 1
class GreedySearchStream(FeatureExtractionStream):
def __init__(self, context_size: int) -> None:
"""FeatureExtractionStream class for greedy search."""
super().__init__()
self.context_size = context_size
# For the RNN-T decoder, it contains the decoder output
# corresponding to the decoder input self.hyp.ys[-context_size:]
# Its shape is (decoder_out_dim,)
self.hyp: Hypothesis = None
self.decoder_out: Optional[torch.Tensor] = None
@property
def result(self) -> List[int]:
return self.hyp.ys[self.context_size :]
class ModifiedBeamSearchStream(FeatureExtractionStream):
def __init__(self, context_size: int) -> None:
"""FeatureExtractionStream class for modified beam search decoding."""
super().__init__()
self.context_size = context_size
self.hyps = HypothesisList()
self.best_hyp = None
@property
def result(self) -> List[int]:
best_hyp = self.hyps.get_most_probable(length_norm=True)
return best_hyp.ys[self.context_size :]