mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 22:54:18 +00:00
Parallel streaming decode with greedy search.
This commit is contained in:
parent
85db12caac
commit
985707f38a
@ -27,6 +27,66 @@ from torchaudio.models import Emformer as _Emformer
|
||||
LOG_EPSILON = math.log(1e-10)
|
||||
|
||||
|
||||
def unstack_states(
|
||||
states: List[List[torch.Tensor]],
|
||||
) -> List[List[List[torch.Tensor]]]:
|
||||
"""Unstack the emformer state corresponding to a batch of utterances
|
||||
into a list of states, were the i-th entry is the state from the i-th
|
||||
utterance in the batch.
|
||||
|
||||
Args:
|
||||
states:
|
||||
A list-of-list of tensors. ``len(states)`` equals to number of
|
||||
layers in the emformer. ``states[i]]`` contains the states for
|
||||
the i-th layer. ``states[i][k]`` is either a 3-D tensor of shape
|
||||
``(T, N, C)`` or a 2-D tensor of shape ``(C, N)``
|
||||
"""
|
||||
batch_size = states[0][0].size(1)
|
||||
num_layers = len(states)
|
||||
|
||||
ans = [None] * batch_size
|
||||
for i in range(batch_size):
|
||||
ans[i] = [[] for _ in range(num_layers)]
|
||||
|
||||
for li, layer in enumerate(states):
|
||||
for s in layer:
|
||||
s_list = s.unbind(dim=1)
|
||||
for bi, b in enumerate(ans):
|
||||
b[li].append(s_list[bi].unsqueeze(dim=1))
|
||||
return ans
|
||||
|
||||
|
||||
def stack_states(
|
||||
state_list: List[List[List[torch.Tensor]]],
|
||||
) -> List[List[torch.Tensor]]:
|
||||
"""Stack list of emformer states that correspond to separate utterances
|
||||
into a single emformer state so that it can be used as an input for
|
||||
emformer when those utterances are formed into a batch.
|
||||
|
||||
Note:
|
||||
It is the inverse of :func:`unstack_states`.
|
||||
|
||||
Args:
|
||||
state_list:
|
||||
Each element in state_list corresponding to the internal state
|
||||
of the emformer model for a single utterance.
|
||||
Returns:
|
||||
Return a new state corresponding to a batch of utterances.
|
||||
See the input argument of :func:`unstack_states` for the meaning
|
||||
of the returned tensor.
|
||||
"""
|
||||
ans = []
|
||||
for layer in state_list[0]:
|
||||
# layer is a list of tensors
|
||||
ans.append([s for s in layer])
|
||||
|
||||
for states in state_list[1:]:
|
||||
for li, layer in enumerate(states):
|
||||
for si, s in enumerate(layer):
|
||||
ans[li][si] = torch.cat([ans[li][si], s], dim=1)
|
||||
return ans
|
||||
|
||||
|
||||
class Emformer(EncoderInterface):
|
||||
"""This is just a simple wrapper around torchaudio.models.Emformer.
|
||||
We may replace it with our own implementation some time later.
|
||||
|
@ -18,16 +18,16 @@
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from emformer import LOG_EPSILON
|
||||
from streaming_feature_extractor import Stream
|
||||
from emformer import LOG_EPSILON, stack_states, unstack_states
|
||||
from streaming_feature_extractor import FeatureExtractionStream
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
@ -158,9 +158,166 @@ def get_parser():
|
||||
return parser
|
||||
|
||||
|
||||
class StreamingAudioSamples(object):
|
||||
"""This class takes as input a list of audio samples and returns
|
||||
them in a streaming fashion.
|
||||
"""
|
||||
|
||||
def __init__(self, samples: List[torch.Tensor]) -> None:
|
||||
"""
|
||||
Args:
|
||||
samples:
|
||||
A list of audio samples. Each entry is a 1-D tensor of dtype
|
||||
torch.float32, containing the audio samples of an utterance.
|
||||
"""
|
||||
self.samples = samples
|
||||
self.cur_indexes = [0] * len(self.samples)
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
"""Return True if all samples have been processed.
|
||||
Return False otherwise.
|
||||
"""
|
||||
for i, samples in zip(self.cur_indexes, self.samples):
|
||||
if i < samples.numel():
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_next(self) -> List[torch.Tensor]:
|
||||
"""Return a list of audio samples. Each entry may have different
|
||||
lengths. It is OK if an entry contains no samples at all, which
|
||||
means it reaches the end of the utterance.
|
||||
"""
|
||||
ans = []
|
||||
|
||||
# Note: Either branch is fine. The purpose is to simulate streaming
|
||||
if False:
|
||||
num = torch.randint(2000, 5000, (len(self.samples),)).tolist()
|
||||
else:
|
||||
num = [1024] * len(self.samples)
|
||||
|
||||
for i in range(len(self.samples)):
|
||||
start = self.cur_indexes[i]
|
||||
end = start + num[i]
|
||||
self.cur_indexes[i] = end
|
||||
|
||||
s = self.samples[i][start:end]
|
||||
ans.append(s)
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
class StreamList(object):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
context_size: int,
|
||||
blank_id: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
batch_size:
|
||||
Size of this batch.
|
||||
context_size:
|
||||
Context size of the RNN-T decoder model.
|
||||
blank_id:
|
||||
The ID of the blank symbol of the BPE model.
|
||||
"""
|
||||
self.streams = [
|
||||
FeatureExtractionStream(
|
||||
context_size=context_size, blank_id=blank_id
|
||||
)
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
"""Return True if all streams have reached end of utterance.
|
||||
That is, no more audio samples are available for all utterances.
|
||||
"""
|
||||
return all(stream.done for stream in self.streams)
|
||||
|
||||
def accept_waveform(
|
||||
self,
|
||||
audio_samples: List[torch.Tensor],
|
||||
sampling_rate: float,
|
||||
):
|
||||
"""Feeed audio samples to each stream.
|
||||
Args:
|
||||
audio_samples:
|
||||
A list of 1-D tensors containing the audio samples for each
|
||||
utterance in the batch. If an entry is empty, it means
|
||||
end-of-utterance has been reached.
|
||||
sampling_rate:
|
||||
Sampling rate of the given audio samples.
|
||||
"""
|
||||
assert len(audio_samples) == len(self.streams)
|
||||
for stream, samples in zip(self.streams, audio_samples):
|
||||
|
||||
if stream.done:
|
||||
assert samples.numel() == 0
|
||||
continue
|
||||
|
||||
stream.accept_waveform(
|
||||
sampling_rate=sampling_rate,
|
||||
waveform=samples,
|
||||
)
|
||||
|
||||
if samples.numel() == 0:
|
||||
stream.input_finished()
|
||||
|
||||
def build_batch(
|
||||
self,
|
||||
chunk_length: int,
|
||||
segment_length: int,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[List[FeatureExtractionStream]]]:
|
||||
"""
|
||||
Args:
|
||||
chunk_length:
|
||||
Number of frames for each chunk. It equals to
|
||||
``segment_length + right_context_length``.
|
||||
segment_length
|
||||
Number of frames for each segment.
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- features, a 3-D tensor of shape ``(num_active_streams, T, C)``
|
||||
- active_streams, a list of active streams. We say a stream is
|
||||
active when it has enough feature frames to be fed into the
|
||||
encoder model.
|
||||
"""
|
||||
feature_list = []
|
||||
stream_list = []
|
||||
for stream in self.streams:
|
||||
if len(stream.feature_frames) >= chunk_length:
|
||||
# this_chunk is a list of tensors, each of which
|
||||
# has a shape (1, feature_dim)
|
||||
chunk = stream.feature_frames[:chunk_length]
|
||||
stream.feature_frames = stream.feature_frames[segment_length:]
|
||||
features = torch.cat(chunk, dim=0).unsqueeze(0)
|
||||
feature_list.append(features)
|
||||
stream_list.append(stream)
|
||||
elif stream.done and len(stream.feature_frames) > 0:
|
||||
chunk = stream.feature_frames[:chunk_length]
|
||||
stream.feature_frames = []
|
||||
features = torch.cat(chunk, dim=0).unsqueeze(0)
|
||||
features = torch.nn.functional.pad(
|
||||
features,
|
||||
(0, 0, 0, chunk_length - features.size(1)),
|
||||
value=LOG_EPSILON,
|
||||
)
|
||||
feature_list.append(features)
|
||||
stream_list.append(stream)
|
||||
|
||||
if len(feature_list) == 0:
|
||||
return None, None
|
||||
|
||||
features = torch.cat(feature_list, dim=0)
|
||||
return features, stream_list
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: nn.Module,
|
||||
stream: Stream,
|
||||
streams: List[FeatureExtractionStream],
|
||||
encoder_out: torch.Tensor,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
):
|
||||
@ -171,7 +328,7 @@ def greedy_search(
|
||||
stream:
|
||||
A stream object.
|
||||
encoder_out:
|
||||
A 2-D tensor of shape (T, encoder_out_dim) containing the output of
|
||||
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
|
||||
the encoder model.
|
||||
sp:
|
||||
The BPE model.
|
||||
@ -180,59 +337,130 @@ def greedy_search(
|
||||
context_size = model.decoder.context_size
|
||||
device = model.device
|
||||
|
||||
if stream.decoder_out is None:
|
||||
if streams[0].decoder_out is None:
|
||||
decoder_input = torch.tensor(
|
||||
[stream.hyp.ys[-context_size:]],
|
||||
[stream.hyp.ys[-context_size:] for stream in streams],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
stream.decoder_out = model.decoder(
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=False,
|
||||
).unsqueeze(1)
|
||||
# stream.decoder_out is of shape (1, 1, decoder_out_dim)
|
||||
# decoder_out is of shape (N, 1, decoder_out_dim)
|
||||
else:
|
||||
decoder_out = torch.cat(
|
||||
[stream.decoder_out for stream in streams],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
assert encoder_out.ndim == 2
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
T = encoder_out.size(0)
|
||||
T = encoder_out.size(1)
|
||||
for t in range(T):
|
||||
current_encoder_out = encoder_out[t].reshape(
|
||||
1, 1, 1, encoder_out.size(-1)
|
||||
)
|
||||
logits = model.joiner(current_encoder_out, stream.decoder_out)
|
||||
# logits is of shape (1, 1, 1, vocab_size)
|
||||
y = logits.argmax().item()
|
||||
if y == blank_id:
|
||||
continue
|
||||
stream.hyp.ys.append(y)
|
||||
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
||||
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[stream.hyp.ys[-context_size:]],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
logits = model.joiner(current_encoder_out, decoder_out)
|
||||
# logits'shape (batch_size, 1, 1, vocab_size)
|
||||
|
||||
stream.decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=False,
|
||||
).unsqueeze(1)
|
||||
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
streams[i].hyp.ys.append(v)
|
||||
emitted = True
|
||||
|
||||
logging.info(
|
||||
f"Partial result:\n{sp.decode(stream.hyp.ys[context_size:])}"
|
||||
)
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = torch.tensor(
|
||||
[stream.hyp.ys[-context_size:] for stream in streams],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.decoder(
|
||||
decoder_input, need_pad=False
|
||||
).unsqueeze(1)
|
||||
|
||||
for k, s in enumerate(streams):
|
||||
logging.info(
|
||||
f"Partial result {k}:\n{sp.decode(s.hyp.ys[context_size:])}"
|
||||
)
|
||||
|
||||
decoder_out_list = decoder_out.unbind(dim=0)
|
||||
|
||||
for i, d in enumerate(decoder_out_list):
|
||||
streams[i].decoder_out = d.unsqueeze(0)
|
||||
|
||||
|
||||
def process_feature_frames(
|
||||
def process_features(
|
||||
model: nn.Module,
|
||||
stream: Stream,
|
||||
features: torch.Tensor,
|
||||
streams: List[FeatureExtractionStream],
|
||||
sp: spm.SentencePieceProcessor,
|
||||
):
|
||||
"""Process the feature frames contained in ``stream.feature_frames``.
|
||||
) -> None:
|
||||
"""Process features for each stream in parallel.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The RNN-T model.
|
||||
stream:
|
||||
The stream corresponding to the input audio samples.
|
||||
features:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
streams:
|
||||
A list of streams of size (N,).
|
||||
sp:
|
||||
The BPE model.
|
||||
"""
|
||||
assert features.ndim == 3
|
||||
assert features.size(0) == len(streams)
|
||||
batch_size = features.size(0)
|
||||
|
||||
device = model.device
|
||||
features = features.to(device)
|
||||
feature_lens = torch.full(
|
||||
(batch_size,),
|
||||
fill_value=features.size(1),
|
||||
device=device,
|
||||
)
|
||||
if streams[0].states is None:
|
||||
states = None
|
||||
else:
|
||||
state_list = [stream.states for stream in streams]
|
||||
states = stack_states(state_list)
|
||||
|
||||
(encoder_out, encoder_out_lens, states,) = model.encoder.streaming_forward(
|
||||
features,
|
||||
feature_lens,
|
||||
states,
|
||||
)
|
||||
state_list = unstack_states(states)
|
||||
for i, s in enumerate(state_list):
|
||||
streams[i].states = s
|
||||
|
||||
greedy_search(
|
||||
model=model,
|
||||
streams=streams,
|
||||
encoder_out=encoder_out,
|
||||
sp=sp,
|
||||
)
|
||||
|
||||
|
||||
def decode_batch(
|
||||
batched_samples: List[torch.Tensor],
|
||||
model: nn.Module,
|
||||
params: AttributeDict,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Args:
|
||||
batched_samples:
|
||||
A list of 1-D tensors containing the audio samples of each utterance.
|
||||
model:
|
||||
The RNN-T model.
|
||||
params:
|
||||
It is the return value of :func:`get_params`.
|
||||
sp:
|
||||
The BPE model.
|
||||
"""
|
||||
@ -241,102 +469,41 @@ def process_feature_frames(
|
||||
|
||||
right_context_length = model.encoder.right_context_length
|
||||
|
||||
# We add 3 here since the subsampling method is using
|
||||
# ((len - 1) // 2 - 1) // 2)
|
||||
chunk_length = (segment_length + 3) + right_context_length
|
||||
|
||||
device = model.device
|
||||
while len(stream.feature_frames) >= chunk_length:
|
||||
# a list of tensor, each with a shape (1, feature_dim)
|
||||
this_chunk = stream.feature_frames[:chunk_length]
|
||||
batch_size = len(batched_samples)
|
||||
streaming_audio_samples = StreamingAudioSamples(batched_samples)
|
||||
|
||||
stream.feature_frames = stream.feature_frames[segment_length:]
|
||||
features = torch.cat(this_chunk, dim=0).to(device) # (T, feature_dim)
|
||||
features = features.unsqueeze(0) # (1, T, feature_dim)
|
||||
feature_lens = torch.tensor([features.size(1)], device=device)
|
||||
(
|
||||
encoder_out,
|
||||
encoder_out_lens,
|
||||
stream.states,
|
||||
) = model.encoder.streaming_forward(
|
||||
features,
|
||||
feature_lens,
|
||||
stream.states,
|
||||
stream_list = StreamList(
|
||||
batch_size=batch_size,
|
||||
context_size=params.context_size,
|
||||
blank_id=params.blank_id,
|
||||
)
|
||||
|
||||
while not streaming_audio_samples.done:
|
||||
samples = streaming_audio_samples.get_next()
|
||||
stream_list.accept_waveform(
|
||||
audio_samples=samples,
|
||||
sampling_rate=params.sampling_rate,
|
||||
)
|
||||
greedy_search(
|
||||
model=model,
|
||||
stream=stream,
|
||||
encoder_out=encoder_out[0],
|
||||
sp=sp,
|
||||
features, active_streams = stream_list.build_batch(
|
||||
chunk_length=chunk_length,
|
||||
segment_length=segment_length,
|
||||
)
|
||||
|
||||
if stream.feature_extractor.is_last_frame(stream.num_fetched_frames - 1):
|
||||
assert len(stream.feature_frames) < chunk_length
|
||||
|
||||
if len(stream.feature_frames) > 0:
|
||||
this_chunk = stream.feature_frames[:chunk_length]
|
||||
stream.feature_frames = []
|
||||
features = torch.cat(this_chunk, dim=0) # (T, feature_dim)
|
||||
features = features.to(device).unsqueeze(0) # (1, T, feature_dim)
|
||||
features = torch.nn.functional.pad(
|
||||
features,
|
||||
(0, 0, 0, chunk_length - features.size(1)),
|
||||
value=LOG_EPSILON,
|
||||
)
|
||||
feature_lens = torch.tensor([features.size(1)], device=device)
|
||||
(
|
||||
encoder_out,
|
||||
encoder_out_lens,
|
||||
stream.states,
|
||||
) = model.encoder.streaming_forward(
|
||||
features,
|
||||
feature_lens,
|
||||
stream.states,
|
||||
)
|
||||
greedy_search(
|
||||
if features is not None:
|
||||
process_features(
|
||||
model=model,
|
||||
stream=stream,
|
||||
encoder_out=encoder_out[0],
|
||||
features=features,
|
||||
streams=active_streams,
|
||||
sp=sp,
|
||||
)
|
||||
|
||||
|
||||
def decode_one_utterance(
|
||||
audio_samples: torch.Tensor,
|
||||
model: nn.Module,
|
||||
stream: Stream,
|
||||
params: AttributeDict,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
):
|
||||
"""Decode one utterance.
|
||||
Args:
|
||||
audio_samples:
|
||||
A 1-D float32 tensor of shape (num_samples,) containing the
|
||||
audio samples.
|
||||
model:
|
||||
The RNN-T model.
|
||||
feature_extractor:
|
||||
The feature extractor.
|
||||
params:
|
||||
It is the return value of :func:`get_params`.
|
||||
sp:
|
||||
The BPE model.
|
||||
"""
|
||||
i = 0
|
||||
num_samples = audio_samples.size(0)
|
||||
while i < num_samples:
|
||||
# Simulate streaming.
|
||||
this_chunk_num_samples = torch.randint(2000, 5000, (1,)).item()
|
||||
|
||||
thiks_chunk_samples = audio_samples[i : (i + this_chunk_num_samples)]
|
||||
i += this_chunk_num_samples
|
||||
|
||||
stream.accept_waveform(
|
||||
sampling_rate=params.sampling_rate,
|
||||
waveform=thiks_chunk_samples,
|
||||
)
|
||||
process_feature_frames(model=model, stream=stream, sp=sp)
|
||||
|
||||
stream.input_finished()
|
||||
process_feature_frames(model=model, stream=stream, sp=sp)
|
||||
results = []
|
||||
for s in stream_list.streams:
|
||||
text = sp.decode(s.hyp.ys[params.context_size :])
|
||||
results.append(text)
|
||||
return results
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@ -403,31 +570,41 @@ def main():
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
|
||||
for num, cut in enumerate(test_clean_cuts):
|
||||
logging.info(f"Processing {num}")
|
||||
stream = Stream(
|
||||
context_size=model.decoder.context_size,
|
||||
blank_id=model.decoder.blank_id,
|
||||
)
|
||||
batch_size = 3
|
||||
|
||||
ground_truth = []
|
||||
batched_samples = []
|
||||
for num, cut in enumerate(test_clean_cuts):
|
||||
audio: np.ndarray = cut.load_audio()
|
||||
# audio.shape: (1, num_samples)
|
||||
assert len(audio.shape) == 2
|
||||
assert audio.shape[0] == 1, "Should be single channel"
|
||||
assert audio.dtype == np.float32, audio.dtype
|
||||
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
|
||||
decode_one_utterance(
|
||||
audio_samples=torch.from_numpy(audio).squeeze(0).to(device),
|
||||
model=model,
|
||||
stream=stream,
|
||||
params=params,
|
||||
sp=sp,
|
||||
)
|
||||
|
||||
logging.info(f"The ground truth is:\n{cut.supervisions[0].text}")
|
||||
if num >= 2:
|
||||
# The trained model is using normalized samples
|
||||
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
|
||||
|
||||
samples = torch.from_numpy(audio).squeeze(0)
|
||||
|
||||
batched_samples.append(samples)
|
||||
ground_truth.append(cut.supervisions[0].text)
|
||||
|
||||
if len(batched_samples) >= batch_size:
|
||||
decoded_results = decode_batch(
|
||||
batched_samples=batched_samples,
|
||||
model=model,
|
||||
params=params,
|
||||
sp=sp,
|
||||
)
|
||||
s = "\n"
|
||||
for i, (hyp, ref) in enumerate(zip(decoded_results, ground_truth)):
|
||||
s += f"hyp {i}:\n{hyp}\n"
|
||||
s += f"ref {i}:\n{ref}\n\n"
|
||||
logging.info(s)
|
||||
batched_samples = []
|
||||
ground_truth = []
|
||||
# break after processing the first batch for test purposes
|
||||
break
|
||||
time.sleep(2) # So that you can see the decoded results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -40,7 +40,7 @@ def _create_streaming_feature_extractr() -> OnlineFeature:
|
||||
return OnlineFbank(opts)
|
||||
|
||||
|
||||
class Stream(object):
|
||||
class FeatureExtractionStream(object):
|
||||
def __init__(self, context_size: int, blank_id: int = 0) -> None:
|
||||
"""Context size of the RNN-T decoder model."""
|
||||
self.feature_extractor = _create_streaming_feature_extractr()
|
||||
@ -62,6 +62,9 @@ class Stream(object):
|
||||
# corresponding to the decoder input self.hyp.ys[-context_size:]
|
||||
self.decoder_out: Optional[torch.Tensor] = None
|
||||
|
||||
# After calling `self.input_finished()`, we set this flag to True
|
||||
self._done = False
|
||||
|
||||
def accept_waveform(
|
||||
self,
|
||||
sampling_rate: float,
|
||||
@ -97,6 +100,12 @@ class Stream(object):
|
||||
"""
|
||||
self.feature_extractor.input_finished()
|
||||
self._fetch_frames()
|
||||
self._done = True
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
"""Return True if `self.input_finished()` has been invoked"""
|
||||
return self._done
|
||||
|
||||
def _fetch_frames(self) -> None:
|
||||
"""Fetch frames from the feature extractor"""
|
||||
|
@ -25,7 +25,7 @@ To run this file, do:
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from emformer import Emformer
|
||||
from emformer import Emformer, stack_states, unstack_states
|
||||
|
||||
|
||||
def test_emformer():
|
||||
@ -65,8 +65,41 @@ def test_emformer():
|
||||
print(f"Number of encoder parameters: {num_param}")
|
||||
|
||||
|
||||
def test_emformer_streaming_forward():
|
||||
N = 3
|
||||
C = 80
|
||||
|
||||
output_dim = 500
|
||||
|
||||
encoder = Emformer(
|
||||
num_features=C,
|
||||
output_dim=output_dim,
|
||||
d_model=512,
|
||||
nhead=8,
|
||||
dim_feedforward=2048,
|
||||
num_encoder_layers=20,
|
||||
segment_length=16,
|
||||
left_context_length=120,
|
||||
right_context_length=4,
|
||||
vgg_frontend=False,
|
||||
)
|
||||
|
||||
x = torch.rand(N, 23, C)
|
||||
x_lens = torch.full((N,), 23)
|
||||
y, y_lens, states = encoder.streaming_forward(x=x, x_lens=x_lens)
|
||||
|
||||
state_list = unstack_states(states)
|
||||
states2 = stack_states(state_list)
|
||||
|
||||
for ss, ss2 in zip(states, states2):
|
||||
for s, s2 in zip(ss, ss2):
|
||||
assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}"
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
test_emformer()
|
||||
# test_emformer()
|
||||
test_emformer_streaming_forward()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -24,11 +24,11 @@ To run this file, do:
|
||||
"""
|
||||
|
||||
import torch
|
||||
from streaming_feature_extractor import Stream
|
||||
from streaming_feature_extractor import FeatureExtractionStream
|
||||
|
||||
|
||||
def test_streaming_feature_extractor():
|
||||
stream = Stream(context_size=2, blank_id=0)
|
||||
stream = FeatureExtractionStream(context_size=2, blank_id=0)
|
||||
samples = torch.rand(16000)
|
||||
start = 0
|
||||
while True:
|
||||
|
Loading…
x
Reference in New Issue
Block a user