first upload the conv_emformer_transducer recipe, integrating convolution module into emformer layers.

This commit is contained in:
yaozengwei 2022-04-10 20:24:20 +08:00
parent 3e131891a2
commit 8129470586
12 changed files with 3470 additions and 0 deletions

View File

@ -0,0 +1 @@
../pruned_transducer_stateless/asr_datamodule.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless/beam_search.py

View File

@ -0,0 +1,549 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./transducer_emformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./transducer_emformer/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./transducer_emformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./transducer_emformer/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./transducer_emformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./transducer_emformer/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./transducer_emformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./transducer_emformer/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--avg-last-n",
type=int,
default=0,
help="""If positive, --epoch and --avg are ignored and it
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
where xxx is the number of processed batches while
saving that checkpoint.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_emformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = model.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
else:
log_interval = 2
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if params.avg_last_n > 0:
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.to(device)
model.eval()
model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless/decoder.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../transducer_stateless/encoder_interface.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless/joiner.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless/model.py

View File

@ -0,0 +1,104 @@
# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class Noam(object):
"""
Implements Noam optimizer.
Proposed in
"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
Modified from
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa
Args:
params:
iterable of parameters to optimize or dicts defining parameter groups
model_size:
attention dimension of the transformer model
factor:
learning rate factor
warm_step:
warmup steps
"""
def __init__(
self,
params,
model_size: int = 256,
factor: float = 10.0,
warm_step: int = 25000,
weight_decay=0,
) -> None:
"""Construct an Noam object."""
self.optimizer = torch.optim.Adam(
params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
)
self._step = 0
self.warmup = warm_step
self.factor = factor
self.model_size = model_size
self._rate = 0
@property
def param_groups(self):
"""Return param_groups."""
return self.optimizer.param_groups
def step(self):
"""Update parameters and rate."""
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p["lr"] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step=None):
"""Implement `lrate` above."""
if step is None:
step = self._step
return (
self.factor
* self.model_size ** (-0.5)
* min(step ** (-0.5), step * self.warmup ** (-1.5))
)
def zero_grad(self):
"""Reset gradient."""
self.optimizer.zero_grad()
def state_dict(self):
"""Return state_dict."""
return {
"_step": self._step,
"warmup": self.warmup,
"factor": self.factor,
"model_size": self.model_size,
"_rate": self._rate,
"optimizer": self.optimizer.state_dict(),
}
def load_state_dict(self, state_dict):
"""Load state_dict."""
for key, value in state_dict.items():
if key == "optimizer":
self.optimizer.load_state_dict(state_dict["optimizer"])
else:
setattr(self, key, value)

View File

@ -0,0 +1 @@
../conformer_ctc/subsampling.py

View File

@ -0,0 +1,359 @@
import torch
def test_emformer_attention_forward():
from emformer import EmformerAttention
B, D = 2, 256
U, R = 12, 2
chunk_length = 2
attention = EmformerAttention(embed_dim=D, nhead=8)
for use_memory in [True, False]:
if use_memory:
S = U // chunk_length
M = S - 1
else:
S, M = 0, 0
Q, KV = R + U + S, M + R + U
utterance = torch.randn(U, B, D)
lengths = torch.randint(1, U + 1, (B,))
lengths[0] = U
right_context = torch.randn(R, B, D)
summary = torch.randn(S, B, D)
memory = torch.randn(M, B, D)
attention_mask = torch.rand(Q, KV) >= 0.5
output_right_context_utterance, output_memory = attention(
utterance,
lengths,
right_context,
summary,
memory,
attention_mask,
)
assert output_right_context_utterance.shape == (R + U, B, D)
assert output_memory.shape == (M, B, D)
def test_emformer_attention_infer():
from emformer import EmformerAttention
B, D = 2, 256
R, L = 4, 2
chunk_length = 2
U = chunk_length
attention = EmformerAttention(embed_dim=D, nhead=8)
for use_memory in [True, False]:
if use_memory:
S, M = 1, 3
else:
S, M = 0, 0
utterance = torch.randn(U, B, D)
lengths = torch.randint(1, U + 1, (B,))
lengths[0] = U
right_context = torch.randn(R, B, D)
summary = torch.randn(S, B, D)
memory = torch.randn(M, B, D)
left_context_key = torch.randn(L, B, D)
left_context_val = torch.randn(L, B, D)
(
output_right_context_utterance,
output_memory,
next_key,
next_val,
) = attention.infer(
utterance,
lengths,
right_context,
summary,
memory,
left_context_key,
left_context_val,
)
assert output_right_context_utterance.shape == (R + U, B, D)
assert output_memory.shape == (S, B, D)
assert next_key.shape == (L + U, B, D)
assert next_val.shape == (L + U, B, D)
def test_emformer_layer_forward():
from emformer import EmformerLayer
B, D = 2, 256
U, R, L = 12, 2, 5
chunk_length = 2
for use_memory in [True, False]:
if use_memory:
S = U // chunk_length
M = S - 1
else:
S, M = 0, 0
layer = EmformerLayer(
d_model=D,
nhead=8,
dim_feedforward=1024,
chunk_length=chunk_length,
cnn_module_kernel=3,
left_context_length=L,
max_memory_size=M,
)
Q, KV = R + U + S, M + R + U
utterance = torch.randn(U, B, D)
lengths = torch.randint(1, U + 1, (B,))
lengths[0] = U
right_context = torch.randn(R, B, D)
memory = torch.randn(M, B, D)
attention_mask = torch.rand(Q, KV) >= 0.5
output_utterance, output_right_context, output_memory = layer(
utterance,
lengths,
right_context,
memory,
attention_mask,
)
assert output_utterance.shape == (U, B, D)
assert output_right_context.shape == (R, B, D)
assert output_memory.shape == (M, B, D)
def test_emformer_layer_infer():
from emformer import EmformerLayer
B, D = 2, 256
R, L = 2, 5
chunk_length = 2
U = chunk_length
for use_memory in [True, False]:
if use_memory:
M = 3
else:
M = 0
layer = EmformerLayer(
d_model=D,
nhead=8,
dim_feedforward=1024,
chunk_length=chunk_length,
cnn_module_kernel=3,
left_context_length=L,
max_memory_size=M,
)
utterance = torch.randn(U, B, D)
lengths = torch.randint(1, U + 1, (B,))
lengths[0] = U
right_context = torch.randn(R, B, D)
memory = torch.randn(M, B, D)
state = None
(
output_utterance,
output_right_context,
output_memory,
output_state,
) = layer.infer(
utterance,
lengths,
right_context,
memory,
state,
)
assert output_utterance.shape == (U, B, D)
assert output_right_context.shape == (R, B, D)
if use_memory:
assert output_memory.shape == (1, B, D)
else:
assert output_memory.shape == (0, B, D)
assert len(output_state) == 4
assert output_state[0].shape == (M, B, D)
assert output_state[1].shape == (L, B, D)
assert output_state[2].shape == (L, B, D)
assert output_state[3].shape == (1, B)
def test_emformer_encoder_forward():
from emformer import EmformerEncoder
B, D = 2, 256
U, R, L = 12, 2, 5
chunk_length = 2
for use_memory in [True, False]:
if use_memory:
S = U // chunk_length
M = S - 1
else:
S, M = 0, 0
encoder = EmformerEncoder(
chunk_length=chunk_length,
d_model=D,
dim_feedforward=1024,
num_encoder_layers=2,
cnn_module_kernel=3,
left_context_length=L,
right_context_length=R,
max_memory_size=M,
)
x = torch.randn(U + R, B, D)
lengths = torch.randint(1, U + R + 1, (B,))
lengths[0] = U + R
output, output_lengths = encoder(x, lengths)
assert output.shape == (U, B, D)
assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0))
def test_emformer_encoder_infer():
from emformer import EmformerEncoder
B, D = 2, 256
R, L = 2, 5
chunk_length = 2
U = chunk_length
num_chunks = 3
num_encoder_layers = 2
for use_memory in [True, False]:
if use_memory:
M = 3
else:
M = 0
encoder = EmformerEncoder(
chunk_length=chunk_length,
d_model=D,
dim_feedforward=1024,
num_encoder_layers=num_encoder_layers,
cnn_module_kernel=3,
left_context_length=L,
right_context_length=R,
max_memory_size=M,
)
states = None
for chunk_idx in range(num_chunks):
x = torch.randn(U + R, B, D)
lengths = torch.randint(1, U + R + 1, (B,))
lengths[0] = U + R
output, output_lengths, states = encoder.infer(x, lengths, states)
assert output.shape == (U, B, D)
assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0))
assert len(states) == num_encoder_layers
for state in states:
assert len(state) == 4
assert state[0].shape == (M, B, D)
assert state[1].shape == (L, B, D)
assert state[2].shape == (L, B, D)
assert torch.equal(
state[3], (chunk_idx + 1) * U * torch.ones_like(state[3])
)
def test_emformer_forward():
from emformer import Emformer
num_features = 80
output_dim = 1000
chunk_length = 8
L, R = 128, 4
B, D, U = 2, 256, 80
for use_memory in [True, False]:
if use_memory:
M = 3
else:
M = 0
model = Emformer(
num_features=num_features,
output_dim=output_dim,
chunk_length=chunk_length,
subsampling_factor=4,
d_model=D,
cnn_module_kernel=3,
left_context_length=L,
right_context_length=R,
max_memory_size=M,
vgg_frontend=False,
)
x = torch.randn(B, U + R + 3, num_features)
x_lens = torch.randint(1, U + R + 3 + 1, (B,))
x_lens[0] = U + R + 3
logits, output_lengths = model(x, x_lens)
assert logits.shape == (B, U // 4, output_dim)
assert torch.equal(
output_lengths,
torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0),
)
def test_emformer_infer():
from emformer import Emformer
num_features = 80
output_dim = 1000
chunk_length = 8
U = chunk_length
L, R = 128, 4
B, D = 2, 256
num_chunks = 3
num_encoder_layers = 2
for use_memory in [True, False]:
if use_memory:
M = 3
else:
M = 0
model = Emformer(
num_features=num_features,
output_dim=output_dim,
chunk_length=chunk_length,
subsampling_factor=4,
d_model=D,
num_encoder_layers=num_encoder_layers,
cnn_module_kernel=3,
left_context_length=L,
right_context_length=R,
max_memory_size=M,
vgg_frontend=False,
)
states = None
for chunk_idx in range(num_chunks):
x = torch.randn(B, U + R + 3, num_features)
x_lens = torch.randint(1, U + R + 3 + 1, (B,))
x_lens[0] = U + R + 3
logits, output_lengths, states = model.infer(x, x_lens, states)
assert logits.shape == (B, U // 4, output_dim)
assert torch.equal(
output_lengths,
torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0),
)
assert len(states) == num_encoder_layers
for state in states:
assert len(state) == 4
assert state[0].shape == (M, B, D)
assert state[1].shape == (L // 4, B, D)
assert state[2].shape == (L // 4, B, D)
assert torch.equal(
state[3],
U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]),
)
if __name__ == "__main__":
test_emformer_attention_forward()
test_emformer_attention_infer()
test_emformer_layer_forward()
test_emformer_layer_infer()
test_emformer_encoder_forward()
test_emformer_encoder_infer()
test_emformer_forward()
test_emformer_infer()

File diff suppressed because it is too large Load Diff