Merge 74200583be92b96726de086023db2a0f45a1d401 into 8eff7a4642da89bace5ea60ea9c5cae7ef5043b0

This commit is contained in:
Zengwei Yao 2022-05-25 09:48:59 +00:00 committed by GitHub
commit b056e55ead
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 4529 additions and 0 deletions

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/asr_datamodule.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/beam_search.py

View File

@ -0,0 +1,634 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless4/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./pruned_transducer_stateless4/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless4/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./pruned_transducer_stateless4/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless4/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
else:
log_interval = 2
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0
start = params.epoch - params.avg
assert start >= 1
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/decoder.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/encoder_interface.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/joiner.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/model.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/optim.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/scaling.py

View File

@ -0,0 +1,585 @@
import torch
def test_rel_positional_encoding():
from emformer import RelPositionalEncoding
D = 256
pos_enc = RelPositionalEncoding(D, dropout=0.1)
pos_len = 100
neg_len = 100
x = torch.randn(2, D)
x, pos_emb = pos_enc(x, pos_len, neg_len)
assert pos_emb.shape == (pos_len + neg_len - 1, D)
def test_emformer_attention_forward():
from emformer import EmformerAttention
B, D = 2, 256
chunk_length = 4
right_context_length = 2
num_chunks = 3
U = num_chunks * chunk_length
R = num_chunks * right_context_length
attention = EmformerAttention(
embed_dim=D,
nhead=8,
chunk_length=chunk_length,
right_context_length=right_context_length,
)
for use_memory in [True, False]:
if use_memory:
S = num_chunks
M = S - 1
else:
S, M = 0, 0
Q, KV = R + U + S, M + R + U
utterance = torch.randn(U, B, D)
lengths = torch.randint(1, U + 1, (B,))
lengths[0] = U
right_context = torch.randn(R, B, D)
summary = torch.randn(S, B, D)
memory = torch.randn(M, B, D)
attention_mask = torch.rand(Q, KV) >= 0.5
PE = 2 * U + right_context_length - 1
pos_emb = torch.randn(PE, D)
(
output_right_context_utterance,
output_memory,
probs_memory,
probs_frames,
) = attention(
utterance,
lengths,
right_context,
summary,
memory,
attention_mask,
pos_emb,
)
assert output_right_context_utterance.shape == (R + U, B, D)
assert output_memory.shape == (M, B, D)
assert probs_memory.shape == (B, U)
assert probs_frames.shape == (B, U)
def test_emformer_attention_infer():
from emformer import EmformerAttention
B, D = 2, 256
chunk_length = 4
right_context_length = 2
num_chunks = 1
U = chunk_length * num_chunks
R = right_context_length * num_chunks
L = 3
attention = EmformerAttention(
embed_dim=D,
nhead=8,
chunk_length=chunk_length,
right_context_length=right_context_length,
)
for use_memory in [True, False]:
if use_memory:
S, M = 1, 3
else:
S, M = 0, 0
utterance = torch.randn(U, B, D)
lengths = torch.randint(1, U + 1, (B,))
lengths[0] = U
right_context = torch.randn(R, B, D)
summary = torch.randn(S, B, D)
memory = torch.randn(M, B, D)
left_context_key = torch.randn(L, B, D)
left_context_val = torch.randn(L, B, D)
PE = (
2 * U
+ right_context_length
- 1
+ (M * chunk_length if M > 0 else L)
)
pos_emb = torch.randn(PE, D)
(
output_right_context_utterance,
output_memory,
next_key,
next_val,
) = attention.infer(
utterance,
lengths,
right_context,
summary,
memory,
left_context_key,
left_context_val,
pos_emb,
)
assert output_right_context_utterance.shape == (R + U, B, D)
assert output_memory.shape == (S, B, D)
assert next_key.shape == (L + U, B, D)
assert next_val.shape == (L + U, B, D)
def test_convolution_module_forward():
from emformer import ConvolutionModule
B, D = 2, 256
chunk_length = 4
right_context_length = 2
num_chunks = 3
U = num_chunks * chunk_length
R = num_chunks * right_context_length
kernel_size = 31
conv_module = ConvolutionModule(
chunk_length,
right_context_length,
D,
kernel_size,
)
utterance = torch.randn(U, B, D)
right_context = torch.randn(R, B, D)
cache = torch.randn(B, D, kernel_size - 1)
utterance, right_context, new_cache = conv_module(
utterance, right_context, cache
)
assert utterance.shape == (U, B, D)
assert right_context.shape == (R, B, D)
assert new_cache.shape == (B, D, kernel_size - 1)
def test_convolution_module_infer():
from emformer import ConvolutionModule
B, D = 2, 256
chunk_length = 4
right_context_length = 2
num_chunks = 1
U = num_chunks * chunk_length
R = num_chunks * right_context_length
kernel_size = 31
conv_module = ConvolutionModule(
chunk_length,
right_context_length,
D,
kernel_size,
)
utterance = torch.randn(U, B, D)
right_context = torch.randn(R, B, D)
cache = torch.randn(B, D, kernel_size - 1)
utterance, right_context, new_cache = conv_module.infer(
utterance, right_context, cache
)
assert utterance.shape == (U, B, D)
assert right_context.shape == (R, B, D)
assert new_cache.shape == (B, D, kernel_size - 1)
def test_emformer_encoder_layer_forward():
from emformer import EmformerEncoderLayer
B, D = 2, 256
chunk_length = 8
right_context_length = 2
left_context_length = 8
kernel_size = 31
num_chunks = 3
U = num_chunks * chunk_length
R = num_chunks * right_context_length
for use_memory in [True, False]:
if use_memory:
S = num_chunks
M = S - 1
else:
S, M = 0, 0
layer = EmformerEncoderLayer(
d_model=D,
nhead=8,
dim_feedforward=1024,
chunk_length=chunk_length,
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=M,
)
Q, KV = R + U + S, M + R + U
utterance = torch.randn(U, B, D)
lengths = torch.randint(1, U + 1, (B,))
lengths[0] = U
right_context = torch.randn(R, B, D)
memory = torch.randn(M, B, D)
attention_mask = torch.rand(Q, KV) >= 0.5
PE = 2 * U + right_context_length - 1
pos_emb = torch.randn(PE, D)
output_utterance, output_right_context, output_memory = layer(
utterance,
lengths,
right_context,
memory,
attention_mask,
pos_emb,
)
assert output_utterance.shape == (U, B, D)
assert output_right_context.shape == (R, B, D)
assert output_memory.shape == (M, B, D)
def test_emformer_encoder_layer_infer():
from emformer import EmformerEncoderLayer
B, D = 2, 256
chunk_length = 8
right_context_length = 2
left_context_length = 8
kernel_size = 31
num_chunks = 1
U = num_chunks * chunk_length
R = num_chunks * right_context_length
for use_memory in [True, False]:
if use_memory:
max_memory_size = 3
M = 1
else:
max_memory_size = 0
M = 0
layer = EmformerEncoderLayer(
d_model=D,
nhead=8,
dim_feedforward=1024,
chunk_length=chunk_length,
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=max_memory_size,
)
utterance = torch.randn(U, B, D)
lengths = torch.randint(1, U + 1, (B,))
lengths[0] = U
right_context = torch.randn(R, B, D)
memory = torch.randn(M, B, D)
state = None
PE = (
2 * U
+ right_context_length
- 1
+ (
max_memory_size * chunk_length
if max_memory_size > 0
else left_context_length
)
)
pos_emb = torch.randn(PE, D)
conv_cache = None
(
output_utterance,
output_right_context,
output_memory,
output_state,
conv_cache,
) = layer.infer(
utterance,
lengths,
right_context,
memory,
pos_emb,
state,
conv_cache,
)
assert output_utterance.shape == (U, B, D)
assert output_right_context.shape == (R, B, D)
if use_memory:
assert output_memory.shape == (1, B, D)
else:
assert output_memory.shape == (0, B, D)
assert len(output_state) == 4
assert output_state[0].shape == (max_memory_size, B, D)
assert output_state[1].shape == (left_context_length, B, D)
assert output_state[2].shape == (left_context_length, B, D)
assert output_state[3].shape == (1, B)
assert conv_cache.shape == (B, D, kernel_size - 1)
def test_emformer_encoder_forward():
from emformer import EmformerEncoder
B, D = 2, 256
chunk_length = 4
right_context_length = 2
left_context_length = 2
num_chunks = 3
U = num_chunks * chunk_length
kernel_size = 31
num_encoder_layers = 2
for use_memory in [True, False]:
if use_memory:
S = num_chunks
M = S - 1
else:
S, M = 0, 0
encoder = EmformerEncoder(
chunk_length=chunk_length,
d_model=D,
dim_feedforward=1024,
num_encoder_layers=num_encoder_layers,
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=M,
)
x = torch.randn(U + right_context_length, B, D)
lengths = torch.randint(1, U + right_context_length + 1, (B,))
lengths[0] = U + right_context_length
output, output_lengths = encoder(x, lengths)
assert output.shape == (U, B, D)
assert torch.equal(
output_lengths, torch.clamp(lengths - right_context_length, min=0)
)
def test_emformer_encoder_infer():
from emformer import EmformerEncoder
B, D = 2, 256
num_encoder_layers = 2
chunk_length = 4
right_context_length = 2
left_context_length = 2
num_chunks = 3
kernel_size = 31
for use_memory in [True, False]:
if use_memory:
max_memory_size = 3
else:
max_memory_size = 0
encoder = EmformerEncoder(
chunk_length=chunk_length,
d_model=D,
dim_feedforward=1024,
num_encoder_layers=num_encoder_layers,
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=max_memory_size,
)
states = None
conv_caches = None
for chunk_idx in range(num_chunks):
x = torch.randn(chunk_length + right_context_length, B, D)
lengths = torch.randint(
1, chunk_length + right_context_length + 1, (B,)
)
lengths[0] = chunk_length + right_context_length
output, output_lengths, states, conv_caches = encoder.infer(
x, lengths, states, conv_caches
)
assert output.shape == (chunk_length, B, D)
assert torch.equal(
output_lengths,
torch.clamp(lengths - right_context_length, min=0),
)
assert len(states) == num_encoder_layers
for state in states:
assert len(state) == 4
assert state[0].shape == (max_memory_size, B, D)
assert state[1].shape == (left_context_length, B, D)
assert state[2].shape == (left_context_length, B, D)
assert torch.equal(
state[3],
(chunk_idx + 1) * chunk_length * torch.ones_like(state[3]),
)
for conv_cache in conv_caches:
assert conv_cache.shape == (B, D, kernel_size - 1)
def test_emformer_encoder_forward_infer_consistency():
from emformer import EmformerEncoder
chunk_length = 4
num_chunks = 3
U = chunk_length * num_chunks
left_context_length, right_context_length = 1, 2
D = 256
num_encoder_layers = 3
kernel_size = 31
memory_sizes = [0, 3]
for max_memory_size in memory_sizes:
encoder = EmformerEncoder(
chunk_length=chunk_length,
d_model=D,
dim_feedforward=1024,
num_encoder_layers=num_encoder_layers,
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=max_memory_size,
)
encoder.eval()
x = torch.randn(U + right_context_length, 1, D)
lengths = torch.tensor([U + right_context_length])
# training mode with full utterance
forward_output, forward_output_lengths = encoder(x, lengths)
# streaming inference mode with individual chunks
states = None
conv_caches = None
for chunk_idx in range(num_chunks):
start_idx = chunk_idx * chunk_length
end_idx = start_idx + chunk_length
chunk = x[start_idx : end_idx + right_context_length] # noqa
chunk_length = torch.tensor([chunk_length])
(
infer_output_chunk,
infer_output_lengths,
states,
conv_caches,
) = encoder.infer(chunk, chunk_length, states, conv_caches)
forward_output_chunk = forward_output[start_idx:end_idx]
assert torch.allclose(
infer_output_chunk,
forward_output_chunk,
atol=1e-4,
rtol=0.0,
), (
infer_output_chunk - forward_output_chunk
)
def test_emformer_forward():
from emformer import Emformer
num_features = 80
chunk_length = 16
right_context_length = 8
left_context_length = 8
num_chunks = 3
U = num_chunks * chunk_length
B, D = 2, 256
kernel_size = 31
for use_memory in [True, False]:
if use_memory:
max_memory_size = 3
else:
max_memory_size = 0
model = Emformer(
num_features=num_features,
chunk_length=chunk_length,
subsampling_factor=4,
d_model=D,
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=max_memory_size,
)
x = torch.randn(B, U + right_context_length + 3, num_features)
x_lens = torch.randint(1, U + right_context_length + 3 + 1, (B,))
x_lens[0] = U + right_context_length + 3
output, output_lengths = model(x, x_lens)
assert output.shape == (B, U // 4, D)
assert torch.equal(
output_lengths,
torch.clamp(
((x_lens - 1) // 2 - 1) // 2 - right_context_length // 4, min=0
),
)
def test_emformer_infer():
from emformer import Emformer
num_features = 80
chunk_length = 8
U = chunk_length
left_context_length, right_context_length = 32, 4
B, D = 2, 256
num_chunks = 3
num_encoder_layers = 2
kernel_size = 31
for use_memory in [True, False]:
if use_memory:
max_memory_size = 32
else:
max_memory_size = 0
model = Emformer(
num_features=num_features,
chunk_length=chunk_length,
subsampling_factor=4,
d_model=D,
num_encoder_layers=num_encoder_layers,
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=max_memory_size,
)
states = None
conv_caches = None
for chunk_idx in range(num_chunks):
x = torch.randn(B, U + right_context_length + 3, num_features)
x_lens = torch.randint(1, U + right_context_length + 3 + 1, (B,))
x_lens[0] = U + right_context_length + 3
output, output_lengths, states, conv_caches = model.infer(
x, x_lens, states, conv_caches
)
assert output.shape == (B, U // 4, D)
assert torch.equal(
output_lengths,
torch.clamp(
((x_lens - 1) // 2 - 1) // 2 - right_context_length // 4,
min=0,
),
)
assert len(states) == num_encoder_layers
for state in states:
assert len(state) == 4
assert state[0].shape == (max_memory_size, B, D)
assert state[1].shape == (left_context_length // 4, B, D)
assert state[2].shape == (left_context_length // 4, B, D)
assert torch.equal(
state[3],
U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]),
)
for conv_cache in conv_caches:
assert conv_cache.shape == (B, D, kernel_size - 1)
if __name__ == "__main__":
test_rel_positional_encoding()
test_emformer_attention_forward()
test_emformer_attention_infer()
test_convolution_module_forward()
test_convolution_module_infer()
test_emformer_encoder_layer_forward()
test_emformer_encoder_layer_infer()
test_emformer_encoder_forward()
test_emformer_encoder_infer()
test_emformer_encoder_forward_infer_consistency()
test_emformer_forward()
test_emformer_infer()

File diff suppressed because it is too large Load Diff