Merge remote-tracking branch 'k2-fsa/streaming' into streaming_new

This commit is contained in:
yaozengwei 2022-04-18 14:48:36 +08:00
commit b343cb51dc
17 changed files with 956 additions and 336 deletions

View File

@ -15,3 +15,7 @@ exclude =
**/data/**,
icefall/shared/make_kn_lm.py,
icefall/__init__.py
ignore =
# E203 whitespace before ':'
E203,

View File

@ -1,98 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class LabelSmoothingLoss(torch.nn.Module):
"""
Implement the LabelSmoothingLoss proposed in the following paper
https://arxiv.org/pdf/1512.00567.pdf
(Rethinking the Inception Architecture for Computer Vision)
"""
def __init__(
self,
ignore_index: int = -1,
label_smoothing: float = 0.1,
reduction: str = "sum",
) -> None:
"""
Args:
ignore_index:
ignored class id
label_smoothing:
smoothing rate (0.0 means the conventional cross entropy loss)
reduction:
It has the same meaning as the reduction in
`torch.nn.CrossEntropyLoss`. It can be one of the following three
values: (1) "none": No reduction will be applied. (2) "mean": the
mean of the output is taken. (3) "sum": the output will be summed.
"""
super().__init__()
assert 0.0 <= label_smoothing < 1.0
self.ignore_index = ignore_index
self.label_smoothing = label_smoothing
self.reduction = reduction
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute loss between x and target.
Args:
x:
prediction of dimension
(batch_size, input_length, number_of_classes).
target:
target masked with self.ignore_index of
dimension (batch_size, input_length).
Returns:
A scalar tensor containing the loss without normalization.
"""
assert x.ndim == 3
assert target.ndim == 2
assert x.shape[:2] == target.shape
num_classes = x.size(-1)
x = x.reshape(-1, num_classes)
# Now x is of shape (N*T, C)
# We don't want to change target in-place below,
# so we make a copy of it here
target = target.clone().reshape(-1)
ignored = target == self.ignore_index
target[ignored] = 0
true_dist = torch.nn.functional.one_hot(
target, num_classes=num_classes
).to(x)
true_dist = (
true_dist * (1 - self.label_smoothing)
+ self.label_smoothing / num_classes
)
# Set the value of ignored indexes to 0
true_dist[ignored] = 0
loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
if self.reduction == "sum":
return loss.sum()
elif self.reduction == "mean":
return loss.sum() / (~ignored).sum()
else:
return loss.sum(dim=-1)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/conformer_ctc/label_smoothing.py

View File

@ -1,98 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class LabelSmoothingLoss(torch.nn.Module):
"""
Implement the LabelSmoothingLoss proposed in the following paper
https://arxiv.org/pdf/1512.00567.pdf
(Rethinking the Inception Architecture for Computer Vision)
"""
def __init__(
self,
ignore_index: int = -1,
label_smoothing: float = 0.1,
reduction: str = "sum",
) -> None:
"""
Args:
ignore_index:
ignored class id
label_smoothing:
smoothing rate (0.0 means the conventional cross entropy loss)
reduction:
It has the same meaning as the reduction in
`torch.nn.CrossEntropyLoss`. It can be one of the following three
values: (1) "none": No reduction will be applied. (2) "mean": the
mean of the output is taken. (3) "sum": the output will be summed.
"""
super().__init__()
assert 0.0 <= label_smoothing < 1.0
self.ignore_index = ignore_index
self.label_smoothing = label_smoothing
self.reduction = reduction
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute loss between x and target.
Args:
x:
prediction of dimension
(batch_size, input_length, number_of_classes).
target:
target masked with self.ignore_index of
dimension (batch_size, input_length).
Returns:
A scalar tensor containing the loss without normalization.
"""
assert x.ndim == 3
assert target.ndim == 2
assert x.shape[:2] == target.shape
num_classes = x.size(-1)
x = x.reshape(-1, num_classes)
# Now x is of shape (N*T, C)
# We don't want to change target in-place below,
# so we make a copy of it here
target = target.clone().reshape(-1)
ignored = target == self.ignore_index
target[ignored] = 0
true_dist = torch.nn.functional.one_hot(
target, num_classes=num_classes
).to(x)
true_dist = (
true_dist * (1 - self.label_smoothing)
+ self.label_smoothing / num_classes
)
# Set the value of ignored indexes to 0
true_dist[ignored] = 0
loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
if self.reduction == "sum":
return loss.sum()
elif self.reduction == "mean":
return loss.sum() / (~ignored).sum()
else:
return loss.sum(dim=-1)

View File

@ -0,0 +1 @@
../conformer_ctc/label_smoothing.py

View File

@ -70,7 +70,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
# |-- lexicon.txt
# `-- speaker.info
if [ ! -d $dl_dir/aishell/data_aishell/wav ]; then
if [ ! -d $dl_dir/aishell/data_aishell/wav/train ]; then
lhotse download aishell $dl_dir
fi

View File

@ -76,7 +76,11 @@ class LabelSmoothingLoss(torch.nn.Module):
target = target.clone().reshape(-1)
ignored = target == self.ignore_index
target[ignored] = 0
# See https://github.com/k2-fsa/icefall/issues/240
# and https://github.com/k2-fsa/icefall/issues/297
# for why we don't use target[ignored] = 0 here
target = torch.where(ignored, torch.zeros_like(target), target)
true_dist = torch.nn.functional.one_hot(
target, num_classes=num_classes
@ -86,8 +90,17 @@ class LabelSmoothingLoss(torch.nn.Module):
true_dist * (1 - self.label_smoothing)
+ self.label_smoothing / num_classes
)
# Set the value of ignored indexes to 0
true_dist[ignored] = 0
#
# See https://github.com/k2-fsa/icefall/issues/240
# and https://github.com/k2-fsa/icefall/issues/297
# for why we don't use true_dist[ignored] = 0 here
true_dist = torch.where(
ignored.unsqueeze(1).repeat(1, true_dist.shape[1]),
torch.zeros_like(true_dist),
true_dist,
)
loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
if self.reduction == "sum":

View File

@ -98,27 +98,28 @@ def get_parser():
"--epoch",
type=int,
default=28,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--avg-last-n",
type=int,
default=0,
help="""If positive, --epoch and --avg are ignored and it
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
where xxx is the number of processed batches while
saving that checkpoint.
""",
"'--epoch' and '--iter'",
)
parser.add_argument(
@ -453,13 +454,19 @@ def main():
)
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -485,8 +492,20 @@ def main():
logging.info("About to create model")
model = get_transducer_model(params)
if params.avg_last_n > 0:
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))

View File

@ -32,13 +32,16 @@ class Joiner(nn.Module):
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, s_range, C).
Output from the encoder. Its shape is (N, T, s_range, C) for
training and (N, C) for streaming decoding.
decoder_out:
Output from the decoder. Its shape is (N, T, s_range, C).
Output from the decoder. Its shape is (N, T, s_range, C) for
training and (N, C) for streaming decoding.
Returns:
Return a tensor of shape (N, T, s_range, C).
"""
assert encoder_out.ndim == decoder_out.ndim == 4
assert encoder_out.ndim == decoder_out.ndim
assert encoder_out.ndim in (2, 4)
assert encoder_out.shape == decoder_out.shape
logit = encoder_out + decoder_out

View File

@ -27,6 +27,75 @@ from torchaudio.models import Emformer as _Emformer
LOG_EPSILON = math.log(1e-10)
def unstack_states(
states: List[List[torch.Tensor]],
) -> List[List[List[torch.Tensor]]]:
"""Unstack the emformer state corresponding to a batch of utterances
into a list of states, were the i-th entry is the state from the i-th
utterance in the batch.
Args:
states:
A list-of-list of tensors. ``len(states)`` equals to number of
layers in the emformer. ``states[i]]`` contains the states for
the i-th layer. ``states[i][k]`` is either a 3-D tensor of shape
``(T, N, C)`` or a 2-D tensor of shape ``(C, N)``
"""
batch_size = states[0][0].size(1)
num_layers = len(states)
ans = [None] * batch_size
for i in range(batch_size):
ans[i] = [[] for _ in range(num_layers)]
for li, layer in enumerate(states):
for s in layer:
s_list = s.unbind(dim=1)
# We will use stack(dim=1) later in stack_states()
for bi, b in enumerate(ans):
b[li].append(s_list[bi])
return ans
def stack_states(
state_list: List[List[List[torch.Tensor]]],
) -> List[List[torch.Tensor]]:
"""Stack list of emformer states that correspond to separate utterances
into a single emformer state so that it can be used as an input for
emformer when those utterances are formed into a batch.
Note:
It is the inverse of :func:`unstack_states`.
Args:
state_list:
Each element in state_list corresponding to the internal state
of the emformer model for a single utterance.
Returns:
Return a new state corresponding to a batch of utterances.
See the input argument of :func:`unstack_states` for the meaning
of the returned tensor.
"""
batch_size = len(state_list)
ans = []
for layer in state_list[0]:
# layer is a list of tensors
if batch_size > 1:
ans.append([[s] for s in layer])
# Note: We will stack ans[layer][s][] later to get ans[layer][s]
else:
ans.append([s.unsqueeze(1) for s in layer])
for b, states in enumerate(state_list[1:], 1):
for li, layer in enumerate(states):
for si, s in enumerate(layer):
ans[li][si].append(s)
if b == batch_size - 1:
ans[li][si] = torch.stack(ans[li][si], dim=1)
# We will use unbind(dim=1) later in unstack_states()
return ans
class Emformer(EncoderInterface):
"""This is just a simple wrapper around torchaudio.models.Emformer.
We may replace it with our own implementation some time later.
@ -63,11 +132,11 @@ class Emformer(EncoderInterface):
num_encoder_layers:
Number of encoder layers.
segment_length:
Number of frames per segment.
Number of frames per segment before subsampling.
left_context_length:
Number of frames in the left context.
Number of frames in the left context before subsampling.
right_context_length:
Number of frames in the right context.
Number of frames in the right context before subsampling.
max_memory_size:
TODO.
dropout:
@ -94,6 +163,7 @@ class Emformer(EncoderInterface):
else:
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.segment_length = segment_length
self.right_context_length = right_context_length
assert right_context_length % subsampling_factor == 0

View File

@ -0,0 +1,184 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
./transducer_emformer/export.py \
--exp-dir ./transducer_emformer/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
It will generate a file exp_dir/pretrained.pt
To use the generated file with `transducer_emformer/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./transducer_emformer/decode.py \
--exp-dir ./transducer_emformer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 1000 \
--bpe-model data/lang_bpe_500/bpe.model
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
add_model_arguments(parser)
return parser
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.eval()
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -18,16 +18,16 @@
import argparse
import logging
import time
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Tuple
import kaldifeat
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from emformer import LOG_EPSILON, stack_states, unstack_states
from streaming_feature_extractor import FeatureExtractionStream
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
@ -147,10 +147,10 @@ def get_parser():
)
parser.add_argument(
"--sample-rate",
type=int,
"--sampling-rate",
type=float,
default=16000,
help="The sample rate of the input sound file",
help="Sample rate of the audio",
)
add_model_arguments(parser)
@ -158,115 +158,352 @@ def get_parser():
return parser
def get_feature_extractor(
params: AttributeDict,
) -> kaldifeat.Fbank:
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = params.device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = True
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
class StreamingAudioSamples(object):
"""This class takes as input a list of audio samples and returns
them in a streaming fashion.
"""
return kaldifeat.Fbank(opts)
def __init__(self, samples: List[torch.Tensor]) -> None:
"""
Args:
samples:
A list of audio samples. Each entry is a 1-D tensor of dtype
torch.float32, containing the audio samples of an utterance.
"""
self.samples = samples
self.cur_indexes = [0] * len(self.samples)
@property
def done(self) -> bool:
"""Return True if all samples have been processed.
Return False otherwise.
"""
for i, samples in zip(self.cur_indexes, self.samples):
if i < samples.numel():
return False
return True
def get_next(self) -> List[torch.Tensor]:
"""Return a list of audio samples. Each entry may have different
lengths. It is OK if an entry contains no samples at all, which
means it reaches the end of the utterance.
"""
ans = []
num = [1024] * len(self.samples)
for i in range(len(self.samples)):
start = self.cur_indexes[i]
end = start + num[i]
self.cur_indexes[i] = end
s = self.samples[i][start:end]
ans.append(s)
return ans
def decode_one_utterance(
audio_samples: torch.Tensor,
class StreamList(object):
def __init__(
self,
batch_size: int,
context_size: int,
blank_id: int,
):
"""
Args:
batch_size:
Size of this batch.
context_size:
Context size of the RNN-T decoder model.
blank_id:
The ID of the blank symbol of the BPE model.
"""
self.streams = [
FeatureExtractionStream(
context_size=context_size, blank_id=blank_id
)
for _ in range(batch_size)
]
@property
def done(self) -> bool:
"""Return True if all streams have reached end of utterance.
That is, no more audio samples are available for all utterances.
"""
return all(stream.done for stream in self.streams)
def accept_waveform(
self,
audio_samples: List[torch.Tensor],
sampling_rate: float,
):
"""Feeed audio samples to each stream.
Args:
audio_samples:
A list of 1-D tensors containing the audio samples for each
utterance in the batch. If an entry is empty, it means
end-of-utterance has been reached.
sampling_rate:
Sampling rate of the given audio samples.
"""
assert len(audio_samples) == len(self.streams)
for stream, samples in zip(self.streams, audio_samples):
if stream.done:
assert samples.numel() == 0
continue
stream.accept_waveform(
sampling_rate=sampling_rate,
waveform=samples,
)
if samples.numel() == 0:
stream.input_finished()
def build_batch(
self,
chunk_length: int,
segment_length: int,
) -> Tuple[Optional[torch.Tensor], Optional[List[FeatureExtractionStream]]]:
"""
Args:
chunk_length:
Number of frames for each chunk. It equals to
``segment_length + right_context_length``.
segment_length
Number of frames for each segment.
Returns:
Return a tuple containing:
- features, a 3-D tensor of shape ``(num_active_streams, T, C)``
- active_streams, a list of active streams. We say a stream is
active when it has enough feature frames to be fed into the
encoder model.
"""
feature_list = []
stream_list = []
for stream in self.streams:
if len(stream.feature_frames) >= chunk_length:
# this_chunk is a list of tensors, each of which
# has a shape (1, feature_dim)
chunk = stream.feature_frames[:chunk_length]
stream.feature_frames = stream.feature_frames[segment_length:]
features = torch.cat(chunk, dim=0)
feature_list.append(features)
stream_list.append(stream)
elif stream.done and len(stream.feature_frames) > 0:
chunk = stream.feature_frames[:chunk_length]
stream.feature_frames = []
features = torch.cat(chunk, dim=0)
features = torch.nn.functional.pad(
features,
(0, 0, 0, chunk_length - features.size(0)),
mode="constant",
value=LOG_EPSILON,
)
feature_list.append(features)
stream_list.append(stream)
if len(feature_list) == 0:
return None, None
features = torch.stack(feature_list, dim=0)
return features, stream_list
def greedy_search(
model: nn.Module,
fbank: kaldifeat.Fbank,
params: AttributeDict,
streams: List[FeatureExtractionStream],
encoder_out: torch.Tensor,
sp: spm.SentencePieceProcessor,
):
"""Decode one utterance.
"""
Args:
audio_samples:
A 1-D float32 tensor of shape (num_samples,) containing the normalized
audio samples. Normalized means the samples is in the range [-1, 1].
model:
The RNN-T model.
feature_extractor:
The feature extractor.
stream:
A stream object.
encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
sp:
The BPE model.
"""
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
if streams[0].decoder_out is None:
decoder_input = torch.tensor(
[stream.hyp.ys[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
).squeeze(1)
# decoder_out is of shape (N, decoder_out_dim)
else:
decoder_out = torch.stack(
[stream.decoder_out for stream in streams],
dim=0,
)
assert encoder_out.ndim == 3
T = encoder_out.size(1)
for t in range(T):
current_encoder_out = encoder_out[:, t]
# current_encoder_out's shape: (batch_size, encoder_out_dim)
logits = model.joiner(current_encoder_out, decoder_out)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.ys.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp.ys[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(
1
)
for k, s in enumerate(streams):
logging.info(
f"Partial result {k}:\n{sp.decode(s.hyp.ys[context_size:])}"
)
decoder_out_list = decoder_out.unbind(dim=0)
for i, d in enumerate(decoder_out_list):
streams[i].decoder_out = d
def process_features(
model: nn.Module,
features: torch.Tensor,
streams: List[FeatureExtractionStream],
sp: spm.SentencePieceProcessor,
) -> None:
"""Process features for each stream in parallel.
Args:
model:
The RNN-T model.
features:
A 3-D tensor of shape (N, T, C).
streams:
A list of streams of size (N,).
sp:
The BPE model.
"""
assert features.ndim == 3
assert features.size(0) == len(streams)
batch_size = features.size(0)
device = model.device
features = features.to(device)
feature_lens = torch.full(
(batch_size,),
fill_value=features.size(1),
device=device,
)
# Caution: It has a limitation as it assumes that
# if one of the stream has an empty state, then all other
# streams also have empty states.
if streams[0].states is None:
states = None
else:
state_list = [stream.states for stream in streams]
states = stack_states(state_list)
(encoder_out, encoder_out_lens, states,) = model.encoder.streaming_forward(
features,
feature_lens,
states,
)
state_list = unstack_states(states)
for i, s in enumerate(state_list):
streams[i].states = s
greedy_search(
model=model,
streams=streams,
encoder_out=encoder_out,
sp=sp,
)
def decode_batch(
batched_samples: List[torch.Tensor],
model: nn.Module,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
) -> List[str]:
"""
Args:
batched_samples:
A list of 1-D tensors containing the audio samples of each utterance.
model:
The RNN-T model.
params:
It is the return value of :func:`get_params`.
sp:
The BPE model.
"""
sample_rate = params.sample_rate
frame_shift = sample_rate * fbank.opts.frame_opts.frame_shift_ms / 1000
# number of frames before subsampling
segment_length = model.encoder.segment_length
frame_shift = int(frame_shift) # number of samples
right_context_length = model.encoder.right_context_length
# Note: We add 3 here because the subsampling method ((n-1)//2-1))//2
# is not equal to n//4. We will switch to a subsampling method that
# satisfies n//4, where n is the number of input frames.
segment_length = (params.segment_length + 3) * frame_shift
# We add 3 here since the subsampling method is using
# ((len - 1) // 2 - 1) // 2)
chunk_length = (segment_length + 3) + right_context_length
right_context_length = params.right_context_length * frame_shift
chunk_size = segment_length + right_context_length
batch_size = len(batched_samples)
streaming_audio_samples = StreamingAudioSamples(batched_samples)
opts = fbank.opts.frame_opts
chunk_size += (
(opts.frame_length_ms - opts.frame_shift_ms) / 1000 * sample_rate
stream_list = StreamList(
batch_size=batch_size,
context_size=params.context_size,
blank_id=params.blank_id,
)
chunk_size = int(chunk_size)
states: Optional[List[List[torch.Tensor]]] = None
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
hyp = [blank_id] * context_size
decoder_input = torch.tensor(hyp, device=device, dtype=torch.int64).reshape(
1, context_size
)
decoder_out = model.decoder(decoder_input, need_pad=False)
i = 0
num_samples = audio_samples.size(0)
while i < num_samples:
# Note: The current approach of computing the features is not ideal
# since it re-computes the features for the right context.
chunk = audio_samples[i : i + chunk_size] # noqa
i += segment_length
if chunk.size(0) < chunk_size:
chunk = torch.nn.functional.pad(
chunk, pad=(0, chunk_size - chunk.size(0))
)
features = fbank(chunk)
feature_lens = torch.tensor([features.size(0)], device=params.device)
features = features.unsqueeze(0) # (1, T, C)
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
features,
feature_lens,
states,
while not streaming_audio_samples.done:
samples = streaming_audio_samples.get_next()
stream_list.accept_waveform(
audio_samples=samples,
sampling_rate=params.sampling_rate,
)
for t in range(encoder_out_lens.item()):
# fmt: off
current_encoder_out = encoder_out[0:1, t:t+1, :].unsqueeze(2)
# fmt: on
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1))
# logits is (1, 1, 1, vocab_size)
y = logits.argmax().item()
if y == blank_id:
continue
hyp.append(y)
decoder_input = torch.tensor(
[hyp[-context_size:]], device=device, dtype=torch.int64
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
logging.info(f"Partial result:\n{sp.decode(hyp[context_size:])}")
features, active_streams = stream_list.build_batch(
chunk_length=chunk_length,
segment_length=segment_length,
)
if features is not None:
process_features(
model=model,
features=features,
streams=active_streams,
sp=sp,
)
results = []
for s in stream_list.streams:
text = sp.decode(s.hyp.ys[params.context_size :])
results.append(text)
return results
@torch.no_grad()
@ -333,30 +570,43 @@ def main():
test_clean_cuts = librispeech.test_clean_cuts()
fbank = get_feature_extractor(params)
batch_size = 3
ground_truth = []
batched_samples = []
for num, cut in enumerate(test_clean_cuts):
logging.info("Processing {num}")
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
decode_one_utterance(
audio_samples=torch.from_numpy(audio).squeeze(0).to(device),
model=model,
fbank=fbank,
params=params,
sp=sp,
)
logging.info(f"The ground truth is:\n{cut.supervisions[0].text}")
if num >= 0:
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
batched_samples.append(samples)
ground_truth.append(cut.supervisions[0].text)
if len(batched_samples) >= batch_size:
decoded_results = decode_batch(
batched_samples=batched_samples,
model=model,
params=params,
sp=sp,
)
s = "\n"
for i, (hyp, ref) in enumerate(zip(decoded_results, ground_truth)):
s += f"hyp {i}:\n{hyp}\n"
s += f"ref {i}:\n{ref}\n\n"
logging.info(s)
batched_samples = []
ground_truth = []
# break after processing the first batch for test purposes
break
time.sleep(2) # So that you can see the decoded results
if __name__ == "__main__":
torch.manual_seed(20220410)
main()

View File

@ -0,0 +1,116 @@
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import torch
from beam_search import Hypothesis
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
def _create_streaming_feature_extractor() -> OnlineFeature:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
return OnlineFbank(opts)
class FeatureExtractionStream(object):
def __init__(self, context_size: int, blank_id: int = 0) -> None:
"""Context size of the RNN-T decoder model."""
self.feature_extractor = _create_streaming_feature_extractor()
self.hyp = Hypothesis(
ys=([blank_id] * context_size),
log_prob=torch.tensor([0.0]),
) # for greedy search, will extend it to beam search
# It contains a list of 1-D tensors representing the feature frames.
self.feature_frames: List[torch.Tensor] = []
self.num_fetched_frames = 0
# For the emformer model, it contains the states of each
# encoder layer.
self.states: Optional[List[List[torch.Tensor]]] = None
# For the RNN-T decoder, it contains the decoder output
# corresponding to the decoder input self.hyp.ys[-context_size:]
# Its shape is (decoder_out_dim,)
self.decoder_out: Optional[torch.Tensor] = None
# After calling `self.input_finished()`, we set this flag to True
self._done = False
def accept_waveform(
self,
sampling_rate: float,
waveform: torch.Tensor,
) -> None:
"""Feed audio samples to the feature extractor and compute features
if there are enough samples available.
Caution:
The range of the audio samples should match the one used in the
training. That is, if you use the range [-1, 1] in the training, then
the input audio samples should also be normalized to [-1, 1].
Args
sampling_rate:
The sampling rate of the input audio samples. It is used for sanity
check to ensure that the input sampling rate equals to the one
used in the extractor. If they are not equal, then no resampling
will be performed; instead an error will be thrown.
waveform:
A 1-D torch tensor of dtype torch.float32 containing audio samples.
It should be on CPU.
"""
self.feature_extractor.accept_waveform(
sampling_rate=sampling_rate,
waveform=waveform,
)
self._fetch_frames()
def input_finished(self) -> None:
"""Signal that no more audio samples available and the feature
extractor should flush the buffered samples to compute frames.
"""
self.feature_extractor.input_finished()
self._fetch_frames()
self._done = True
@property
def done(self) -> bool:
"""Return True if `self.input_finished()` has been invoked"""
return self._done
def _fetch_frames(self) -> None:
"""Fetch frames from the feature extractor"""
while self.num_fetched_frames < self.feature_extractor.num_frames_ready:
frame = self.feature_extractor.get_frame(self.num_fetched_frames)
self.feature_frames.append(frame)
self.num_fetched_frames += 1

View File

@ -25,7 +25,7 @@ To run this file, do:
import warnings
import torch
from emformer import Emformer
from emformer import Emformer, stack_states, unstack_states
def test_emformer():
@ -191,9 +191,43 @@ def test_emformer_infer_batch_single_consistency():
assert torch.allclose(batch_logits, single_logits, atol=1e-5, rtol=0.0)
def test_emformer_streaming_forward():
N = 3
C = 80
output_dim = 500
encoder = Emformer(
num_features=C,
output_dim=output_dim,
d_model=512,
nhead=8,
dim_feedforward=2048,
num_encoder_layers=20,
segment_length=16,
left_context_length=120,
right_context_length=4,
vgg_frontend=False,
)
x = torch.rand(N, 23, C)
x_lens = torch.full((N,), 23)
y, y_lens, states = encoder.streaming_forward(x=x, x_lens=x_lens)
state_list = unstack_states(states)
states2 = stack_states(state_list)
for ss, ss2 in zip(states, states2):
for s, s2 in zip(ss, ss2):
assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}"
@torch.no_grad()
def main():
# test_emformer()
test_emformer()
test_emformer_infer_batch_single_consistency()
test_emformer_streaming_forward()
if __name__ == "__main__":

View File

@ -0,0 +1,53 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./transducer_emformer/test_streaming_feature_extractor.py
"""
import torch
from streaming_feature_extractor import FeatureExtractionStream
def test_streaming_feature_extractor():
stream = FeatureExtractionStream(context_size=2, blank_id=0)
samples = torch.rand(16000)
start = 0
while True:
n = torch.randint(50, 500, (1,)).item()
end = start + n
this_chunk = samples[start:end]
start = end
if len(this_chunk) == 0:
break
stream.accept_waveform(sampling_rate=16000, waveform=this_chunk)
print(len(stream.feature_frames))
stream.input_finished()
print(len(stream.feature_frames))
def main():
test_streaming_feature_extractor()
if __name__ == "__main__":
main()

View File

@ -1,3 +1,34 @@
from .checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
remove_checkpoints,
save_checkpoint,
save_checkpoint_with_global_batch_idx,
)
from .decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from .dist import (
cleanup_dist,
setup_dist,
)
from .env import (
get_env_info,
get_git_branch_name,
get_git_date,
get_git_sha1,
)
from .utils import (
AttributeDict,
MetricsTracker,

View File

@ -216,27 +216,62 @@ def save_checkpoint_with_global_batch_idx(
)
def find_checkpoints(out_dir: Path) -> List[str]:
def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
"""Find all available checkpoints in a directory.
The checkpoint filenames have the form: `checkpoint-xxx.pt`
where xxx is a numerical value.
Assume you have the following checkpoints in the folder `foo`:
- checkpoint-1.pt
- checkpoint-20.pt
- checkpoint-300.pt
- checkpoint-4000.pt
Case 1 (Return all checkpoints)::
find_checkpoints(out_dir='foo')
Case 2 (Return checkpoints newer than checkpoint-20.pt, i.e.,
checkpoint-4000.pt, checkpoint-300.pt, and checkpoint-20.pt)
find_checkpoints(out_dir='foo', iteration=20)
Case 3 (Return checkpoints older than checkpoint-20.pt, i.e.,
checkpoint-20.pt, checkpoint-1.pt)::
find_checkpoints(out_dir='foo', iteration=-20)
Args:
out_dir:
The directory where to search for checkpoints.
iteration:
If it is 0, return all available checkpoints.
If it is positive, return the checkpoints whose iteration number is
greater than or equal to `iteration`.
If it is negative, return the checkpoints whose iteration number is
less than or equal to `-iteration`.
Returns:
Return a list of checkpoint filenames, sorted in descending
order by the numerical value in the filename.
"""
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
pattern = re.compile(r"checkpoint-([0-9]+).pt")
idx_checkpoints = [
iter_checkpoints = [
(int(pattern.search(c).group(1)), c) for c in checkpoints
]
# iter_checkpoints is a list of tuples. Each tuple contains
# two elements: (iteration_number, checkpoint-iteration_number.pt)
iter_checkpoints = sorted(
iter_checkpoints, reverse=True, key=lambda x: x[0]
)
if iteration >= 0:
ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration]
else:
ans = [ic[1] for ic in iter_checkpoints if ic[0] <= -iteration]
idx_checkpoints = sorted(idx_checkpoints, reverse=True, key=lambda x: x[0])
ans = [ic[1] for ic in idx_checkpoints]
return ans

View File

@ -95,6 +95,7 @@ def get_env_info() -> Dict[str, Any]:
"k2-git-sha1": k2.version.__git_sha1__,
"k2-git-date": k2.version.__git_date__,
"lhotse-version": lhotse.__version__,
"torch-version": torch.__version__,
"torch-cuda-available": torch.cuda.is_available(),
"torch-cuda-version": torch.version.cuda,
"python-version": sys.version[:3],

View File

@ -1,5 +1,6 @@
[tool.isort]
profile = "black"
skip = ["icefall/__init__.py"]
[tool.black]
line-length = 80