Export streaming zipformer to ncnn (#906)

This commit is contained in:
Fangjun Kuang 2023-02-13 23:41:43 +08:00 committed by GitHub
parent e63a8c27f8
commit c5e687ddf5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 5805 additions and 30 deletions

View File

@ -131,3 +131,102 @@ python3 ./lstm_transducer_stateless2/ncnn-decode.py \
rm -rf $repo
log "--------------------------------------------------------------------------"
log "=========================================================================="
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--exp-dir $repo/exp \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
\
--decode-chunk-len 32 \
--num-encoder-layers "2,4,3,2,4" \
--feedforward-dims "1024,1024,2048,2048,1024" \
--nhead "8,8,8,8,8" \
--encoder-dims "384,384,384,384,384" \
--attention-dims "192,192,192,192,192" \
--encoder-unmasked-dims "256,256,256,256,256" \
--zipformer-downsampling-factors "1,2,4,8,2" \
--cnn-module-kernels "31,31,31,31,31" \
--decoder-dim 512 \
--joiner-dim 512
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt
python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
$repo/test_wavs/1089-134686-0001.wav
rm -rf $repo
log "--------------------------------------------------------------------------"
log "=========================================================================="
repo_url=https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_char_bpe/L.pt"
git lfs pull --include "data/lang_char_bpe/L_disambig.pt"
git lfs pull --include "data/lang_char_bpe/Linv.pt"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
--lang-dir $repo/data/lang_char_bpe \
--exp-dir $repo/exp \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--decode-chunk-len 32 \
--num-encoder-layers "2,4,3,2,4" \
--feedforward-dims "1024,1024,1536,1536,1024" \
--nhead "8,8,8,8,8" \
--encoder-dims "384,384,384,384,384" \
--attention-dims "192,192,192,192,192" \
--encoder-unmasked-dims "256,256,256,256,256" \
--zipformer-downsampling-factors "1,2,4,8,2" \
--cnn-module-kernels "31,31,31,31,31" \
--decoder-dim 512 \
--joiner-dim 512
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt
python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \
--tokens $repo/data/lang_char_bpe/tokens.txt \
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
$repo/test_wavs/0.wav
rm -rf $repo
log "--------------------------------------------------------------------------"

View File

@ -310,6 +310,16 @@ def main():
model.eval()
convert_scaled_to_non_scaled(model, inplace=True)
encoder_num_param = sum([p.numel() for p in model.encoder.parameters()])
decoder_num_param = sum([p.numel() for p in model.decoder.parameters()])
joiner_num_param = sum([p.numel() for p in model.joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
logging.info("Using torch.jit.trace()")
logging.info("Exporting encoder")

View File

@ -203,11 +203,8 @@ class Model:
# (1, 512, 2) -> (512, 2)
ex.input(name, ncnn.Mat(states[i * 4 + 3].numpy()).clone())
import pdb
# pdb.set_trace()
ret, ncnn_out0 = ex.extract("out0")
# assert ret == 0, ret
assert ret == 0, ret
encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
out_states: List[torch.Tensor] = []

View File

@ -99,7 +99,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
default="lstm_transducer_stateless2/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
@ -316,6 +316,16 @@ def main():
model.eval()
convert_scaled_to_non_scaled(model, inplace=True)
encoder_num_param = sum([p.numel() for p in model.encoder.parameters()])
decoder_num_param = sum([p.numel() for p in model.decoder.parameters()])
joiner_num_param = sum([p.numel() for p in model.joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
logging.info("Using torch.jit.trace()")
logging.info("Exporting encoder")

View File

@ -87,7 +87,11 @@ class Decoder(nn.Module):
y = y.to(torch.int64)
# this stuff about clamp() is a temporary fix for a mismatch
# at utterance start, we use negative ids in beam_search.py
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
if torch.jit.is_tracing():
# This is for exporting to PNNX via ONNX
embedding_out = self.embedding(y)
else:
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:

View File

@ -53,7 +53,6 @@ class Joiner(nn.Module):
"""
assert encoder_out.ndim == decoder_out.ndim
assert encoder_out.ndim in (2, 4)
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)

View File

@ -22,11 +22,101 @@ BasicNorm is replaced by a module with `exp` removed.
"""
import copy
from typing import List
from typing import List, Tuple
import torch
import torch.nn as nn
from scaling import ActivationBalancer, BasicNorm, Whiten
from zipformer import PoolingModule
class PoolingModuleNoProj(nn.Module):
def forward(
self,
x: torch.Tensor,
cached_len: torch.Tensor,
cached_avg: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A tensor of shape (T, N, C)
cached_len:
A tensor of shape (N,)
cached_avg:
A tensor of shape (N, C)
Returns:
Return a tuple containing:
- new_x
- new_cached_len
- new_cached_avg
"""
x = x.cumsum(dim=0) # (T, N, C)
x = x + (cached_avg * cached_len.unsqueeze(1)).unsqueeze(0)
# Cumulated numbers of frames from start
cum_mask = torch.arange(1, x.size(0) + 1, device=x.device)
cum_mask = cum_mask.unsqueeze(1) + cached_len.unsqueeze(0) # (T, N)
pooling_mask = (1.0 / cum_mask).unsqueeze(2)
# now pooling_mask: (T, N, 1)
x = x * pooling_mask # (T, N, C)
cached_len = cached_len + x.size(0)
cached_avg = x[-1]
return x, cached_len, cached_avg
class PoolingModuleWithProj(nn.Module):
def __init__(self, proj: torch.nn.Module):
super().__init__()
self.proj = proj
self.pooling = PoolingModuleNoProj()
def forward(
self,
x: torch.Tensor,
cached_len: torch.Tensor,
cached_avg: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A tensor of shape (T, N, C)
cached_len:
A tensor of shape (N,)
cached_avg:
A tensor of shape (N, C)
Returns:
Return a tuple containing:
- new_x
- new_cached_len
- new_cached_avg
"""
x, cached_len, cached_avg = self.pooling(x, cached_len, cached_avg)
return self.proj(x), cached_len, cached_avg
def streaming_forward(
self,
x: torch.Tensor,
cached_len: torch.Tensor,
cached_avg: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A tensor of shape (T, N, C)
cached_len:
A tensor of shape (N,)
cached_avg:
A tensor of shape (N, C)
Returns:
Return a tuple containing:
- new_x
- new_cached_len
- new_cached_avg
"""
x, cached_len, cached_avg = self.pooling(x, cached_len, cached_avg)
return self.proj(x), cached_len, cached_avg
class NonScaledNorm(nn.Module):
@ -53,7 +143,7 @@ class NonScaledNorm(nn.Module):
def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm:
assert isinstance(basic_norm, BasicNorm), type(BasicNorm)
assert isinstance(basic_norm, BasicNorm), type(basic_norm)
norm = NonScaledNorm(
num_channels=basic_norm.num_channels,
eps_exp=basic_norm.eps.data.exp().item(),
@ -62,6 +152,11 @@ def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm:
return norm
def convert_pooling_module(pooling: PoolingModule) -> PoolingModuleWithProj:
assert isinstance(pooling, PoolingModule), type(pooling)
return PoolingModuleWithProj(proj=pooling.proj)
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
# get_submodule was added to nn.Module at v1.9.0
def get_submodule(model, target):
@ -83,6 +178,7 @@ def get_submodule(model, target):
def convert_scaled_to_non_scaled(
model: nn.Module,
inplace: bool = False,
is_pnnx: bool = False,
):
"""
Args:
@ -91,6 +187,8 @@ def convert_scaled_to_non_scaled(
inplace:
If True, the input model is modified inplace.
If False, the input model is copied and we modify the copied version.
is_pnnx:
True if we are going to export the model for PNNX.
Return:
Return a model without scaled layers.
"""
@ -103,6 +201,8 @@ def convert_scaled_to_non_scaled(
d[name] = convert_basic_norm(m)
elif isinstance(m, (ActivationBalancer, Whiten)):
d[name] = nn.Identity()
elif isinstance(m, PoolingModule) and is_pnnx:
d[name] = convert_pooling_module(m)
for k, v in d.items():
if "." in k:

View File

@ -1,3 +1,10 @@
This recipe implements Streaming Zipformer-Transducer model.
See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer_transducer.html for detailed tutorials.
[./emformer.py](./emformer.py) and [./train.py](./train.py)
are basically the same as
[./emformer2.py](./emformer2.py) and [./train2.py](./train2.py).
The only purpose of [./emformer2.py](./emformer2.py) and [./train2.py](./train2.py)
is for exporting to [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn).

View File

@ -0,0 +1,367 @@
#!/usr/bin/env python3
"""
Please see
https://k2-fsa.github.io/icefall/model-export/export-ncnn.html
for more details about how to use this file.
We use
https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed
to demonstrate the usage of this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_char_bpe/L.pt"
git lfs pull --include "data/lang_char_bpe/L_disambig.pt"
git lfs pull --include "data/lang_char_bpe/Linv.pt"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export to ncnn
./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
--lang-dir $repo/data/lang_char_bpe \
--exp-dir $repo/exp \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--decode-chunk-len 32 \
--num-encoder-layers "2,4,3,2,4" \
--feedforward-dims "1024,1024,1536,1536,1024" \
--nhead "8,8,8,8,8" \
--encoder-dims "384,384,384,384,384" \
--attention-dims "192,192,192,192,192" \
--encoder-unmasked-dims "256,256,256,256,256" \
--zipformer-downsampling-factors "1,2,4,8,2" \
--cnn-module-kernels "31,31,31,31,31" \
--decoder-dim 512 \
--joiner-dim 512
cd $repo/exp
pnnx encoder_jit_trace-pnnx.pt
pnnx decoder_jit_trace-pnnx.pt
pnnx joiner_jit_trace-pnnx.pt
You can find converted models at
https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-bilingual-zh-en-2023-02-13
See ./streaming-ncnn-decode.py
and
https://github.com/k2-fsa/sherpa-ncnn
for usage.
"""
import argparse
import logging
from pathlib import Path
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train2 import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7_streaming/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="The lang dir",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
add_model_arguments(parser)
return parser
def export_encoder_model_jit_trace(
encoder_model: torch.nn.Module,
encoder_filename: str,
) -> None:
"""Export the given encoder model with torch.jit.trace()
Note: The warmup argument is fixed to 1.
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported model.
"""
encoder_model.__class__.forward = encoder_model.__class__.streaming_forward
decode_chunk_len = encoder_model.decode_chunk_size * 2
pad_length = 7
T = decode_chunk_len + pad_length # 32 + 7 = 39
logging.info(f"decode_chunk_len: {decode_chunk_len}")
logging.info(f"T: {T}")
x = torch.zeros(1, T, 80, dtype=torch.float32)
states = encoder_model.get_init_state()
traced_model = torch.jit.trace(encoder_model, (x, states))
traced_model.save(encoder_filename)
logging.info(f"Saved to {encoder_filename}")
def export_decoder_model_jit_trace(
decoder_model: torch.nn.Module,
decoder_filename: str,
) -> None:
"""Export the given decoder model with torch.jit.trace()
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The input decoder model
decoder_filename:
The filename to save the exported model.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = torch.tensor([False])
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
traced_model.save(decoder_filename)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_jit_trace(
joiner_model: torch.nn.Module,
joiner_filename: str,
) -> None:
"""Export the given joiner model with torch.jit.trace()
Note: The argument project_input is fixed to True. A user should not
project the encoder_out/decoder_out by himself/herself. The exported joiner
will do that for the user.
Args:
joiner_model:
The input joiner model
joiner_filename:
The filename to save the exported model.
"""
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
traced_model.save(joiner_filename)
logging.info(f"Saved to {joiner_filename}")
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn")
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True)
encoder_num_param = sum([p.numel() for p in model.encoder.parameters()])
decoder_num_param = sum([p.numel() for p in model.decoder.parameters()])
joiner_num_param = sum([p.numel() for p in model.joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
logging.info("Using torch.jit.trace()")
logging.info("Exporting encoder")
encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename)
logging.info("Exporting decoder")
decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename)
logging.info("Exporting joiner")
joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
main()

View File

@ -0,0 +1,369 @@
#!/usr/bin/env python3
"""
Please see
https://k2-fsa.github.io/icefall/model-export/export-ncnn.html
for more details about how to use this file.
We use
https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
to demonstrate the usage of this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export to ncnn
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--exp-dir $repo/exp \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
\
--decode-chunk-len 32 \
--num-encoder-layers "2,4,3,2,4" \
--feedforward-dims "1024,1024,2048,2048,1024" \
--nhead "8,8,8,8,8" \
--encoder-dims "384,384,384,384,384" \
--attention-dims "192,192,192,192,192" \
--encoder-unmasked-dims "256,256,256,256,256" \
--zipformer-downsampling-factors "1,2,4,8,2" \
--cnn-module-kernels "31,31,31,31,31" \
--decoder-dim 512 \
--joiner-dim 512
cd $repo/exp
pnnx encoder_jit_trace-pnnx.pt
pnnx decoder_jit_trace-pnnx.pt
pnnx joiner_jit_trace-pnnx.pt
You can find converted models at
https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-en-2023-02-13
See ./streaming-ncnn-decode.py
and
https://github.com/k2-fsa/sherpa-ncnn
for usage.
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train2 import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7_streaming/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
add_model_arguments(parser)
return parser
def export_encoder_model_jit_trace(
encoder_model: torch.nn.Module,
encoder_filename: str,
) -> None:
"""Export the given encoder model with torch.jit.trace()
Note: The warmup argument is fixed to 1.
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported model.
"""
encoder_model.__class__.forward = encoder_model.__class__.streaming_forward
decode_chunk_len = encoder_model.decode_chunk_size * 2
pad_length = 7
T = decode_chunk_len + pad_length # 32 + 7 = 39
logging.info(f"decode_chunk_len: {decode_chunk_len}")
logging.info(f"T: {T}")
x = torch.zeros(1, T, 80, dtype=torch.float32)
states = encoder_model.get_init_state()
traced_model = torch.jit.trace(encoder_model, (x, states))
traced_model.save(encoder_filename)
logging.info(f"Saved to {encoder_filename}")
def export_decoder_model_jit_trace(
decoder_model: torch.nn.Module,
decoder_filename: str,
) -> None:
"""Export the given decoder model with torch.jit.trace()
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The input decoder model
decoder_filename:
The filename to save the exported model.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = torch.tensor([False])
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
traced_model.save(decoder_filename)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_jit_trace(
joiner_model: torch.nn.Module,
joiner_filename: str,
) -> None:
"""Export the given joiner model with torch.jit.trace()
Note: The argument project_input is fixed to True. A user should not
project the encoder_out/decoder_out by himself/herself. The exported joiner
will do that for the user.
Args:
joiner_model:
The input joiner model
joiner_filename:
The filename to save the exported model.
"""
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
traced_model.save(joiner_filename)
logging.info(f"Saved to {joiner_filename}")
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn")
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True)
encoder_num_param = sum([p.numel() for p in model.encoder.parameters()])
decoder_num_param = sum([p.numel() for p in model.decoder.parameters()])
joiner_num_param = sum([p.numel() for p in model.joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
logging.info("Using torch.jit.trace()")
logging.info("Exporting encoder")
encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename)
logging.info("Exporting decoder")
decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename)
logging.info("Exporting joiner")
joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
main()

View File

@ -0,0 +1,419 @@
#!/usr/bin/env python3
#
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \
--tokens ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/tokens.txt \
--encoder-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/joiner_jit_trace-pnnx.ncnn.bin \
./sherpa-ncnn-streaming-zipformer-en-2023-02-13/test_wavs/1089-134686-0001.wav
You can find pretrained models at
- English: https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-en-2023-02-13
- Bilingual (Chinese + English): https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-bilingual-zh-en-2023-02-13
"""
import argparse
import logging
from typing import List, Optional, Tuple
import k2
import ncnn
import torch
import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--tokens",
type=str,
help="Path to tokens.txt",
)
parser.add_argument(
"--encoder-param-filename",
type=str,
help="Path to encoder.ncnn.param",
)
parser.add_argument(
"--encoder-bin-filename",
type=str,
help="Path to encoder.ncnn.bin",
)
parser.add_argument(
"--decoder-param-filename",
type=str,
help="Path to decoder.ncnn.param",
)
parser.add_argument(
"--decoder-bin-filename",
type=str,
help="Path to decoder.ncnn.bin",
)
parser.add_argument(
"--joiner-param-filename",
type=str,
help="Path to joiner.ncnn.param",
)
parser.add_argument(
"--joiner-bin-filename",
type=str,
help="Path to joiner.ncnn.bin",
)
parser.add_argument(
"sound_filename",
type=str,
help="Path to foo.wav",
)
return parser.parse_args()
def to_int_tuple(s: str):
return tuple(map(int, s.split(",")))
class Model:
def __init__(self, args):
self.init_encoder(args)
self.init_decoder(args)
self.init_joiner(args)
# Please change the parameters according to your model
self.num_encoder_layers = to_int_tuple("2,4,3,2,4")
self.encoder_dims = to_int_tuple("384,384,384,384,384") # also known as d_model
self.attention_dims = to_int_tuple("192,192,192,192,192")
self.zipformer_downsampling_factors = to_int_tuple("1,2,4,8,2")
self.cnn_module_kernels = to_int_tuple("31,31,31,31,31")
self.decode_chunk_size = 32 // 2
num_left_chunks = 4
self.left_context_length = self.decode_chunk_size * num_left_chunks # 64
self.chunk_length = self.decode_chunk_size * 2
pad_length = 7
self.T = self.chunk_length + pad_length
def get_init_states(self) -> List[torch.Tensor]:
cached_len_list = []
cached_avg_list = []
cached_key_list = []
cached_val_list = []
cached_val2_list = []
cached_conv1_list = []
cached_conv2_list = []
for i in range(len(self.num_encoder_layers)):
num_layers = self.num_encoder_layers[i]
ds = self.zipformer_downsampling_factors[i]
attention_dim = self.attention_dims[i]
left_context_length = self.left_context_length // ds
encoder_dim = self.encoder_dims[i]
cnn_module_kernel = self.cnn_module_kernels[i]
cached_len_list.append(torch.zeros(num_layers))
cached_avg_list.append(torch.zeros(num_layers, encoder_dim))
cached_key_list.append(
torch.zeros(num_layers, left_context_length, attention_dim)
)
cached_val_list.append(
torch.zeros(num_layers, left_context_length, attention_dim // 2)
)
cached_val2_list.append(
torch.zeros(num_layers, left_context_length, attention_dim // 2)
)
cached_conv1_list.append(
torch.zeros(num_layers, encoder_dim, cnn_module_kernel - 1)
)
cached_conv2_list.append(
torch.zeros(num_layers, encoder_dim, cnn_module_kernel - 1)
)
states = (
cached_len_list
+ cached_avg_list
+ cached_key_list
+ cached_val_list
+ cached_val2_list
+ cached_conv1_list
+ cached_conv2_list
)
return states
def init_encoder(self, args):
encoder_net = ncnn.Net()
encoder_net.opt.use_packing_layout = False
encoder_net.opt.use_fp16_storage = False
encoder_net.opt.num_threads = 4
encoder_param = args.encoder_param_filename
encoder_model = args.encoder_bin_filename
encoder_net.load_param(encoder_param)
encoder_net.load_model(encoder_model)
self.encoder_net = encoder_net
def init_decoder(self, args):
decoder_param = args.decoder_param_filename
decoder_model = args.decoder_bin_filename
decoder_net = ncnn.Net()
decoder_net.opt.num_threads = 4
decoder_net.load_param(decoder_param)
decoder_net.load_model(decoder_model)
self.decoder_net = decoder_net
def init_joiner(self, args):
joiner_param = args.joiner_param_filename
joiner_model = args.joiner_bin_filename
joiner_net = ncnn.Net()
joiner_net.opt.num_threads = 4
joiner_net.load_param(joiner_param)
joiner_net.load_model(joiner_model)
self.joiner_net = joiner_net
def run_encoder(
self,
x: torch.Tensor,
states: List[torch.Tensor],
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Args:
x:
A tensor of shape (T, C)
states:
A list of tensors. len(states) == self.num_layers * 4
Returns:
Return a tuple containing:
- encoder_out, a tensor of shape (T, encoder_dim).
- next_states, a list of tensors containing the next states
"""
with self.encoder_net.create_extractor() as ex:
ex.input("in0", ncnn.Mat(x.numpy()).clone())
for i in range(len(states)):
name = f"in{i+1}"
ex.input(name, ncnn.Mat(states[i].squeeze().numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
out_states: List[torch.Tensor] = []
for i in range(len(states)):
name = f"out{i+1}"
ret, ncnn_out_state = ex.extract(name)
assert ret == 0, ret
ncnn_out_state = torch.from_numpy(ncnn_out_state.numpy())
if i < len(self.num_encoder_layers):
# for cached_len, we need to discard the last dim
ncnn_out_state = ncnn_out_state.squeeze(1)
out_states.append(ncnn_out_state)
return encoder_out, out_states
def run_decoder(self, decoder_input):
assert decoder_input.dtype == torch.int32
with self.decoder_net.create_extractor() as ex:
ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
return decoder_out
def run_joiner(self, encoder_out, decoder_out):
with self.joiner_net.create_extractor() as ex:
ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone())
ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone()
return joiner_out
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
def create_streaming_feature_extractor() -> OnlineFeature:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
return OnlineFbank(opts)
def greedy_search(
model: Model,
encoder_out: torch.Tensor,
decoder_out: Optional[torch.Tensor] = None,
hyp: Optional[List[int]] = None,
):
context_size = 2
blank_id = 0
if decoder_out is None:
assert hyp is None, hyp
hyp = [blank_id] * context_size
decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size)
decoder_out = model.run_decoder(decoder_input).squeeze(0)
else:
assert decoder_out.ndim == 1
assert hyp is not None, hyp
T = encoder_out.size(0)
for t in range(T):
cur_encoder_out = encoder_out[t]
joiner_out = model.run_joiner(cur_encoder_out, decoder_out)
y = joiner_out.argmax(dim=0).item()
if y != blank_id:
hyp.append(y)
decoder_input = hyp[-context_size:]
decoder_input = torch.tensor(decoder_input, dtype=torch.int32)
decoder_out = model.run_decoder(decoder_input).squeeze(0)
return hyp, decoder_out
def main():
args = get_args()
logging.info(vars(args))
model = Model(args)
sound_file = args.sound_filename
sample_rate = 16000
logging.info("Constructing Fbank computer")
online_fbank = create_streaming_feature_extractor()
logging.info(f"Reading sound files: {sound_file}")
wave_samples = read_sound_files(
filenames=[sound_file],
expected_sample_rate=sample_rate,
)[0]
logging.info(wave_samples.shape)
tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
wave_samples = torch.cat([wave_samples, tail_padding])
states = model.get_init_states()
logging.info(f"number of states: {len(states)}")
hyp = None
decoder_out = None
num_processed_frames = 0
segment = model.T
offset = model.chunk_length
chunk = int(1 * sample_rate) # 0.2 second
start = 0
while start < wave_samples.numel():
end = min(start + chunk, wave_samples.numel())
samples = wave_samples[start:end]
start += chunk
online_fbank.accept_waveform(
sampling_rate=sample_rate,
waveform=samples,
)
while online_fbank.num_frames_ready - num_processed_frames >= segment:
frames = []
for i in range(segment):
frames.append(online_fbank.get_frame(num_processed_frames + i))
num_processed_frames += offset
frames = torch.cat(frames, dim=0)
encoder_out, states = model.run_encoder(frames, states)
hyp, decoder_out = greedy_search(model, encoder_out, decoder_out, hyp)
symbol_table = k2.SymbolTable.from_file(args.tokens)
context_size = 2
text = ""
for i in hyp[context_size:]:
text += symbol_table[i]
text = text.replace("", " ").strip()
logging.info(sound_file)
logging.info(text)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

File diff suppressed because it is too large Load Diff

View File

@ -44,7 +44,6 @@ from scaling import (
)
from torch import Tensor, nn
from icefall.dist import get_rank
from icefall.utils import make_pad_mask, subsequent_chunk_mask
@ -271,7 +270,6 @@ class Zipformer(EncoderInterface):
num_encoder_layers (int): number of encoder layers
dropout (float): dropout rate
cnn_module_kernels (int): Kernel size of convolution module
vgg_frontend (bool): whether to use vgg frontend.
warmup_batches (float): number of batches to warm up over
"""
@ -388,9 +386,9 @@ class Zipformer(EncoderInterface):
def _init_skip_modules(self):
"""
If self.zipformer_downsampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer
indexed 4 (in zero indexing), with has subsapling_factor=4, we combine the output of
layers 2 and 3; and at the input of layer indexed 5, which which has subsampling_factor=2,
we combine the outputs of layers 1 and 5.
indexed 4 (in zero indexing), which has subsampling_factor=4, we combine the output of
layers 2 and 3; and at the input of layer indexed 5, which has subsampling_factor=2,
we combine the outputs of layers 1 and 4.
"""
skip_layers = []
skip_modules = []
@ -1272,8 +1270,7 @@ class ZipformerEncoder(nn.Module):
Shape:
src: (S, N, E).
cached_len: (N,)
N is the batch size.
cached_len: (num_layers,)
cached_avg: (num_layers, N, C).
N is the batch size, C is the feature dimension.
cached_key: (num_layers, left_context_len, N, K).
@ -1289,8 +1286,8 @@ class ZipformerEncoder(nn.Module):
Returns: A tuple of 8 tensors:
- output tensor
- updated cached number of past frmaes.
- updated cached average of past frmaes.
- updated cached number of past frames.
- updated cached average of past frames.
- updated cached key tensor of of the first attention module.
- updated cached value tensor of of the first attention module.
- updated cached value tensor of of the second attention module.
@ -1522,9 +1519,6 @@ class AttentionDownsample(torch.nn.Module):
"""
def __init__(self, in_channels: int, out_channels: int, downsample: int):
"""
Require out_channels > in_channels.
"""
super(AttentionDownsample, self).__init__()
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5))
@ -1902,8 +1896,6 @@ class RelPositionMultiheadAttention(nn.Module):
Args:
x: input to be projected to query, key, value
pos_emb: Positional embedding tensor
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shape:
- Inputs:
@ -1911,13 +1903,6 @@ class RelPositionMultiheadAttention(nn.Module):
the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
- cached_key: :math:`(left_context_len, N, K)`, where N is the batch size, K is the key dimension.
- cached_val: :math:`(left_context_len, N, V)`, where N is the batch size, V is the value dimension.

File diff suppressed because it is too large Load Diff