refactor streaming decoding

This commit is contained in:
yaozengwei 2022-06-09 20:37:16 +08:00
parent 734d97c47b
commit 7f09720403
4 changed files with 421 additions and 434 deletions

View File

@ -41,8 +41,8 @@ LOG_EPSILON = math.log(1e-10)
def unstack_states(
states,
) -> List[List[List[torch.Tensor]]]:
states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]]
) -> List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]]:
# TODO: modify doc
"""Unstack the emformer state corresponding to a batch of utterances
into a list of states, were the i-th entry is the state from the i-th
@ -50,18 +50,14 @@ def unstack_states(
Args:
states:
A list-of-list of tensors. ``len(states)`` equals to number of
layers in the emformer. ``states[i]]`` contains the states for
the i-th layer. ``states[i][k]`` is either a 3-D tensor of shape
``(T, N, C)`` or a 2-D tensor of shape ``(C, N)``
A list-of-list of tensors.
``len(states[0])`` and ``len(states[1])`` eqaul to number of layers.
"""
past_lens, attn_caches, conv_caches = states
batch_size = past_lens.size(0)
attn_caches, conv_caches = states
batch_size = conv_caches[0].size(0)
num_layers = len(attn_caches)
list_past_len = past_lens.tolist()
list_attn_caches = [None] * batch_size
for i in range(batch_size):
list_attn_caches[i] = [[] for _ in range(num_layers)]
@ -81,14 +77,14 @@ def unstack_states(
ans = [None] * batch_size
for i in range(batch_size):
ans[i] = [list_past_len[i], list_attn_caches[i], list_conv_caches[i]]
ans[i] = [list_attn_caches[i], list_conv_caches[i]]
return ans
def stack_states(
state_list,
) -> List[List[torch.Tensor]]:
state_list: List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]]
) -> Tuple[List[List[torch.Tensor]], List[torch.Tensor]]:
# TODO: modify doc
"""Stack list of emformer states that correspond to separate utterances
into a single emformer state so that it can be used as an input for
@ -108,18 +104,15 @@ def stack_states(
"""
batch_size = len(state_list)
past_lens = [states[0] for states in state_list]
past_lens = torch.tensor([past_lens])
attn_caches = []
for layer in state_list[0][1]:
for layer in state_list[0][0]:
if batch_size > 1:
# Note: We will stack attn_caches[layer][s][] later to get attn_caches[layer][s] # noqa
attn_caches.append([[s] for s in layer])
else:
attn_caches.append([s.unsqueeze(1) for s in layer])
for b, states in enumerate(state_list[1:], 1):
for li, layer in enumerate(states[1]):
for li, layer in enumerate(states[0]):
for si, s in enumerate(layer):
attn_caches[li][si].append(s)
if b == batch_size - 1:
@ -128,19 +121,19 @@ def stack_states(
)
conv_caches = []
for layer in state_list[0][2]:
for layer in state_list[0][1]:
if batch_size > 1:
# Note: We will stack conv_caches[layer][] later to get attn_caches[layer] # noqa
conv_caches.append([layer])
else:
conv_caches.append(layer.unsqueeze(0))
for b, states in enumerate(state_list[1:], 1):
for li, layer in enumerate(states[2]):
for li, layer in enumerate(states[1]):
conv_caches[li].append(layer)
if b == batch_size - 1:
conv_caches[li] = torch.stack(conv_caches[li], dim=0)
return [past_lens, attn_caches, conv_caches]
return [attn_caches, conv_caches]
class ConvolutionModule(nn.Module):
@ -1489,13 +1482,12 @@ class EmformerEncoder(nn.Module):
self,
x: torch.Tensor,
lengths: torch.Tensor,
states: List[
torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]
],
num_processed_frames: torch.Tensor,
states: Tuple[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]],
) -> Tuple[
torch.Tensor,
torch.Tensor,
List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]],
Tuple[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]],
]:
"""Forward pass for streaming inference.
@ -1526,10 +1518,9 @@ class EmformerEncoder(nn.Module):
right_context at the end.
- updated states from current chunk's computation.
"""
past_lens = states[0]
assert past_lens.shape == (x.size(1),), past_lens.shape
assert num_processed_frames.shape == (x.size(1),)
attn_caches = states[1]
attn_caches = states[0]
assert len(attn_caches) == self.num_encoder_layers, len(attn_caches)
for i in range(len(attn_caches)):
assert attn_caches[i][0].shape == (
@ -1548,24 +1539,23 @@ class EmformerEncoder(nn.Module):
self.d_model,
), attn_caches[i][2].shape
conv_caches = states[2]
conv_caches = states[1]
assert len(conv_caches) == self.num_encoder_layers, len(conv_caches)
for i in range(len(conv_caches)):
assert conv_caches[i].shape == (
x.size(1),
self.d_model,
self.cnn_module_kernel,
self.cnn_module_kernel - 1,
), conv_caches[i].shape
assert x.size(0) == self.chunk_length + self.right_context_length, (
"Per configured chunk_length and right_context_length, "
f"expected size of {self.chunk_length + self.right_context_length} "
f"for dimension 1 of x, but got {x.size(1)}."
)
# assert x.size(0) == self.chunk_length + self.right_context_length, (
# "Per configured chunk_length and right_context_length, "
# f"expected size of {self.chunk_length + self.right_context_length} "
# f"for dimension 1 of x, but got {x.size(0)}."
# )
right_context_start_idx = x.size(0) - self.right_context_length
right_context = x[right_context_start_idx:]
utterance = x[:right_context_start_idx]
right_context = x[-self.right_context_length :]
utterance = x[: -self.right_context_length]
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
memory = (
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
@ -1574,29 +1564,29 @@ class EmformerEncoder(nn.Module):
)
# calcualte padding mask to mask out initial zero caches
chunk_mask = make_pad_mask(output_lengths)
memory_mask = (
(past_lens // self.chunk_length).view(x.size(1), 1)
<= torch.arange(self.memory_size, device=x.device).expand(
x.size(1), self.memory_size
)
).flip(1)
left_context_mask = (
past_lens.view(x.size(1), 1)
<= torch.arange(self.left_context_length, device=x.device).expand(
x.size(1), self.left_context_length
)
).flip(1)
right_context_mask = torch.zeros(
x.size(1),
self.right_context_length,
dtype=torch.bool,
device=x.device,
)
padding_mask = torch.cat(
[memory_mask, left_context_mask, right_context_mask, chunk_mask],
dim=1,
)
# chunk_mask = make_pad_mask(output_lengths).to(x.device)
# memory_mask = (
# (past_lens // self.chunk_length).view(x.size(1), 1)
# <= torch.arange(self.memory_size, device=x.device).expand(
# x.size(1), self.memory_size
# )
# ).flip(1)
# left_context_mask = (
# past_lens.view(x.size(1), 1)
# <= torch.arange(self.left_context_length, device=x.device).expand(
# x.size(1), self.left_context_length
# )
# ).flip(1)
# right_context_mask = torch.zeros(
# x.size(1),
# self.right_context_length,
# dtype=torch.bool,
# device=x.device,
# )
# padding_mask = torch.cat(
# [memory_mask, left_context_mask, right_context_mask, chunk_mask],
# dim=1,
# )
output = utterance
output_attn_caches: List[List[torch.Tensor]] = []
@ -1612,19 +1602,14 @@ class EmformerEncoder(nn.Module):
output,
right_context,
memory,
padding_mask=padding_mask,
# padding_mask=padding_mask,
attn_cache=attn_caches[layer_idx],
conv_cache=conv_caches[layer_idx],
)
output_attn_caches.append(output_attn_cache)
output_conv_caches.append(output_conv_cache)
output_past_lens = past_lens + output_lengths
output_states = [
output_past_lens,
output_attn_caches,
output_conv_caches,
]
output_states = [output_attn_caches, output_conv_caches]
return output, output_lengths, output_states
@ -1738,6 +1723,7 @@ class Emformer(EncoderInterface):
self,
x: torch.Tensor,
x_lens: torch.Tensor,
num_processed_frames: torch.Tensor,
states: Optional[List[List[torch.Tensor]]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
"""Forward pass for streaming inference.
@ -1770,16 +1756,17 @@ class Emformer(EncoderInterface):
- updated states from current chunk's computation.
"""
x = self.encoder_embed(x)
# drop the first and last frames
x = x[:, 1:-1, :]
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# Caution: We assume the subsampling factor is 4!
with warnings.catch_warnings():
warnings.simplefilter("ignore")
x_lens = ((x_lens - 1) // 2 - 1) // 2
x_lens = (((x_lens - 1) >> 1) - 1) >> 1
x_lens -= 2
assert x.size(0) == x_lens.max().item()
output, output_lengths, output_states = self.encoder.infer(
x, x_lens, states
x, x_lens, num_processed_frames, states
)
output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C)

View File

@ -18,7 +18,7 @@ import math
from typing import List, Optional, Tuple
import torch
from beam_search import HypothesisList
from beam_search import Hypothesis, HypothesisList
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
from icefall.utils import AttributeDict
@ -48,6 +48,7 @@ class Stream(object):
self,
params: AttributeDict,
device: torch.device = torch.device("cpu"),
LOG_EPS: float = math.log(1e-10),
) -> None:
"""
Args:
@ -57,11 +58,14 @@ class Stream(object):
The device to run this stream.
"""
self.device = device
self.LOG_EPS = LOG_EPS
# Containing attention caches and convolution caches
self.states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] = None
self.states: Optional[
Tuple[List[List[torch.Tensor]], List[torch.Tensor]]
] = None
# Initailize zero states.
self.init_states()
self.init_states(params)
# It use different attributes for different decoding methods.
self.context_size = params.context_size
@ -70,6 +74,12 @@ class Stream(object):
self.hyp = [params.blank_id] * params.context_size
elif params.decoding_method == "modified_beam_search":
self.hyps = HypothesisList()
self.hyps.add(
Hypothesis(
ys=[params.blank_id] * params.context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
@ -77,7 +87,7 @@ class Stream(object):
self.ground_truth: str = ""
self.feature: torch.Tensor = None
self.feature: Optional[torch.Tensor] = None
# Make sure all feature frames can be used.
# Add 2 here since we will drop the first and last after subsampling.
self.chunk_length = params.chunk_length
@ -91,14 +101,14 @@ class Stream(object):
self._done = False
def set_feature(self, feature: torch.Tensor) -> None:
assert feature.dim == 2, feature.dim
assert feature.dim() == 2, feature.dim()
self.num_frames = feature.size(0)
# tail padding
self.feature = torch.nn.functional.pad(
feature,
(0, 0, 0, self.pad_length),
mode="constant",
value=math.log(1e-10),
value=self.LOG_EPS,
)
def set_ground_truth(self, ground_truth: str) -> None:
@ -140,9 +150,11 @@ class Stream(object):
)
ret_length = update_length + self.pad_length
ret_feature = self.feature[:ret_length]
ret_feature = self.feature[
self.num_processed_frames : self.num_processed_frames + ret_length
]
# Cut off used frames.
self.feature = self.feature[update_length:]
# self.feature = self.feature[update_length:]
self.num_processed_frames += update_length
if self.num_processed_frames >= self.num_frames:

View File

@ -18,9 +18,10 @@
import argparse
import logging
import math
import warnings
from pathlib import Path
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple
import k2
from lhotse import CutSet
@ -31,15 +32,24 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from emformer import LOG_EPSILON, stack_states, unstack_states
from streaming_feature_extractor import Stream
from kaldifeat import Fbank, FbankOptions
from stream import Stream
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import AttributeDict, setup_logger
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
@ -55,6 +65,16 @@ def get_parser():
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
@ -65,14 +85,14 @@ def get_parser():
)
parser.add_argument(
"--avg-last-n",
type=int,
default=0,
help="""If positive, --epoch and --avg are ignored and it
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
where xxx is the number of processed batches while
saving that checkpoint.
""",
"--use-averaged-model",
type=str2bool,
default=False,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
@ -172,52 +192,39 @@ def get_parser():
def greedy_search(
model: nn.Module,
streams: List[Stream],
encoder_out: torch.Tensor,
sp: spm.SentencePieceProcessor,
):
"""
Args:
model:
The RNN-T model.
streams:
A list of stream objects.
encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
sp:
The BPE model.
"""
streams: List[Stream],
) -> List[List[int]]:
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
device = next(model.parameters()).device
T = encoder_out.size(1)
if streams[0].decoder_out is None:
for stream in streams:
stream.hyp = [blank_id] * context_size
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1)
# decoder_out is of shape (N, decoder_out_dim)
else:
decoder_out = torch.stack(
[stream.decoder_out for stream in streams],
dim=0,
)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# logging.info(f"decoder_out shape : {decoder_out.shape}")
for t in range(T):
current_encoder_out = encoder_out[:, t]
# current_encoder_out's shape: (batch_size, encoder_out_dim)
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner(current_encoder_out, decoder_out)
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
@ -236,227 +243,64 @@ def greedy_search(
decoder_out = model.decoder(
decoder_input,
need_pad=False,
).squeeze(1)
for k, stream in enumerate(streams):
result = sp.decode(stream.decoding_result())
logging.info(f"Partial result {k}:\n{result}")
decoder_out_list = decoder_out.unbind(dim=0)
for i, d in enumerate(decoder_out_list):
streams[i].decoder_out = d
def modified_beam_search(
model: nn.Module,
streams: List[Stream],
encoder_out: torch.Tensor,
sp: spm.SentencePieceProcessor,
beam: int = 4,
):
"""
Args:
model:
The RNN-T model.
streams:
A list of stream objects.
encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
sp:
The BPE model.
beam:
Number of active paths during the beam search.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert len(streams) == encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
batch_size = len(streams)
T = encoder_out.size(1)
for stream in streams:
if len(stream.hyps) == 0:
stream.hyps.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
B = [stream.hyps for stream in streams]
for t in range(T):
current_encoder_out = encoder_out[:, t]
# current_encoder_out's shape: (batch_size, encoder_out_dim)
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.stack(
[hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1)
# decoder_out is of shape (num_hyps, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, encoder_out_dim)
logits = model.joiner(current_encoder_out, decoder_out)
# logits is of shape (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
streams[i].hyps = B[i]
result = sp.decode(streams[i].decoding_result())
logging.info(f"Partial result {i}:\n{result}")
decoder_out = model.joiner.decoder_proj(decoder_out)
def build_batch(
decode_steams: List[Stream],
chunk_length: int,
segment_length: int,
) -> Tuple[
Optional[torch.Tensor],
Optional[torch.tensor],
Optional[List[Stream]],
]:
"""
Args:
chunk_length:
Number of frames for each chunk. It equals to
``segment_length + right_context_length``.
segment_length
Number of frames for each segment.
Returns:
Return a tuple containing:
- features, a 3-D tensor of shape ``(num_active_streams, T, C)``
- active_streams, a list of active streams. We say a stream is
active when it has enough feature frames to be fed into the
encoder model.
"""
feature_list = []
length_list = []
stream_list = []
for stream in decode_steams:
if len(stream.feature_frames) >= chunk_length:
# this_chunk is a list of tensors, each of which
# has a shape (1, feature_dim)
chunk = stream.feature_frames[:chunk_length]
stream.feature_frames = stream.feature_frames[segment_length:]
features = torch.cat(chunk, dim=0)
feature_list.append(features)
length_list.append(chunk_length)
stream_list.append(stream)
elif stream.done and len(stream.feature_frames) > 0:
chunk = stream.feature_frames[:chunk_length]
stream.feature_frames = []
features = torch.cat(chunk, dim=0)
length_list.append(features.size(0))
features = torch.nn.functional.pad(
features,
(0, 0, 0, chunk_length - features.size(0)),
mode="constant",
value=LOG_EPSILON,
)
feature_list.append(features)
stream_list.append(stream)
if len(feature_list) == 0:
return None, None, None
features = torch.stack(feature_list, dim=0)
lengths = torch.cat(length_list)
return features, lengths, stream_list
def process_features(
def decode_one_chunk(
model: nn.Module,
features: torch.Tensor,
feature_lens: torch.Tensor,
streams: List[Stream],
params: AttributeDict,
sp: spm.SentencePieceProcessor,
) -> None:
"""Process features for each stream in parallel.
) -> List[int]:
device = next(model.parameters()).device
Args:
model:
The RNN-T model.
features:
A 3-D tensor of shape (N, T, C).
streams:
A list of streams of size (N,).
params:
It is the return value of :func:`get_params`.
sp:
The BPE model.
"""
assert features.ndim == 3
assert features.size(0) == len(streams)
assert feature_lens.size(0) == len(streams)
feature_list = []
feature_len_list = []
state_list = []
num_processed_frames_list = []
device = model.device
features = features.to(device)
for stream in streams:
feature, feature_len = stream.get_feature_chunk()
feature_list.append(feature)
feature_len_list.append(feature_len)
state_list.append(stream.states)
num_processed_frames_list.append(stream.num_processed_frames)
state_list = [stream.states for stream in streams]
features = pad_sequence(
feature_list, batch_first=True, padding_value=LOG_EPSILON
).to(device)
feature_lens = torch.tensor(feature_len_list, device=device)
num_processed_frames = torch.tensor(
num_processed_frames_list, device=device
)
# Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa
tail_length = (
3 * params.subsampling_factor + params.right_context_length + 3
)
if features.size(1) < tail_length:
pad_length = tail_length - features.size(1)
feature_lens += pad_length
features = torch.nn.functional.pad(
features,
(0, 0, 0, pad_length),
mode="constant",
value=LOG_EPSILON,
)
# print(features.shape)
# stack states of all streams
states = stack_states(state_list)
encoder_out, encoder_out_lens, states = model.encoder.infer(
features,
feature_lens,
states,
x=features,
x_lens=feature_lens,
states=states,
num_processed_frames=num_processed_frames,
)
encoder_out = model.joiner.encoder_proj(encoder_out)
# update cached states of each stream
state_list = unstack_states(states)
for i, s in enumerate(state_list):
streams[i].states = s
@ -466,26 +310,47 @@ def process_features(
model=model,
streams=streams,
encoder_out=encoder_out,
sp=sp,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=streams,
encoder_out=encoder_out,
sp=sp,
beam=params.beam_size,
)
# elif params.decoding_method == "modified_beam_search":
# modified_beam_search(
# model=model,
# streams=streams,
# encoder_out=encoder_out,
# sp=sp,
# beam=params.beam_size,
# )
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
finished_streams = [i for i, stream in enumerate(streams) if stream.done]
return finished_streams
def create_streaming_feature_extractor() -> Fbank:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
return Fbank(opts)
def decode_dataset(
params: AttributeDict,
cuts: CutSet,
model: nn.Module,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
):
"""Decode dataset.
@ -493,72 +358,126 @@ def decode_dataset(
"""
device = next(model.parameters()).device
# number of frames before subsampling
segment_length = model.encoder.segment_length
right_context_length = model.encoder.right_context_length
# 5 = 3 + 2
# 1) add 3 here since the subsampling method is using
# ((len - 1) // 2 - 1) // 2)
# 2) add 2 here we will drop first and last frame after subsampling
chunk_length = (segment_length + 5) + right_context_length
opts = FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
log_interval = 300
decode_results = []
streams = []
for num, cut in enumerate(cuts):
# Each utterance has a Stream.
stream = Stream(params=params, device=device, LOG_EPS=LOG_EPSILON)
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
fbank = create_streaming_feature_extractor()
feature = fbank(samples)
stream.set_feature(feature)
stream.set_ground_truth(cut.supervisions[0].text)
# Each uttetance has a Stream
stream = Stream(
params=params,
audio_sample=samples,
ground_truth=cut.supervisions[0].text,
device=device,
)
streams.append(stream)
while len(streams) >= params.num_decode_streams:
for stream in streams:
stream.accept_waveform()
# try to build batch
features, active_streams = build_batch(
chunk_length=chunk_length,
segment_length=segment_length,
finished_streams = decode_one_chunk(
model=model,
streams=streams,
params=params,
sp=sp,
)
if features is not None:
process_features(
model=model,
features=features,
streams=active_streams,
params=params,
sp=sp,
)
new_streams = []
for stream in streams:
if stream.done:
decode_results.append(
(
stream.ground_truth.split(),
sp.decode(stream.decoding_result()).split(),
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
streams[i].ground_truth.split(),
sp.decode(streams[i].decoding_result()).split(),
)
else:
new_streams.append(stream)
del streams
streams = new_streams
)
print(decode_results[-1])
del streams[i]
# print("delete", i, len(streams))
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
while len(streams) > 0:
finished_streams = decode_one_chunk(
model=model,
streams=streams,
params=params,
sp=sp,
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
streams[i].ground_truth.split(),
sp.decode(streams[i].decoding_result()).split(),
)
)
del streams[i]
if params.decoding_method == "greedy_search":
return {"greedy_search": decode_results}
else:
return {f"beam_size_{params.beam_size}": decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=sorted(results))
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s) @ torch.no_grad()
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
@ -571,6 +490,32 @@ def main():
# Note: params.decoding_method is currently not used.
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
# for streaming
params.suffix += f"-streaming-chunk-length-{params.chunk_length}"
params.suffix += f"-left-context-length-{params.left_context_length}"
params.suffix += f"-right-context-length-{params.right_context_length}"
params.suffix += f"-memory-size-{params.memory_size}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-streaming-decode")
logging.info("Decoding started")
@ -595,24 +540,83 @@ def main():
logging.info("About to create model")
model = get_transducer_model(params)
if params.avg_last_n > 0:
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
model.device = device
@ -622,42 +626,26 @@ def main():
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
batch_size = 3
test_sets = ["test-clean", "test-other"]
test_cuts = [test_clean_cuts, test_other_cuts]
ground_truth = []
batched_samples = []
for num, cut in enumerate(test_clean_cuts):
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
cuts=test_cut,
model=model,
params=params,
sp=sp,
)
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
samples = torch.from_numpy(audio).squeeze(0)
# batched_samples.append(samples)
# ground_truth.append(cut.supervisions[0].text)
if len(batched_samples) >= batch_size:
decoded_results = decode_batch(
batched_samples=batched_samples,
model=model,
params=params,
sp=sp,
)
s = "\n"
for i, (hyp, ref) in enumerate(zip(decoded_results, ground_truth)):
s += f"hyp {i}:\n{hyp}\n"
s += f"ref {i}:\n{ref}\n\n"
logging.info(s)
batched_samples = []
ground_truth = []
# break after processing the first batch for test purposes
break
logging.info("Done!")
if __name__ == "__main__":

View File

@ -449,7 +449,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
cnn_module_kernel=params.cnn_module_kernel,
left_context_length=params.left_context_length,
right_context_length=params.right_context_length,
max_memory_size=params.memory_size,
memory_size=params.memory_size,
)
return encoder