Minor fixes

This commit is contained in:
pkufool 2022-06-06 17:56:22 +08:00
parent 3aacf75652
commit 09b0c54983
14 changed files with 1819 additions and 56 deletions

View File

@ -26,11 +26,15 @@ class DecodeStream(object):
def __init__(
self,
params: AttributeDict,
initial_states: List[torch.Tensor],
decoding_graph: Optional[k2.Fsa] = None,
device: torch.device = torch.device("cpu"),
) -> None:
"""
Args:
initial_states:
Initial decode states of the model, e.g. the return value of
`get_init_state` in conformer.py
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
device:
@ -41,6 +45,8 @@ class DecodeStream(object):
self.params = params
self.states = initial_states
# It contains a 2-D tensors representing the feature frames.
self.features: torch.Tensor = None
# how many frames are processed. (before subsampling).
@ -56,7 +62,6 @@ class DecodeStream(object):
if params.decoding_method == "greedy_search":
self.hyp = [params.blank_id] * params.context_size
elif params.decoding_method == "fast_beam_search":
# feature_len is needed to get partial results.
# The rnnt_decoding_stream for fast_beam_search.
self.rnnt_decoding_stream: k2.RnntDecodingStream = (
k2.RnntDecodingStream(decoding_graph)
@ -66,31 +71,6 @@ class DecodeStream(object):
False
), f"Decoding method :{params.decoding_method} do not support"
# The caches for streaming conformer
# It is a List containing two tensors, the first one is the cache for
# attention which has a shape of
# (num_encoder_layers, left_context, encoder_dim),
# the second one is the cache of conv_module which has a shape of
# (num_encoder_layers, cnn_module_kernel - 1, encoder_dim).
self.states: List[torch.Tensor] = [
torch.zeros(
(
params.num_encoder_layers,
params.left_context,
params.encoder_dim,
),
device=device,
),
torch.zeros(
(
params.num_encoder_layers,
params.cnn_module_kernel - 1,
params.encoder_dim,
),
device=device,
),
]
@property
def done(self) -> bool:
"""Return True if all the features are processed."""
@ -105,6 +85,8 @@ class DecodeStream(object):
def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]:
"""Consume chunk_size frames of features"""
# plus 3 here because we subsampling features with
# lengths = ((x_lens - 1) // 2 - 1) // 2
ret_chunk_size = min(
self.features.size(0) - self.num_processed_frames, chunk_size + 3
)

View File

@ -378,7 +378,7 @@ def decode_one_chunk(
]
# Note: states will be modified in streaming_forward.
encoder_out, encoder_out_lens = model.encoder.streaming_forward(
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
x=features,
x_lens=feature_lens,
states=states,
@ -462,10 +462,12 @@ def decode_dataset(
decode_results = []
# Contain decode streams currently running.
decode_streams = []
initial_states = model.get_init_state(params.left_context, device=device)
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
decode_stream = DecodeStream(
params=params,
initial_states=initial_states,
decoding_graph=decoding_graph,
device=device,
)

View File

@ -353,7 +353,6 @@ def get_params() -> AttributeDict:
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"cnn_module_kernel": 31,
"vgg_frontend": False,
# parameters for decoder
"embedding_dim": 512,
@ -376,7 +375,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
cnn_module_kernel=params.cnn_module_kernel,
vgg_frontend=params.vgg_frontend,
dynamic_chunk_training=params.dynamic_chunk_training,
short_chunk_size=params.short_chunk_size,

View File

@ -122,6 +122,7 @@ class Conformer(EncoderInterface):
causal,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self._init_state = torch.jit.Attribute([], List[torch.Tensor])
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
@ -194,6 +195,55 @@ class Conformer(EncoderInterface):
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, lengths
@torch.jit.export
def get_init_state(
self, left_context: int, device: torch.device
) -> List[torch.Tensor]:
"""Return the initial cache state of the model.
Args:
left_context: The left context size (in frames after subsampling).
Returns:
Return the initial state of the model, it is a list containing two
tensors, the first one is the cache for attentions which has a shape
of (num_encoder_layers, left_context, encoder_dim), the second one
is the cache of conv_modules which has a shape of
(num_encoder_layers, cnn_module_kernel - 1, encoder_dim).
NOTE: the returned tensors are on the given device.
"""
if (
len(self._init_state) == 2
and self._init_state[0].size(1) == left_context
):
# Note: It is OK to share the init state as it is
# not going to be modified by the model
return self._init_state
init_states: List[torch.Tensor] = [
torch.zeros(
(
self.encoder_layers,
left_context,
self.d_model,
),
device=device,
),
torch.zeros(
(
self.encoder_layers,
self.cnn_module_kernel - 1,
self.d_model,
),
device=device,
),
]
self._init_state = init_states
return init_states
@torch.jit.export
def streaming_forward(
self,
@ -206,7 +256,7 @@ class Conformer(EncoderInterface):
right_context: int = 4,
simulate_streaming: bool = False,
processed_lens: Optional[Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
"""
Args:
x:
@ -296,7 +346,7 @@ class Conformer(EncoderInterface):
embed, pos_enc = self.encoder_pos(embed, left_context)
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
x = self.encoder.chunk_forward(
x, states = self.encoder.chunk_forward(
embed,
pos_enc,
src_key_padding_mask=src_key_padding_mask,
@ -338,7 +388,7 @@ class Conformer(EncoderInterface):
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, lengths
return x, lengths, states
class ConformerEncoderLayer(nn.Module):
@ -490,7 +540,7 @@ class ConformerEncoderLayer(nn.Module):
warmup: float = 1.0,
left_context: int = 0,
right_context: int = 0,
) -> Tensor:
) -> Tuple[Tensor, List[Tensor]]:
"""
Pass the input through the encoder layer.
@ -527,6 +577,12 @@ class ConformerEncoderLayer(nn.Module):
# macaron style feed forward module
src = src + self.dropout(self.feed_forward_macaron(src))
# We put the attention cache this level (i.e. before linear transformation)
# to save memory consumption, when decoding in streaming fashion, the
# batch size would be thousands (for 32GB machine), if we cache key & val
# separately, it needs extra several GB memory.
# TODO(WeiKang): Move cache to self_attn level (i.e. cache key & val
# separately) if needed.
key = torch.cat([states[0], src], dim=0)
val = key
if right_context > 0:
@ -560,7 +616,7 @@ class ConformerEncoderLayer(nn.Module):
src = self.norm_final(self.balancer(src))
return src
return src, states
class ConformerEncoder(nn.Module):
@ -635,7 +691,7 @@ class ConformerEncoder(nn.Module):
warmup: float = 1.0,
left_context: int = 0,
right_context: int = 0,
) -> Tensor:
) -> Tuple[Tensor, List[Tensor]]:
r"""Pass the input through the encoder layers in turn.
Args:
@ -678,7 +734,7 @@ class ConformerEncoder(nn.Module):
for layer_index, mod in enumerate(self.layers):
cache = [states[0][layer_index], states[1][layer_index]]
output = mod.chunk_forward(
output, cache = mod.chunk_forward(
output,
pos_emb,
states=cache,
@ -691,7 +747,7 @@ class ConformerEncoder(nn.Module):
states[0][layer_index] = cache[0]
states[1][layer_index] = cache[1]
return output
return output, states
class RelPositionalEncoding(torch.nn.Module):

View File

@ -357,7 +357,8 @@ def decode_one_chunk(
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(
(params.decode_chunk_size + 2) * params.subsampling_factor
(params.decode_chunk_size + 2 + params.right_context)
* params.subsampling_factor
)
features.append(feat)
feature_lens.append(feat_len)
@ -394,8 +395,7 @@ def decode_one_chunk(
]
processed_feature_lens = torch.tensor(processed_feature_lens, device=device)
# Note: states will be modified in streaming_forward.
encoder_out, encoder_out_lens = model.encoder.streaming_forward(
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
x=features,
x_lens=feature_lens,
states=states,
@ -475,15 +475,17 @@ def decode_dataset(
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
log_interval = 300
log_interval = 50
decode_results = []
# Contain decode streams currently running.
decode_streams = []
initial_states = model.get_init_state(params.left_context, device=device)
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
decode_stream = DecodeStream(
params=params,
initial_states=initial_states,
decoding_graph=decoding_graph,
device=device,
)

View File

@ -393,7 +393,6 @@ def get_params() -> AttributeDict:
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"cnn_module_kernel": 31,
# parameters for decoder
"decoder_dim": 512,
# parameters for joiner
@ -416,7 +415,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
cnn_module_kernel=params.cnn_module_kernel,
dynamic_chunk_training=params.dynamic_chunk_training,
short_chunk_size=params.short_chunk_size,
num_left_chunks=params.num_left_chunks,

View File

@ -58,6 +58,7 @@ Usage:
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@ -87,9 +88,12 @@ from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
@ -219,6 +223,70 @@ def get_parser():
Used only when the decoding_method is fast_beam_search_nbest_oracle.
""",
)
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="""How many left context can be seen in chunks when calculating attention.
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--right-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
return parser
@ -268,9 +336,27 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
states=[],
chunk_size=params.right_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
@ -509,6 +595,10 @@ def main():
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
@ -544,6 +634,11 @@ def main():
params.unk_id = sp.unk_id()
params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params)
logging.info("About to create model")

View File

@ -0,0 +1 @@
../pruned_transducer_stateless/decode_stream.py

View File

@ -0,0 +1,721 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./pruned_transducer_stateless2/streaming_decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--decoding_method greedy_search \
--num-decode-streams 200
"""
import argparse
import logging
import math
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import AsrDataModule
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from librispeech import LibriSpeech
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.decode import one_best_decoding
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Support only greedy_search and fast_beam_search now.
""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=32,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="""How many left context can be seen in chunks when calculating attention.
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=True,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--right-context",
type=int,
default=4,
help="right context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded parallel.",
)
return parser
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
) -> List[List[int]]:
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# logging.info(f"decoder_out shape : {decoder_out.shape}")
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
decoder_out = model.joiner.decoder_proj(decoder_out)
hyp_tokens = []
for stream in streams:
hyp_tokens.append(stream.hyp)
return hyp_tokens
def fast_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
decoding_streams: k2.RnntDecodingStreams,
) -> List[List[int]]:
B, T, C = encoder_out.shape
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyp_tokens = get_texts(best_path)
return hyp_tokens
def decode_one_chunk(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decode_streams: List[DecodeStream],
) -> List[int]:
"""Decode one chunk frames of features for each decode_streams and
return the indexes of finished streams in a List.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decode_streams:
A List of DecodeStream, each belonging to a utterance.
Returns:
Return a List containing which DecodeStreams are finished.
"""
device = model.device
features = []
feature_lens = []
states = []
rnnt_stream_list = []
processed_feature_lens = []
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(
(params.decode_chunk_size + 2 + params.right_context)
* params.subsampling_factor
)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_feature_lens.append(stream.feature_len)
if params.decoding_method == "fast_beam_search":
rnnt_stream_list.append(stream.rnnt_decoding_stream)
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
# if T is less than 7 there will be an error in time reduction layer,
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
tail_length = 15 + params.right_context * params.subsampling_factor
if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1)
features = torch.cat(
[
features,
torch.tensor(
LOG_EPS, dtype=features.dtype, device=device
).expand(
features.size(0),
tail_length - features.size(1),
features.size(2),
),
],
dim=1,
)
states = [
torch.stack([x[0] for x in states], dim=2),
torch.stack([x[1] for x in states], dim=2),
]
processed_feature_lens = torch.tensor(processed_feature_lens, device=device)
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
x=features,
x_lens=feature_lens,
states=states,
left_context=params.left_context,
right_context=params.right_context,
processed_lens=processed_feature_lens,
)
encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search":
hyp_tokens = greedy_search(model, encoder_out, decode_streams)
elif params.decoding_method == "fast_beam_search":
config = k2.RnntDecodingConfig(
vocab_size=params.vocab_size,
decoder_history_len=params.context_size,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
processed_lens = processed_feature_lens + encoder_out_lens
hyp_tokens = fast_beam_search(
model, encoder_out, processed_lens, decoding_streams
)
else:
assert False
states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
finished_streams = []
for i in range(len(decode_streams)):
decode_streams[i].states = [states[0][i], states[1][i]]
decode_streams[i].feature_len += encoder_out_lens[i]
if params.decoding_method == "fast_beam_search":
decode_streams[i].hyp = hyp_tokens[i]
if decode_streams[i].done:
finished_streams.append(i)
return finished_streams
def decode_dataset(
cuts: CutSet,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
cuts:
Lhotse Cutset containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
device = model.device
opts = FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
log_interval = 50
decode_results = []
# Contain decode streams currently running.
decode_streams = []
initial_states = model.get_init_state(params.left_context, device=device)
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
decode_stream = DecodeStream(
params=params,
initial_states=initial_states,
decoding_graph=decoding_graph,
device=device,
)
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts)
feature = fbank(samples.to(device))
decode_stream.set_features(feature)
decode_stream.ground_truth = cut.supervisions[0].text
decode_streams.append(decode_stream)
while len(decode_streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
params, model, sp, decode_streams
)
for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] # noqa
decode_results.append(
(
decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(),
)
)
del decode_streams[i]
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
# decode final chunks of last sequences
while len(decode_streams):
finished_streams = decode_one_chunk(params, model, sp, decode_streams)
for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] # noqa
decode_results.append(
(
decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(),
)
)
del decode_streams[i]
key = "greedy_search"
if params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
)
return {key: decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=sorted(results))
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
# for streaming
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
params.suffix += f"-right-context-{params.right_context}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.to(device)
model.eval()
model.device = device
decoding_graph = None
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeech(params.manifest_dir)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_sets = ["test-clean", "test-other"]
test_cuts = [test_clean_cuts, test_other_cuts]
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
cuts=test_cut,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -286,6 +286,59 @@ def get_parser():
help="The probability to select a batch from the GigaSpeech dataset",
)
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
""",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="How many left context can be seen in chunks when calculating attention.",
)
parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time masking
encouraging the network to delay symbols.
""",
)
parser.add_argument(
"--return-sym-delay",
type=str2bool,
default=False,
help="""Whether to return `sym_delay` during training, this is a stat
to measure symbols emission delay, especially for time masking training.
""",
)
return parser
@ -372,6 +425,10 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
dynamic_chunk_training=params.dynamic_chunk_training,
short_chunk_size=params.short_chunk_size,
num_left_chunks=params.num_left_chunks,
causal=params.causal_convolution,
)
return encoder
@ -905,6 +962,15 @@ def run(rank, world_size, args):
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
if params.dynamic_chunk_training:
assert (
params.causal_convolution
), "dynamic_chunk_training requires causal convolution"
else:
assert (
params.delay_penalty == 0.0
), "delay_penalty is intended for dynamic_chunk_training"
logging.info(params)
logging.info("About to create model")

View File

@ -350,6 +350,7 @@ def decode_one_batch(
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
states=[],
chunk_size=params.right_chunk_size,
left_context=params.left_context,
simulate_streaming=True,

View File

@ -0,0 +1 @@
../pruned_transducer_stateless/decode_stream.py

View File

@ -0,0 +1,783 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./pruned_transducer_stateless2/streaming_decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--decoding_method greedy_search \
--num-decode-streams 200
"""
import argparse
import logging
import math
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.decode import one_best_decoding
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Support only greedy_search and fast_beam_search now.
""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=32,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="""How many left context can be seen in chunks when calculating attention.
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=True,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--right-context",
type=int,
default=4,
help="right context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded parallel.",
)
return parser
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
) -> List[List[int]]:
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# logging.info(f"decoder_out shape : {decoder_out.shape}")
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
decoder_out = model.joiner.decoder_proj(decoder_out)
hyp_tokens = []
for stream in streams:
hyp_tokens.append(stream.hyp)
return hyp_tokens
def fast_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
decoding_streams: k2.RnntDecodingStreams,
) -> List[List[int]]:
B, T, C = encoder_out.shape
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyp_tokens = get_texts(best_path)
return hyp_tokens
def decode_one_chunk(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decode_streams: List[DecodeStream],
) -> List[int]:
"""Decode one chunk frames of features for each decode_streams and
return the indexes of finished streams in a List.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decode_streams:
A List of DecodeStream, each belonging to a utterance.
Returns:
Return a List containing which DecodeStreams are finished.
"""
device = model.device
features = []
feature_lens = []
states = []
rnnt_stream_list = []
processed_feature_lens = []
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(
(params.decode_chunk_size + 2 + params.right_context)
* params.subsampling_factor
)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_feature_lens.append(stream.feature_len)
if params.decoding_method == "fast_beam_search":
rnnt_stream_list.append(stream.rnnt_decoding_stream)
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
# if T is less than 7 there will be an error in time reduction layer,
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
tail_length = 15 + params.right_context * params.subsampling_factor
if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1)
features = torch.cat(
[
features,
torch.tensor(
LOG_EPS, dtype=features.dtype, device=device
).expand(
features.size(0),
tail_length - features.size(1),
features.size(2),
),
],
dim=1,
)
states = [
torch.stack([x[0] for x in states], dim=2),
torch.stack([x[1] for x in states], dim=2),
]
processed_feature_lens = torch.tensor(processed_feature_lens, device=device)
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
x=features,
x_lens=feature_lens,
states=states,
left_context=params.left_context,
right_context=params.right_context,
processed_lens=processed_feature_lens,
)
encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search":
hyp_tokens = greedy_search(model, encoder_out, decode_streams)
elif params.decoding_method == "fast_beam_search":
config = k2.RnntDecodingConfig(
vocab_size=params.vocab_size,
decoder_history_len=params.context_size,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
processed_lens = processed_feature_lens + encoder_out_lens
hyp_tokens = fast_beam_search(
model, encoder_out, processed_lens, decoding_streams
)
else:
assert False
states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
finished_streams = []
for i in range(len(decode_streams)):
decode_streams[i].states = [states[0][i], states[1][i]]
decode_streams[i].feature_len += encoder_out_lens[i]
if params.decoding_method == "fast_beam_search":
decode_streams[i].hyp = hyp_tokens[i]
if decode_streams[i].done:
finished_streams.append(i)
return finished_streams
def decode_dataset(
cuts: CutSet,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
cuts:
Lhotse Cutset containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
device = model.device
opts = FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
log_interval = 50
decode_results = []
# Contain decode streams currently running.
decode_streams = []
initial_states = model.get_init_state(params.left_context, device=device)
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
decode_stream = DecodeStream(
params=params,
initial_states=initial_states,
decoding_graph=decoding_graph,
device=device,
)
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts)
feature = fbank(samples.to(device))
decode_stream.set_features(feature)
decode_stream.ground_truth = cut.supervisions[0].text
decode_streams.append(decode_stream)
while len(decode_streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
params, model, sp, decode_streams
)
for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] # noqa
decode_results.append(
(
decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(),
)
)
del decode_streams[i]
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
# decode final chunks of last sequences
while len(decode_streams):
finished_streams = decode_one_chunk(params, model, sp, decode_streams)
for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] # noqa
decode_results.append(
(
decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(),
)
)
del decode_streams[i]
key = "greedy_search"
if params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
)
return {key: decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=sorted(results))
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
# for streaming
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
params.suffix += f"-right-context-{params.right_context}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
model.device = device
decoding_graph = None
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_sets = ["test-clean", "test-other"]
test_cuts = [test_clean_cuts, test_other_cuts]
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
cuts=test_cut,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -125,6 +125,8 @@ class Conformer(Transformer):
# and throws an error without this change.
self.after_norm = identity
self._init_state = torch.jit.Attribute([], List[torch.Tensor])
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
@ -189,6 +191,55 @@ class Conformer(Transformer):
return logits, lengths
@torch.jit.export
def get_init_state(
self, left_context: int, device: torch.device
) -> List[torch.Tensor]:
"""Return the initial cache state of the model.
Args:
left_context: The left context size (in frames after subsampling).
Returns:
Return the initial state of the model, it is a list containing two
tensors, the first one is the cache for attentions which has a shape
of (num_encoder_layers, left_context, encoder_dim), the second one
is the cache of conv_modules which has a shape of
(num_encoder_layers, cnn_module_kernel - 1, encoder_dim).
NOTE: the returned tensors are on the given device.
"""
if (
len(self._init_state) == 2
and self._init_state[0].size(1) == left_context
):
# Note: It is OK to share the init state as it is
# not going to be modified by the model
return self._init_state
init_states: List[torch.Tensor] = [
torch.zeros(
(
self.encoder_layers,
left_context,
self.d_model,
),
device=device,
),
torch.zeros(
(
self.encoder_layers,
self.cnn_module_kernel - 1,
self.d_model,
),
device=device,
),
]
self._init_state = init_states
return init_states
@torch.jit.export
def streaming_forward(
self,
@ -198,7 +249,7 @@ class Conformer(Transformer):
chunk_size: int = 16,
left_context: int = 64,
simulate_streaming: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
"""
Args:
x:
@ -229,7 +280,7 @@ class Conformer(Transformer):
- logits, its shape is (batch_size, output_seq_len, output_dim)
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
- decode_states, the updated DecodeStates including the information
- states, the updated states(i.e. caches) including the information
of current chunk.
"""
@ -265,7 +316,7 @@ class Conformer(Transformer):
embed, pos_enc = self.encoder_pos(embed, left_context)
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
x = self.encoder.chunk_forward(
x, states = self.encoder.chunk_forward(
embed,
pos_enc,
src_key_padding_mask=src_key_padding_mask,
@ -304,7 +355,7 @@ class Conformer(Transformer):
logits = self.encoder_output_layer(x)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return logits, lengths
return logits, lengths, states
class ConformerEncoderLayer(nn.Module):
@ -461,7 +512,7 @@ class ConformerEncoderLayer(nn.Module):
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tensor:
) -> Tuple[Tensor, List[Tensor]]:
"""
Pass the input through the encoder layer.
@ -471,9 +522,9 @@ class ConformerEncoderLayer(nn.Module):
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (encoder_layers, left_context, batch, attention_dim),
a shape of (left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
(cnn_module_kernel-1, batch, conv_dim).
Note: states will be modified in this function.
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
@ -503,6 +554,12 @@ class ConformerEncoderLayer(nn.Module):
if self.normalize_before:
src = self.norm_mha(src)
# We put the attention cache this level (i.e. before linear transformation)
# to save memory consumption, when decoding in streaming fashion, the
# batch size would be thousands (for 32GB machine), if we cache key & val
# separately, it needs extra several GB memory.
# TODO(WeiKang): Move cache to self_attn level (i.e. cache key & val
# separately) if needed.
key = torch.cat([states[0], src], dim=0)
val = key
states[0] = key[-left_context:, ...]
@ -543,7 +600,7 @@ class ConformerEncoderLayer(nn.Module):
if self.normalize_before:
src = self.norm_final(src)
return src
return src, states
class ConformerEncoder(nn.Module):
@ -612,7 +669,7 @@ class ConformerEncoder(nn.Module):
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tensor:
) -> Tuple[Tensor, List[Tensor]]:
r"""Pass the input through the encoder layers in turn.
Args:
@ -643,7 +700,7 @@ class ConformerEncoder(nn.Module):
for layer_index, mod in enumerate(self.layers):
cache = [states[0][layer_index], states[1][layer_index]]
output = mod.chunk_forward(
output, cache = mod.chunk_forward(
output,
pos_emb,
states=cache,
@ -654,7 +711,7 @@ class ConformerEncoder(nn.Module):
states[0][layer_index] = cache[0]
states[1][layer_index] = cache[1]
return output
return output, states
class RelPositionalEncoding(torch.nn.Module):