Fix conv cache error, support async streaming decoding

This commit is contained in:
pkufool 2022-05-29 07:06:44 +08:00
parent 364bccb2e3
commit b23db42486
8 changed files with 1679 additions and 91 deletions

View File

@ -0,0 +1,118 @@
# Copyright 2022 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple
import k2
import torch
from icefall.utils import AttributeDict
class DecodeStream(object):
def __init__(
self,
params: AttributeDict,
decoding_graph: Optional[k2.Fsa] = None,
device: torch.device = torch.device("cpu"),
) -> None:
"""
Args:
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
device:
The device to run this stream.
"""
if decoding_graph is not None:
assert device == decoding_graph.device
# It contains a 2-D tensors representing the feature frames.
self.features: torch.Tensor = None
# how many frames are processed. (before subsampling).
self.num_processed_frames: int = 0
self._done: bool = False
# The transcript of current utterance.
self.ground_truth: str = ""
# The decoding result (partial or final) of current utterance.
self.hyp: List = []
if params.decoding_method == "greedy_search":
self.hyp = [params.blank_id] * params.context_size
elif params.decoding_method == "fast_beam_search":
# feature_len is needed to get partial results.
self.feature_len: int = 0
# The rnnt_decoding_stream for fast_beam_search.
self.rnnt_decoding_stream: k2.RnntDecodingStream = (
k2.RnntDecodingStream(decoding_graph)
)
else:
assert (
False
), f"Decoding method :{params.decoding_method} do not support"
# The caches for streaming conformer
# It is a List containing two tensors, the first one is the cache for
# attention which has a shape of
# (num_encoder_layers, left_context, encoder_dim),
# the second one is the cache of conv_module which has a shape of
# (num_encoder_layers, cnn_module_kernel - 1, encoder_dim).
self.states: List[torch.Tensor] = [
torch.zeros(
(
params.num_encoder_layers,
params.left_context,
params.encoder_dim,
),
device=device,
),
torch.zeros(
(
params.num_encoder_layers,
params.cnn_module_kernel - 1,
params.encoder_dim,
),
device=device,
),
]
@property
def done(self) -> bool:
"""Return True if all the features are processed."""
return self._done
def set_features(
self,
features: torch.Tensor,
) -> None:
"""Set features tensor of current utterance."""
self.features = features
def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]:
"""Consume chunk_size frames of features"""
ret_chunk_size = min(
self.features.size(0) - self.num_processed_frames, chunk_size + 3
)
ret_features = self.features[
self.num_processed_frames : self.num_processed_frames
+ ret_chunk_size,
:,
]
self.num_processed_frames += chunk_size
if self.num_processed_frames >= self.features.size(0):
self._done = True
return ret_features, ret_chunk_size

View File

@ -0,0 +1,698 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./pruned_transducer_stateless2/streaming_decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--decoding_method greedy_search \
--num-decode-streams 200
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from decode_stream import DecodeStream
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
from icefall.decode import Nbest, one_best_decoding
from icefall.utils import get_texts
from kaldifeat import FbankOptions, Fbank
from lhotse import CutSet
from train import get_params, get_transducer_model
from torch.nn.utils.rnn import pad_sequence
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Support only greedy_search and fast_beam_search now.
""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=32,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="""How many left context can be seen in chunks when calculating attention.
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=True,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded parallel.",
)
return parser
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
) -> List[List[int]]:
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :]
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
hyp_tokens = []
for stream in streams:
hyp_tokens.append(stream.hyp)
return hyp_tokens
def fast_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
decoding_streams: k2.RnntDecodingStreams,
) -> List[List[int]]:
B, T, C = encoder_out.shape
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyp_tokens = get_texts(best_path)
return hyp_tokens
def decode_one_chunk(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decode_streams: List[DecodeStream],
) -> List[int]:
"""Decode one chunk frames of features for each decode_streams and
return the indexes of finished streams in a List.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decode_streams:
A List of DecodeStream, each belonging to a utterance.
Returns:
Return a List containing which DecodeStreams are finished.
"""
device = model.device
features = []
feature_lens = []
states = []
rnnt_stream_list = []
processed_feature_lens = []
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(
params.decode_chunk_size * params.subsampling_factor
)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
if params.decoding_method == "fast_beam_search":
processed_feature_lens.append(stream.feature_len)
rnnt_stream_list.append(stream.rnnt_decoding_stream)
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
# if T is less than 7 there will be an error in time reduction layer,
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
if features.size(1) < 7:
feature_lens += 7 - features.size(1)
features = torch.cat(
[
features,
torch.tensor(
LOG_EPS, dtype=features.dtype, device=device
).expand(
features.size(0), 7 - features.size(1), features.size(2)
),
],
dim=1,
)
states = [
torch.stack([x[0] for x in states], dim=2),
torch.stack([x[1] for x in states], dim=2),
]
# Note: states will be modified in streaming_forward.
encoder_out, encoder_out_lens = model.encoder.streaming_forward(
x=features,
x_lens=feature_lens,
states=states,
left_context=params.left_context,
)
if params.decoding_method == "greedy_search":
hyp_tokens = greedy_search(model, encoder_out, decode_streams)
elif params.decoding_method == "fast_beam_search":
config = k2.RnntDecodingConfig(
vocab_size=params.vocab_size,
decoder_history_len=params.context_size,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
processed_feature_lens = torch.tensor(
processed_feature_lens, device=device
)
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
processed_lens = processed_feature_lens + encoder_out_lens
hyp_tokens = fast_beam_search(
model, encoder_out, processed_lens, decoding_streams
)
else:
assert False
states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
finished_streams = []
for i in range(len(decode_streams)):
decode_streams[i].states = [states[0][i], states[1][i]]
if params.decoding_method == "fast_beam_search":
decode_streams[i].feature_len += encoder_out_lens[i]
decode_streams[i].hyp = hyp_tokens[i]
if decode_streams[i].done:
finished_streams.append(i)
return finished_streams
def decode_dataset(
cuts: CutSet,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
cuts:
Lhotse Cutset containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
device = model.device
opts = FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
log_interval = 300
decode_results = []
# Contain decode streams currently running.
decode_streams = []
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
decode_stream = DecodeStream(
params=params,
decoding_graph=decoding_graph,
device=device,
)
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts)
decode_stream.set_features(fbank(samples.to(device)))
decode_stream.ground_truth = cut.supervisions[0].text
decode_streams.append(decode_stream)
while len(decode_streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
params, model, sp, decode_streams
)
for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :]
decode_results.append(
(
decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(),
)
)
del decode_streams[i]
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
# decode final chunks of last sequences
while len(decode_streams):
finished_streams = decode_one_chunk(params, model, sp, decode_streams)
for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :]
decode_results.append(
(
decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(),
)
)
del decode_streams[i]
key = "greedy_search"
if params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
)
return {key: decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
# for streaming
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.to(device)
model.eval()
model.device = device
decoding_graph = None
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_sets = ["test-clean", "test-other"]
test_cuts = [test_clean_cuts, test_other_cuts]
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
cuts=test_cut,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -330,7 +330,7 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model. - subsampling_factor: The subsampling factor for the model.
- attention_dim: Hidden dim for multi-head attention model. - encoder_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder. - num_decoder_layers: Number of decoder layer of transformer decoder.
@ -350,10 +350,11 @@ def get_params() -> AttributeDict:
# parameters for conformer # parameters for conformer
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, "subsampling_factor": 4,
"attention_dim": 512, "encoder_dim": 512,
"nhead": 8, "nhead": 8,
"dim_feedforward": 2048, "dim_feedforward": 2048,
"num_encoder_layers": 12, "num_encoder_layers": 12,
"cnn_module_kernel": 31,
"vgg_frontend": False, "vgg_frontend": False,
# parameters for decoder # parameters for decoder
"embedding_dim": 512, "embedding_dim": 512,
@ -372,10 +373,11 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
num_features=params.feature_dim, num_features=params.feature_dim,
output_dim=params.vocab_size, output_dim=params.vocab_size,
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim, d_model=params.encoder_dim,
nhead=params.nhead, nhead=params.nhead,
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
cnn_module_kernel=params.cnn_module_kernel,
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
dynamic_chunk_training=params.dynamic_chunk_training, dynamic_chunk_training=params.dynamic_chunk_training,
short_chunk_size=params.short_chunk_size, short_chunk_size=params.short_chunk_size,
@ -862,7 +864,7 @@ def run(rank, world_size, args):
optimizer = Noam( optimizer = Noam(
model.parameters(), model.parameters(),
model_size=params.attention_dim, model_size=params.encoder_dim,
factor=params.lr_factor, factor=params.lr_factor,
warm_step=params.warm_step, warm_step=params.warm_step,
) )

View File

@ -19,7 +19,7 @@ import logging
import copy import copy
import math import math
import warnings import warnings
from typing import Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
@ -104,11 +104,12 @@ class Conformer(EncoderInterface):
self.encoder_layers = num_encoder_layers self.encoder_layers = num_encoder_layers
self.d_model = d_model self.d_model = d_model
self.cnn_module_kernel = cnn_module_kernel
self.causal = causal
self.dynamic_chunk_training = dynamic_chunk_training self.dynamic_chunk_training = dynamic_chunk_training
self.short_chunk_threshold = short_chunk_threshold self.short_chunk_threshold = short_chunk_threshold
self.short_chunk_size = short_chunk_size self.short_chunk_size = short_chunk_size
self.num_left_chunks = num_left_chunks self.num_left_chunks = num_left_chunks
self.causal = causal
self.encoder_pos = RelPositionalEncoding(d_model, dropout) self.encoder_pos = RelPositionalEncoding(d_model, dropout)
@ -192,11 +193,11 @@ class Conformer(EncoderInterface):
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
warmup: float = 1.0, warmup: float = 1.0,
states: Optional[Tensor] = None, states: Optional[List[Tensor]] = None,
chunk_size: int = 16, chunk_size: int = 16,
left_context: int = 64, left_context: int = 64,
simulate_streaming: bool = False, simulate_streaming: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
x: x:
@ -210,8 +211,11 @@ class Conformer(EncoderInterface):
to turn modules on sequentially. to turn modules on sequentially.
states: states:
The decode states for previous frames which contains the cached data. The decode states for previous frames which contains the cached data.
It has a shape of (2, encoder_layers, left_context, batch, attention_dim), It has two elements, the first element is the attn_cache which has
states[0,...] is the attn_cache, states[1,...] is the conv_cache. a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: If not None, states will be modified in this function.
chunk_size: chunk_size:
The chunk size for decoding, this will be used to simulate streaming The chunk size for decoding, this will be used to simulate streaming
decoding using masking. decoding using masking.
@ -234,7 +238,11 @@ class Conformer(EncoderInterface):
# x: [N, T, C] # x: [N, T, C]
# Caution: We assume the subsampling factor is 4! # Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
#
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 1) >> 1) - 1) >> 1
if not simulate_streaming: if not simulate_streaming:
assert ( assert (
@ -242,11 +250,14 @@ class Conformer(EncoderInterface):
), "Require cache when sending data in streaming mode" ), "Require cache when sending data in streaming mode"
assert ( assert (
states.shape == (2, self.encoder_layers, left_context, x.size(0), self.d_model) len(states) == 2 and
), f"""The shape of states MUST be equal to states[0].shape == (self.encoder_layers, left_context, x.size(0), self.d_model) and
(2, encoder_layers, left_context, batch, d_model) which is states[1].shape == (self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)
{(2, self.encoder_layers, left_context, x.size(0), self.d_model)} ), f"""The length of states MUST be equal to 2, and the shape of
given {states.shape}.""" first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)},
given {states[0].shape}. the shape of second element should be
{(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
given {states[1].shape}."""
src_key_padding_mask = make_pad_mask(lengths + left_context) src_key_padding_mask = make_pad_mask(lengths + left_context)
@ -254,7 +265,7 @@ class Conformer(EncoderInterface):
embed, pos_enc = self.encoder_pos(embed, left_context) embed, pos_enc = self.encoder_pos(embed, left_context)
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
x, states = self.encoder( x = self.encoder(
embed, embed,
pos_enc, pos_enc,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
@ -284,7 +295,7 @@ class Conformer(EncoderInterface):
num_left_chunks=num_left_chunks, num_left_chunks=num_left_chunks,
device=x.device device=x.device
) )
x, _ = self.encoder( x = self.encoder(
x, x,
pos_emb, pos_emb,
mask=mask, mask=mask,
@ -294,7 +305,7 @@ class Conformer(EncoderInterface):
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, lengths, states return x, lengths
class ConformerEncoderLayer(nn.Module): class ConformerEncoderLayer(nn.Module):
@ -376,9 +387,9 @@ class ConformerEncoderLayer(nn.Module):
src_mask: Optional[Tensor] = None, src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0, warmup: float = 1.0,
states: Optional[Tensor] = None, states: Optional[List[Tensor]] = None,
left_context: int = 0, left_context: int = 0,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor]:
""" """
Pass the input through the encoder layer. Pass the input through the encoder layer.
@ -391,8 +402,11 @@ class ConformerEncoderLayer(nn.Module):
bypass layers more frequently. bypass layers more frequently.
states: states:
The decode states for previous frames which contains the cached data. The decode states for previous frames which contains the cached data.
It has a shape of (2, encoder_layers, left_context, batch, attention_dim), It has two elements, the first element is the attn_cache which has
states[0,...] is the attn_cache, states[1,...] is the conv_cache. a shape of (left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(cnn_module_kernel-1, batch, conv_dim).
Note: If not None, states will be modified in this function.
left_context: left context (in frames) used during streaming decoding. left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances, this is used only in real streaming decoding, in other circumstances,
it MUST be 0. it MUST be 0.
@ -425,9 +439,9 @@ class ConformerEncoderLayer(nn.Module):
val = src val = src
if not self.training and states is not None: if not self.training and states is not None:
# src: [chunk_size, N, F] e.g. [8, 41, 512] # src: [chunk_size, N, F] e.g. [8, 41, 512]
key = torch.cat([states[0, ...], src], dim=0) key = torch.cat([states[0], src], dim=0)
val = key val = key
states[0, ...] = key[-left_context:, ...] states[0] = key[-left_context:, ...]
else: else:
assert left_context == 0 assert left_context == 0
@ -445,15 +459,12 @@ class ConformerEncoderLayer(nn.Module):
src = src + self.dropout(src_att) src = src + self.dropout(src_att)
# convolution module # convolution module
residual = src
if not self.training and states is not None: if not self.training and states is not None:
src = torch.cat([states[1, ...], src], dim=0) conv, conv_cache = self.conv_module(src, states[1])
states[1, ...] = src[-left_context:, ...] states[1] = conv_cache
else:
conv = self.conv_module(src) conv = self.conv_module(src)
conv = conv[-residual.size(0) :, :, :] # noqa: E203 src = src + self.dropout(conv)
src = residual + self.dropout(conv)
# feed forward module # feed forward module
src = src + self.dropout(self.feed_forward(src)) src = src + self.dropout(self.feed_forward(src))
@ -463,7 +474,7 @@ class ConformerEncoderLayer(nn.Module):
if alpha != 1.0: if alpha != 1.0:
src = alpha * src + (1 - alpha) * src_orig src = alpha * src + (1 - alpha) * src_orig
return src, states return src
class ConformerEncoder(nn.Module): class ConformerEncoder(nn.Module):
@ -495,9 +506,9 @@ class ConformerEncoder(nn.Module):
mask: Optional[Tensor] = None, mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0, warmup: float = 1.0,
states: Optional[Tensor] = None, states: Optional[List[Tensor]] = None,
left_context: int = 0, left_context: int = 0,
) -> Tuple[Tensor, Tensor]: ) -> Tensor:
r"""Pass the input through the encoder layers in turn. r"""Pass the input through the encoder layers in turn.
Args: Args:
@ -509,8 +520,11 @@ class ConformerEncoder(nn.Module):
bypass layers more frequently. bypass layers more frequently.
states: states:
The decode states for previous frames which contains the cached data. The decode states for previous frames which contains the cached data.
It has a shape of (2, encoder_layers, left_context, batch, attention_dim), It has two elements, the first element is the attn_cache which has
states[0,...] is the attn_cache, states[1,...] is the conv_cache. a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: If not None, states will be modified in this function.
left_context: left context (in frames) used during streaming decoding. left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances, this is used only in real streaming decoding, in other circumstances,
it MUST be 0. it MUST be 0.
@ -532,19 +546,21 @@ class ConformerEncoder(nn.Module):
assert left_context >= 0 assert left_context >= 0
for layer_index, mod in enumerate(self.layers): for layer_index, mod in enumerate(self.layers):
output, cache = mod( cache = None if states is None else [states[0][layer_index], states[1][layer_index]]
output = mod(
output, output,
pos_emb, pos_emb,
src_mask=mask, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
warmup=warmup, warmup=warmup,
states=None if states is None else states[:, layer_index, ...], states=cache,
left_context=left_context, left_context=left_context,
) )
if states is not None: if states is not None:
states[:, layer_index, ...] = cache states[0][layer_index] = cache[0]
states[1][layer_index] = cache[1]
return output, states return output
class RelPositionalEncoding(torch.nn.Module): class RelPositionalEncoding(torch.nn.Module):
@ -1180,14 +1196,23 @@ class ConvolutionModule(nn.Module):
initial_scale=0.25, initial_scale=0.25,
) )
def forward(self, x: Tensor) -> Tensor: def forward(
self,
x: Tensor,
cache: Optional[Tensor] = None
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Compute convolution module. """Compute convolution module.
Args: Args:
x: Input tensor (#time, batch, channels). x: Input tensor (#time, batch, channels).
cache: The cache of depthwise_conv, only used in real streaming
decoding.
Returns: Returns:
Tensor: Output tensor (#time, batch, channels). If cache is None return the output tensor (#time, batch, channels).
If cache is not None, return a tuple of Tensor, the first one is
the output tensor (#time, batch, channels), the second one is the
new cache for next chunk (#kernel_size - 1, batch, channels).
""" """
# exchange the temporal dimension and the feature dimension # exchange the temporal dimension and the feature dimension
@ -1201,9 +1226,15 @@ class ConvolutionModule(nn.Module):
# 1D Depthwise Conv # 1D Depthwise Conv
if self.causal and self.lorder > 0: if self.causal and self.lorder > 0:
# Make depthwise_conv causal by if cache is None:
# manualy padding self.lorder zeros to the left # Make depthwise_conv causal by
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) # manualy padding self.lorder zeros to the left
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
else:
assert not self.training, "Cache should be None in training time"
assert cache.size(0) == self.lorder
x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
cache = x.permute(2, 0, 1)[-self.lorder:,...]
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
x = self.deriv_balancer2(x) x = self.deriv_balancer2(x)
@ -1211,7 +1242,7 @@ class ConvolutionModule(nn.Module):
x = self.pointwise_conv2(x) # (batch, channel, time) x = self.pointwise_conv2(x) # (batch, channel, time)
return x.permute(2, 0, 1) return x.permute(2, 0, 1) if cache is None else (x.permute(2, 0, 1), cache)
class Conv2dSubsampling(nn.Module): class Conv2dSubsampling(nn.Module):

View File

@ -0,0 +1 @@
../pruned_transducer_stateless/decode_stream.py

View File

@ -0,0 +1,705 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./pruned_transducer_stateless2/streaming_decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--decoding_method greedy_search \
--num-decode-streams 200
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from decode_stream import DecodeStream
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
from icefall.decode import Nbest, one_best_decoding
from icefall.utils import get_texts
from kaldifeat import FbankOptions, Fbank
from lhotse import CutSet
from train import get_params, get_transducer_model
from torch.nn.utils.rnn import pad_sequence
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Support only greedy_search and fast_beam_search now.
""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=32,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="""How many left context can be seen in chunks when calculating attention.
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=True,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded parallel.",
)
return parser
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
) -> List[List[int]]:
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# logging.info(f"decoder_out shape : {decoder_out.shape}")
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :]
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
decoder_out = model.joiner.decoder_proj(decoder_out)
hyp_tokens = []
for stream in streams:
hyp_tokens.append(stream.hyp)
return hyp_tokens
def fast_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
decoding_streams: k2.RnntDecodingStreams,
) -> List[List[int]]:
B, T, C = encoder_out.shape
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyp_tokens = get_texts(best_path)
return hyp_tokens
def decode_one_chunk(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decode_streams: List[DecodeStream],
) -> List[int]:
"""Decode one chunk frames of features for each decode_streams and
return the indexes of finished streams in a List.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decode_streams:
A List of DecodeStream, each belonging to a utterance.
Returns:
Return a List containing which DecodeStreams are finished.
"""
device = model.device
features = []
feature_lens = []
states = []
rnnt_stream_list = []
processed_feature_lens = []
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(
params.decode_chunk_size * params.subsampling_factor
)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
if params.decoding_method == "fast_beam_search":
processed_feature_lens.append(stream.feature_len)
rnnt_stream_list.append(stream.rnnt_decoding_stream)
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
# if T is less than 7 there will be an error in time reduction layer,
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
if features.size(1) < 7:
feature_lens += 7 - features.size(1)
features = torch.cat(
[
features,
torch.tensor(
LOG_EPS, dtype=features.dtype, device=device
).expand(
features.size(0), 7 - features.size(1), features.size(2)
),
],
dim=1,
)
states = [
torch.stack([x[0] for x in states], dim=2),
torch.stack([x[1] for x in states], dim=2),
]
# Note: states will be modified in streaming_forward.
encoder_out, encoder_out_lens = model.encoder.streaming_forward(
x=features,
x_lens=feature_lens,
states=states,
left_context=params.left_context,
)
encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search":
hyp_tokens = greedy_search(model, encoder_out, decode_streams)
elif params.decoding_method == "fast_beam_search":
config = k2.RnntDecodingConfig(
vocab_size=params.vocab_size,
decoder_history_len=params.context_size,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
processed_feature_lens = torch.tensor(
processed_feature_lens, device=device
)
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
processed_lens = processed_feature_lens + encoder_out_lens
hyp_tokens = fast_beam_search(
model, encoder_out, processed_lens, decoding_streams
)
else:
assert False
states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
finished_streams = []
for i in range(len(decode_streams)):
decode_streams[i].states = [states[0][i], states[1][i]]
if params.decoding_method == "fast_beam_search":
decode_streams[i].feature_len += encoder_out_lens[i]
decode_streams[i].hyp = hyp_tokens[i]
if decode_streams[i].done:
finished_streams.append(i)
return finished_streams
def decode_dataset(
cuts: CutSet,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
cuts:
Lhotse Cutset containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
device = model.device
opts = FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
log_interval = 300
decode_results = []
# Contain decode streams currently running.
decode_streams = []
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
decode_stream = DecodeStream(
params=params,
decoding_graph=decoding_graph,
device=device,
)
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts)
decode_stream.set_features(fbank(samples.to(device)))
decode_stream.ground_truth = cut.supervisions[0].text
decode_streams.append(decode_stream)
while len(decode_streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
params, model, sp, decode_streams
)
for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :]
decode_results.append(
(
decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(),
)
)
del decode_streams[i]
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
# decode final chunks of last sequences
while len(decode_streams):
finished_streams = decode_one_chunk(params, model, sp, decode_streams)
for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :]
decode_results.append(
(
decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(),
)
)
del decode_streams[i]
key = "greedy_search"
if params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
)
return {key: decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
# for streaming
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.to(device)
model.eval()
model.device = device
decoding_graph = None
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_sets = ["test-clean", "test-other"]
test_cuts = [test_clean_cuts, test_other_cuts]
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
cuts=test_cut,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -393,6 +393,7 @@ def get_params() -> AttributeDict:
"nhead": 8, "nhead": 8,
"dim_feedforward": 2048, "dim_feedforward": 2048,
"num_encoder_layers": 12, "num_encoder_layers": 12,
"cnn_module_kernel": 31,
# parameters for decoder # parameters for decoder
"decoder_dim": 512, "decoder_dim": 512,
# parameters for joiner # parameters for joiner
@ -415,6 +416,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
nhead=params.nhead, nhead=params.nhead,
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
cnn_module_kernel=params.cnn_module_kernel,
dynamic_chunk_training=params.dynamic_chunk_training, dynamic_chunk_training=params.dynamic_chunk_training,
short_chunk_size=params.short_chunk_size, short_chunk_size=params.short_chunk_size,
num_left_chunks=params.num_left_chunks, num_left_chunks=params.num_left_chunks,

View File

@ -18,7 +18,7 @@
import copy import copy
import math import math
import warnings import warnings
from typing import Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -97,11 +97,13 @@ class Conformer(Transformer):
self.encoder_layers = num_encoder_layers self.encoder_layers = num_encoder_layers
self.d_model = d_model self.d_model = d_model
self.cnn_module_kernel = cnn_module_kernel
self.causal = causal
self.dynamic_chunk_training = dynamic_chunk_training self.dynamic_chunk_training = dynamic_chunk_training
self.short_chunk_threshold = short_chunk_threshold self.short_chunk_threshold = short_chunk_threshold
self.short_chunk_size = short_chunk_size self.short_chunk_size = short_chunk_size
self.num_left_chunks = num_left_chunks self.num_left_chunks = num_left_chunks
self.causal = causal
self.encoder_pos = RelPositionalEncoding(d_model, dropout) self.encoder_pos = RelPositionalEncoding(d_model, dropout)
@ -188,11 +190,11 @@ class Conformer(Transformer):
self, self,
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
states: Optional[torch.Tensor] = None, states: Optional[List[torch.Tensor]] = None,
chunk_size: int = 16, chunk_size: int = 16,
left_context: int = 64, left_context: int = 64,
simulate_streaming: bool = False, simulate_streaming: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
x: x:
@ -202,8 +204,11 @@ class Conformer(Transformer):
`x` before padding. `x` before padding.
states: states:
The decode states for previous frames which contains the cached data. The decode states for previous frames which contains the cached data.
It has a shape of (2, encoder_layers, left_context, batch, attention_dim), It has two elements, the first element is the attn_cache which has
states[0,...] is the attn_cache, states[1,...] is the conv_cache. a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: If not None, states will be modified in this function.
chunk_size: chunk_size:
The chunk size for decoding, this will be used to simulate streaming The chunk size for decoding, this will be used to simulate streaming
decoding using masking. decoding using masking.
@ -226,7 +231,11 @@ class Conformer(Transformer):
# x: [N, T, C] # x: [N, T, C]
# Caution: We assume the subsampling factor is 4! # Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
#
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 1) >> 1) - 1) >> 1
if not simulate_streaming: if not simulate_streaming:
assert ( assert (
@ -234,11 +243,14 @@ class Conformer(Transformer):
), "Require cache when sending data in streaming mode" ), "Require cache when sending data in streaming mode"
assert ( assert (
states.shape == (2, self.encoder_layers, left_context, x.size(0), self.d_model) len(states) == 2 and
), f"""The shape of states MUST be equal to states[0].shape == (self.encoder_layers, left_context, x.size(0), self.d_model) and
(2, encoder_layers, left_context, batch, d_model) which is states[1].shape == (self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)
{(2, self.encoder_layers, left_context, x.size(0), self.d_model)} ), f"""The length of states MUST be equal to 2, and the shape of
given {states.shape}.""" first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)},
given {states[0].shape}. the shape of second element should be
{(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
given {states[1].shape}."""
src_key_padding_mask = make_pad_mask(lengths + left_context) src_key_padding_mask = make_pad_mask(lengths + left_context)
@ -246,7 +258,7 @@ class Conformer(Transformer):
embed, pos_enc = self.encoder_pos(embed, left_context) embed, pos_enc = self.encoder_pos(embed, left_context)
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
x, states = self.encoder( x = self.encoder(
embed, embed,
pos_enc, pos_enc,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
@ -275,7 +287,7 @@ class Conformer(Transformer):
num_left_chunks=num_left_chunks, num_left_chunks=num_left_chunks,
device=x.device device=x.device
) )
x, _ = self.encoder( x = self.encoder(
x, x,
pos_emb, pos_emb,
mask=mask, mask=mask,
@ -288,7 +300,7 @@ class Conformer(Transformer):
logits = self.encoder_output_layer(x) logits = self.encoder_output_layer(x)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return logits, lengths, states return logits, lengths
class ConformerEncoderLayer(nn.Module): class ConformerEncoderLayer(nn.Module):
@ -369,9 +381,9 @@ class ConformerEncoderLayer(nn.Module):
pos_emb: Tensor, pos_emb: Tensor,
src_mask: Optional[Tensor] = None, src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
states: Optional[Tensor] = None, states: Optional[List[Tensor]] = None,
left_context: int = 0, left_context: int = 0,
) -> Tuple[Tensor, Tensor]: ) -> Tensor:
""" """
Pass the input through the encoder layer. Pass the input through the encoder layer.
@ -380,9 +392,13 @@ class ConformerEncoderLayer(nn.Module):
pos_emb: Positional embedding tensor (required). pos_emb: Positional embedding tensor (required).
src_mask: the mask for the src sequence (optional). src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
states: The decode states for previous frames which contains the cached data. states:
It has a shape of (2, left_context, batch, attention_dim), The decode states for previous frames which contains the cached data.
states[0,...] is the attn_cache, states[1,...] is the conv_cache. It has two elements, the first element is the attn_cache which has
a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: If not None, states will be modified in this function.
left_context: left context (in frames) used during streaming decoding. left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances, this is used only in real streaming decoding, in other circumstances,
it MUST be 0. it MUST be 0.
@ -413,9 +429,9 @@ class ConformerEncoderLayer(nn.Module):
val = src val = src
if not self.training and states is not None: if not self.training and states is not None:
# src: [chunk_size, N, F] e.g. [8, 41, 512] # src: [chunk_size, N, F] e.g. [8, 41, 512]
key = torch.cat([states[0, ...], src], dim=0) key = torch.cat([states[0], src], dim=0)
val = key val = key
states[0, ...] = key[-left_context:, ...] states[0] = key[-left_context:, ...]
else: else:
assert left_context == 0 assert left_context == 0
@ -438,13 +454,12 @@ class ConformerEncoderLayer(nn.Module):
src = self.norm_conv(src) src = self.norm_conv(src)
if not self.training and states is not None: if not self.training and states is not None:
src = torch.cat([states[1, ...], src], dim=0) src, conv_cache = self.conv_module(src, states[1])
states[1, ...] = src[-left_context:, ...] states[1] = conv_cache
else:
src = self.conv_module(src) src = self.conv_module(src)
src = src[-residual.size(0) :, :, :] # noqa: E203
src = residual + self.dropout(src) src = residual + self.dropout(src)
if not self.normalize_before: if not self.normalize_before:
src = self.norm_conv(src) src = self.norm_conv(src)
@ -459,7 +474,7 @@ class ConformerEncoderLayer(nn.Module):
if self.normalize_before: if self.normalize_before:
src = self.norm_final(src) src = self.norm_final(src)
return src, states return src
class ConformerEncoder(nn.Module): class ConformerEncoder(nn.Module):
@ -490,9 +505,9 @@ class ConformerEncoder(nn.Module):
pos_emb: Tensor, pos_emb: Tensor,
mask: Optional[Tensor] = None, mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
states: Optional[Tensor] = None, states: Optional[List[Tensor]] = None,
left_context: int = 0, left_context: int = 0,
) -> Tuple[Tensor, Tensor]: ) -> Tensor:
r"""Pass the input through the encoder layers in turn. r"""Pass the input through the encoder layers in turn.
Args: Args:
@ -500,9 +515,13 @@ class ConformerEncoder(nn.Module):
pos_emb: Positional embedding tensor (required). pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional). mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
states: The decode states for previous frames which contains the cached data. states:
It has a shape of (2, encoder_layers, left_context, batch, attention_dim), The decode states for previous frames which contains the cached data.
states[0,...] is the attn_cache, states[1,...] is the conv_cache. It has two elements, the first element is the attn_cache which has
a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: If not None, states will be modified in this function.
left_context: left context (in frames) used during streaming decoding. left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances, this is used only in real streaming decoding, in other circumstances,
it MUST be 0. it MUST be 0.
@ -525,18 +544,20 @@ class ConformerEncoder(nn.Module):
assert left_context >= 0 assert left_context >= 0
for layer_index, mod in enumerate(self.layers): for layer_index, mod in enumerate(self.layers):
output, cache = mod( cache = None if states is None else [states[0][layer_index], states[1][layer_index]]
output = mod(
output, output,
pos_emb, pos_emb,
src_mask=mask, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
states=None if states is None else states[:,layer_index, ...], states=cache,
left_context=left_context, left_context=left_context,
) )
if states is not None: if states is not None:
states[:, layer_index, ...] = cache states[0][layer_index] = cache[0]
states[1][layer_index] = cache[1]
return output, states return output
class RelPositionalEncoding(torch.nn.Module): class RelPositionalEncoding(torch.nn.Module):
@ -1146,7 +1167,11 @@ class ConvolutionModule(nn.Module):
) )
self.activation = Swish() self.activation = Swish()
def forward(self, x: Tensor) -> Tensor: def forward(
self,
x: Tensor,
cache: Optional[Tensor] = None
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Compute convolution module. """Compute convolution module.
Args: Args:
@ -1165,9 +1190,15 @@ class ConvolutionModule(nn.Module):
# 1D Depthwise Conv # 1D Depthwise Conv
if self.causal and self.lorder > 0: if self.causal and self.lorder > 0:
# Make depthwise_conv causal by if cache is None:
# manualy padding self.lorder zeros to the left # Make depthwise_conv causal by
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) # manualy padding self.lorder zeros to the left
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
else:
assert not self.training, "Cache should be None in training time"
assert cache.size(0) == self.lorder
x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
cache = x.permute(2, 0, 1)[-self.lorder:,...]
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
# x is (batch, channels, time) # x is (batch, channels, time)
@ -1179,7 +1210,7 @@ class ConvolutionModule(nn.Module):
x = self.pointwise_conv2(x) # (batch, channel, time) x = self.pointwise_conv2(x) # (batch, channel, time)
return x.permute(2, 0, 1) return x.permute(2, 0, 1) if cache is None else (x.permute(2, 0, 1), cache)
class Swish(torch.nn.Module): class Swish(torch.nn.Module):