Export streaming zipformer to ncnn (#906)

This commit is contained in:
Fangjun Kuang 2023-02-13 23:41:43 +08:00 committed by GitHub
parent e63a8c27f8
commit c5e687ddf5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 5805 additions and 30 deletions

View File

@ -131,3 +131,102 @@ python3 ./lstm_transducer_stateless2/ncnn-decode.py \
rm -rf $repo rm -rf $repo
log "--------------------------------------------------------------------------" log "--------------------------------------------------------------------------"
log "=========================================================================="
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--exp-dir $repo/exp \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
\
--decode-chunk-len 32 \
--num-encoder-layers "2,4,3,2,4" \
--feedforward-dims "1024,1024,2048,2048,1024" \
--nhead "8,8,8,8,8" \
--encoder-dims "384,384,384,384,384" \
--attention-dims "192,192,192,192,192" \
--encoder-unmasked-dims "256,256,256,256,256" \
--zipformer-downsampling-factors "1,2,4,8,2" \
--cnn-module-kernels "31,31,31,31,31" \
--decoder-dim 512 \
--joiner-dim 512
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt
python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
$repo/test_wavs/1089-134686-0001.wav
rm -rf $repo
log "--------------------------------------------------------------------------"
log "=========================================================================="
repo_url=https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_char_bpe/L.pt"
git lfs pull --include "data/lang_char_bpe/L_disambig.pt"
git lfs pull --include "data/lang_char_bpe/Linv.pt"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
--lang-dir $repo/data/lang_char_bpe \
--exp-dir $repo/exp \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--decode-chunk-len 32 \
--num-encoder-layers "2,4,3,2,4" \
--feedforward-dims "1024,1024,1536,1536,1024" \
--nhead "8,8,8,8,8" \
--encoder-dims "384,384,384,384,384" \
--attention-dims "192,192,192,192,192" \
--encoder-unmasked-dims "256,256,256,256,256" \
--zipformer-downsampling-factors "1,2,4,8,2" \
--cnn-module-kernels "31,31,31,31,31" \
--decoder-dim 512 \
--joiner-dim 512
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt
python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \
--tokens $repo/data/lang_char_bpe/tokens.txt \
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
$repo/test_wavs/0.wav
rm -rf $repo
log "--------------------------------------------------------------------------"

View File

@ -310,6 +310,16 @@ def main():
model.eval() model.eval()
convert_scaled_to_non_scaled(model, inplace=True) convert_scaled_to_non_scaled(model, inplace=True)
encoder_num_param = sum([p.numel() for p in model.encoder.parameters()])
decoder_num_param = sum([p.numel() for p in model.decoder.parameters()])
joiner_num_param = sum([p.numel() for p in model.joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
logging.info("Using torch.jit.trace()") logging.info("Using torch.jit.trace()")
logging.info("Exporting encoder") logging.info("Exporting encoder")

View File

@ -203,11 +203,8 @@ class Model:
# (1, 512, 2) -> (512, 2) # (1, 512, 2) -> (512, 2)
ex.input(name, ncnn.Mat(states[i * 4 + 3].numpy()).clone()) ex.input(name, ncnn.Mat(states[i * 4 + 3].numpy()).clone())
import pdb
# pdb.set_trace()
ret, ncnn_out0 = ex.extract("out0") ret, ncnn_out0 = ex.extract("out0")
# assert ret == 0, ret assert ret == 0, ret
encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
out_states: List[torch.Tensor] = [] out_states: List[torch.Tensor] = []

View File

@ -99,7 +99,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless2/exp", default="lstm_transducer_stateless2/exp",
help="""It specifies the directory where all training related help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
""", """,
@ -316,6 +316,16 @@ def main():
model.eval() model.eval()
convert_scaled_to_non_scaled(model, inplace=True) convert_scaled_to_non_scaled(model, inplace=True)
encoder_num_param = sum([p.numel() for p in model.encoder.parameters()])
decoder_num_param = sum([p.numel() for p in model.decoder.parameters()])
joiner_num_param = sum([p.numel() for p in model.joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
logging.info("Using torch.jit.trace()") logging.info("Using torch.jit.trace()")
logging.info("Exporting encoder") logging.info("Exporting encoder")

View File

@ -87,7 +87,11 @@ class Decoder(nn.Module):
y = y.to(torch.int64) y = y.to(torch.int64)
# this stuff about clamp() is a temporary fix for a mismatch # this stuff about clamp() is a temporary fix for a mismatch
# at utterance start, we use negative ids in beam_search.py # at utterance start, we use negative ids in beam_search.py
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) if torch.jit.is_tracing():
# This is for exporting to PNNX via ONNX
embedding_out = self.embedding(y)
else:
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
if self.context_size > 1: if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True: if need_pad is True:

View File

@ -53,7 +53,6 @@ class Joiner(nn.Module):
""" """
assert encoder_out.ndim == decoder_out.ndim assert encoder_out.ndim == decoder_out.ndim
assert encoder_out.ndim in (2, 4) assert encoder_out.ndim in (2, 4)
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
if project_input: if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)

View File

@ -22,11 +22,101 @@ BasicNorm is replaced by a module with `exp` removed.
""" """
import copy import copy
from typing import List from typing import List, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from scaling import ActivationBalancer, BasicNorm, Whiten from scaling import ActivationBalancer, BasicNorm, Whiten
from zipformer import PoolingModule
class PoolingModuleNoProj(nn.Module):
def forward(
self,
x: torch.Tensor,
cached_len: torch.Tensor,
cached_avg: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A tensor of shape (T, N, C)
cached_len:
A tensor of shape (N,)
cached_avg:
A tensor of shape (N, C)
Returns:
Return a tuple containing:
- new_x
- new_cached_len
- new_cached_avg
"""
x = x.cumsum(dim=0) # (T, N, C)
x = x + (cached_avg * cached_len.unsqueeze(1)).unsqueeze(0)
# Cumulated numbers of frames from start
cum_mask = torch.arange(1, x.size(0) + 1, device=x.device)
cum_mask = cum_mask.unsqueeze(1) + cached_len.unsqueeze(0) # (T, N)
pooling_mask = (1.0 / cum_mask).unsqueeze(2)
# now pooling_mask: (T, N, 1)
x = x * pooling_mask # (T, N, C)
cached_len = cached_len + x.size(0)
cached_avg = x[-1]
return x, cached_len, cached_avg
class PoolingModuleWithProj(nn.Module):
def __init__(self, proj: torch.nn.Module):
super().__init__()
self.proj = proj
self.pooling = PoolingModuleNoProj()
def forward(
self,
x: torch.Tensor,
cached_len: torch.Tensor,
cached_avg: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A tensor of shape (T, N, C)
cached_len:
A tensor of shape (N,)
cached_avg:
A tensor of shape (N, C)
Returns:
Return a tuple containing:
- new_x
- new_cached_len
- new_cached_avg
"""
x, cached_len, cached_avg = self.pooling(x, cached_len, cached_avg)
return self.proj(x), cached_len, cached_avg
def streaming_forward(
self,
x: torch.Tensor,
cached_len: torch.Tensor,
cached_avg: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A tensor of shape (T, N, C)
cached_len:
A tensor of shape (N,)
cached_avg:
A tensor of shape (N, C)
Returns:
Return a tuple containing:
- new_x
- new_cached_len
- new_cached_avg
"""
x, cached_len, cached_avg = self.pooling(x, cached_len, cached_avg)
return self.proj(x), cached_len, cached_avg
class NonScaledNorm(nn.Module): class NonScaledNorm(nn.Module):
@ -53,7 +143,7 @@ class NonScaledNorm(nn.Module):
def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm: def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm:
assert isinstance(basic_norm, BasicNorm), type(BasicNorm) assert isinstance(basic_norm, BasicNorm), type(basic_norm)
norm = NonScaledNorm( norm = NonScaledNorm(
num_channels=basic_norm.num_channels, num_channels=basic_norm.num_channels,
eps_exp=basic_norm.eps.data.exp().item(), eps_exp=basic_norm.eps.data.exp().item(),
@ -62,6 +152,11 @@ def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm:
return norm return norm
def convert_pooling_module(pooling: PoolingModule) -> PoolingModuleWithProj:
assert isinstance(pooling, PoolingModule), type(pooling)
return PoolingModuleWithProj(proj=pooling.proj)
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa # Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
# get_submodule was added to nn.Module at v1.9.0 # get_submodule was added to nn.Module at v1.9.0
def get_submodule(model, target): def get_submodule(model, target):
@ -83,6 +178,7 @@ def get_submodule(model, target):
def convert_scaled_to_non_scaled( def convert_scaled_to_non_scaled(
model: nn.Module, model: nn.Module,
inplace: bool = False, inplace: bool = False,
is_pnnx: bool = False,
): ):
""" """
Args: Args:
@ -91,6 +187,8 @@ def convert_scaled_to_non_scaled(
inplace: inplace:
If True, the input model is modified inplace. If True, the input model is modified inplace.
If False, the input model is copied and we modify the copied version. If False, the input model is copied and we modify the copied version.
is_pnnx:
True if we are going to export the model for PNNX.
Return: Return:
Return a model without scaled layers. Return a model without scaled layers.
""" """
@ -103,6 +201,8 @@ def convert_scaled_to_non_scaled(
d[name] = convert_basic_norm(m) d[name] = convert_basic_norm(m)
elif isinstance(m, (ActivationBalancer, Whiten)): elif isinstance(m, (ActivationBalancer, Whiten)):
d[name] = nn.Identity() d[name] = nn.Identity()
elif isinstance(m, PoolingModule) and is_pnnx:
d[name] = convert_pooling_module(m)
for k, v in d.items(): for k, v in d.items():
if "." in k: if "." in k:

View File

@ -1,3 +1,10 @@
This recipe implements Streaming Zipformer-Transducer model. This recipe implements Streaming Zipformer-Transducer model.
See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer_transducer.html for detailed tutorials. See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer_transducer.html for detailed tutorials.
[./emformer.py](./emformer.py) and [./train.py](./train.py)
are basically the same as
[./emformer2.py](./emformer2.py) and [./train2.py](./train2.py).
The only purpose of [./emformer2.py](./emformer2.py) and [./train2.py](./train2.py)
is for exporting to [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn).

View File

@ -0,0 +1,367 @@
#!/usr/bin/env python3
"""
Please see
https://k2-fsa.github.io/icefall/model-export/export-ncnn.html
for more details about how to use this file.
We use
https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed
to demonstrate the usage of this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_char_bpe/L.pt"
git lfs pull --include "data/lang_char_bpe/L_disambig.pt"
git lfs pull --include "data/lang_char_bpe/Linv.pt"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export to ncnn
./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
--lang-dir $repo/data/lang_char_bpe \
--exp-dir $repo/exp \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--decode-chunk-len 32 \
--num-encoder-layers "2,4,3,2,4" \
--feedforward-dims "1024,1024,1536,1536,1024" \
--nhead "8,8,8,8,8" \
--encoder-dims "384,384,384,384,384" \
--attention-dims "192,192,192,192,192" \
--encoder-unmasked-dims "256,256,256,256,256" \
--zipformer-downsampling-factors "1,2,4,8,2" \
--cnn-module-kernels "31,31,31,31,31" \
--decoder-dim 512 \
--joiner-dim 512
cd $repo/exp
pnnx encoder_jit_trace-pnnx.pt
pnnx decoder_jit_trace-pnnx.pt
pnnx joiner_jit_trace-pnnx.pt
You can find converted models at
https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-bilingual-zh-en-2023-02-13
See ./streaming-ncnn-decode.py
and
https://github.com/k2-fsa/sherpa-ncnn
for usage.
"""
import argparse
import logging
from pathlib import Path
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train2 import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7_streaming/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="The lang dir",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
add_model_arguments(parser)
return parser
def export_encoder_model_jit_trace(
encoder_model: torch.nn.Module,
encoder_filename: str,
) -> None:
"""Export the given encoder model with torch.jit.trace()
Note: The warmup argument is fixed to 1.
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported model.
"""
encoder_model.__class__.forward = encoder_model.__class__.streaming_forward
decode_chunk_len = encoder_model.decode_chunk_size * 2
pad_length = 7
T = decode_chunk_len + pad_length # 32 + 7 = 39
logging.info(f"decode_chunk_len: {decode_chunk_len}")
logging.info(f"T: {T}")
x = torch.zeros(1, T, 80, dtype=torch.float32)
states = encoder_model.get_init_state()
traced_model = torch.jit.trace(encoder_model, (x, states))
traced_model.save(encoder_filename)
logging.info(f"Saved to {encoder_filename}")
def export_decoder_model_jit_trace(
decoder_model: torch.nn.Module,
decoder_filename: str,
) -> None:
"""Export the given decoder model with torch.jit.trace()
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The input decoder model
decoder_filename:
The filename to save the exported model.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = torch.tensor([False])
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
traced_model.save(decoder_filename)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_jit_trace(
joiner_model: torch.nn.Module,
joiner_filename: str,
) -> None:
"""Export the given joiner model with torch.jit.trace()
Note: The argument project_input is fixed to True. A user should not
project the encoder_out/decoder_out by himself/herself. The exported joiner
will do that for the user.
Args:
joiner_model:
The input joiner model
joiner_filename:
The filename to save the exported model.
"""
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
traced_model.save(joiner_filename)
logging.info(f"Saved to {joiner_filename}")
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn")
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True)
encoder_num_param = sum([p.numel() for p in model.encoder.parameters()])
decoder_num_param = sum([p.numel() for p in model.decoder.parameters()])
joiner_num_param = sum([p.numel() for p in model.joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
logging.info("Using torch.jit.trace()")
logging.info("Exporting encoder")
encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename)
logging.info("Exporting decoder")
decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename)
logging.info("Exporting joiner")
joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
main()

View File

@ -0,0 +1,369 @@
#!/usr/bin/env python3
"""
Please see
https://k2-fsa.github.io/icefall/model-export/export-ncnn.html
for more details about how to use this file.
We use
https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
to demonstrate the usage of this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export to ncnn
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--exp-dir $repo/exp \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
\
--decode-chunk-len 32 \
--num-encoder-layers "2,4,3,2,4" \
--feedforward-dims "1024,1024,2048,2048,1024" \
--nhead "8,8,8,8,8" \
--encoder-dims "384,384,384,384,384" \
--attention-dims "192,192,192,192,192" \
--encoder-unmasked-dims "256,256,256,256,256" \
--zipformer-downsampling-factors "1,2,4,8,2" \
--cnn-module-kernels "31,31,31,31,31" \
--decoder-dim 512 \
--joiner-dim 512
cd $repo/exp
pnnx encoder_jit_trace-pnnx.pt
pnnx decoder_jit_trace-pnnx.pt
pnnx joiner_jit_trace-pnnx.pt
You can find converted models at
https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-en-2023-02-13
See ./streaming-ncnn-decode.py
and
https://github.com/k2-fsa/sherpa-ncnn
for usage.
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train2 import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7_streaming/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
add_model_arguments(parser)
return parser
def export_encoder_model_jit_trace(
encoder_model: torch.nn.Module,
encoder_filename: str,
) -> None:
"""Export the given encoder model with torch.jit.trace()
Note: The warmup argument is fixed to 1.
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported model.
"""
encoder_model.__class__.forward = encoder_model.__class__.streaming_forward
decode_chunk_len = encoder_model.decode_chunk_size * 2
pad_length = 7
T = decode_chunk_len + pad_length # 32 + 7 = 39
logging.info(f"decode_chunk_len: {decode_chunk_len}")
logging.info(f"T: {T}")
x = torch.zeros(1, T, 80, dtype=torch.float32)
states = encoder_model.get_init_state()
traced_model = torch.jit.trace(encoder_model, (x, states))
traced_model.save(encoder_filename)
logging.info(f"Saved to {encoder_filename}")
def export_decoder_model_jit_trace(
decoder_model: torch.nn.Module,
decoder_filename: str,
) -> None:
"""Export the given decoder model with torch.jit.trace()
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The input decoder model
decoder_filename:
The filename to save the exported model.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = torch.tensor([False])
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
traced_model.save(decoder_filename)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_jit_trace(
joiner_model: torch.nn.Module,
joiner_filename: str,
) -> None:
"""Export the given joiner model with torch.jit.trace()
Note: The argument project_input is fixed to True. A user should not
project the encoder_out/decoder_out by himself/herself. The exported joiner
will do that for the user.
Args:
joiner_model:
The input joiner model
joiner_filename:
The filename to save the exported model.
"""
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
traced_model.save(joiner_filename)
logging.info(f"Saved to {joiner_filename}")
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn")
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True)
encoder_num_param = sum([p.numel() for p in model.encoder.parameters()])
decoder_num_param = sum([p.numel() for p in model.decoder.parameters()])
joiner_num_param = sum([p.numel() for p in model.joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
logging.info("Using torch.jit.trace()")
logging.info("Exporting encoder")
encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename)
logging.info("Exporting decoder")
decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename)
logging.info("Exporting joiner")
joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
main()

View File

@ -0,0 +1,419 @@
#!/usr/bin/env python3
#
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \
--tokens ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/tokens.txt \
--encoder-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/joiner_jit_trace-pnnx.ncnn.bin \
./sherpa-ncnn-streaming-zipformer-en-2023-02-13/test_wavs/1089-134686-0001.wav
You can find pretrained models at
- English: https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-en-2023-02-13
- Bilingual (Chinese + English): https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-bilingual-zh-en-2023-02-13
"""
import argparse
import logging
from typing import List, Optional, Tuple
import k2
import ncnn
import torch
import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--tokens",
type=str,
help="Path to tokens.txt",
)
parser.add_argument(
"--encoder-param-filename",
type=str,
help="Path to encoder.ncnn.param",
)
parser.add_argument(
"--encoder-bin-filename",
type=str,
help="Path to encoder.ncnn.bin",
)
parser.add_argument(
"--decoder-param-filename",
type=str,
help="Path to decoder.ncnn.param",
)
parser.add_argument(
"--decoder-bin-filename",
type=str,
help="Path to decoder.ncnn.bin",
)
parser.add_argument(
"--joiner-param-filename",
type=str,
help="Path to joiner.ncnn.param",
)
parser.add_argument(
"--joiner-bin-filename",
type=str,
help="Path to joiner.ncnn.bin",
)
parser.add_argument(
"sound_filename",
type=str,
help="Path to foo.wav",
)
return parser.parse_args()
def to_int_tuple(s: str):
return tuple(map(int, s.split(",")))
class Model:
def __init__(self, args):
self.init_encoder(args)
self.init_decoder(args)
self.init_joiner(args)
# Please change the parameters according to your model
self.num_encoder_layers = to_int_tuple("2,4,3,2,4")
self.encoder_dims = to_int_tuple("384,384,384,384,384") # also known as d_model
self.attention_dims = to_int_tuple("192,192,192,192,192")
self.zipformer_downsampling_factors = to_int_tuple("1,2,4,8,2")
self.cnn_module_kernels = to_int_tuple("31,31,31,31,31")
self.decode_chunk_size = 32 // 2
num_left_chunks = 4
self.left_context_length = self.decode_chunk_size * num_left_chunks # 64
self.chunk_length = self.decode_chunk_size * 2
pad_length = 7
self.T = self.chunk_length + pad_length
def get_init_states(self) -> List[torch.Tensor]:
cached_len_list = []
cached_avg_list = []
cached_key_list = []
cached_val_list = []
cached_val2_list = []
cached_conv1_list = []
cached_conv2_list = []
for i in range(len(self.num_encoder_layers)):
num_layers = self.num_encoder_layers[i]
ds = self.zipformer_downsampling_factors[i]
attention_dim = self.attention_dims[i]
left_context_length = self.left_context_length // ds
encoder_dim = self.encoder_dims[i]
cnn_module_kernel = self.cnn_module_kernels[i]
cached_len_list.append(torch.zeros(num_layers))
cached_avg_list.append(torch.zeros(num_layers, encoder_dim))
cached_key_list.append(
torch.zeros(num_layers, left_context_length, attention_dim)
)
cached_val_list.append(
torch.zeros(num_layers, left_context_length, attention_dim // 2)
)
cached_val2_list.append(
torch.zeros(num_layers, left_context_length, attention_dim // 2)
)
cached_conv1_list.append(
torch.zeros(num_layers, encoder_dim, cnn_module_kernel - 1)
)
cached_conv2_list.append(
torch.zeros(num_layers, encoder_dim, cnn_module_kernel - 1)
)
states = (
cached_len_list
+ cached_avg_list
+ cached_key_list
+ cached_val_list
+ cached_val2_list
+ cached_conv1_list
+ cached_conv2_list
)
return states
def init_encoder(self, args):
encoder_net = ncnn.Net()
encoder_net.opt.use_packing_layout = False
encoder_net.opt.use_fp16_storage = False
encoder_net.opt.num_threads = 4
encoder_param = args.encoder_param_filename
encoder_model = args.encoder_bin_filename
encoder_net.load_param(encoder_param)
encoder_net.load_model(encoder_model)
self.encoder_net = encoder_net
def init_decoder(self, args):
decoder_param = args.decoder_param_filename
decoder_model = args.decoder_bin_filename
decoder_net = ncnn.Net()
decoder_net.opt.num_threads = 4
decoder_net.load_param(decoder_param)
decoder_net.load_model(decoder_model)
self.decoder_net = decoder_net
def init_joiner(self, args):
joiner_param = args.joiner_param_filename
joiner_model = args.joiner_bin_filename
joiner_net = ncnn.Net()
joiner_net.opt.num_threads = 4
joiner_net.load_param(joiner_param)
joiner_net.load_model(joiner_model)
self.joiner_net = joiner_net
def run_encoder(
self,
x: torch.Tensor,
states: List[torch.Tensor],
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Args:
x:
A tensor of shape (T, C)
states:
A list of tensors. len(states) == self.num_layers * 4
Returns:
Return a tuple containing:
- encoder_out, a tensor of shape (T, encoder_dim).
- next_states, a list of tensors containing the next states
"""
with self.encoder_net.create_extractor() as ex:
ex.input("in0", ncnn.Mat(x.numpy()).clone())
for i in range(len(states)):
name = f"in{i+1}"
ex.input(name, ncnn.Mat(states[i].squeeze().numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
out_states: List[torch.Tensor] = []
for i in range(len(states)):
name = f"out{i+1}"
ret, ncnn_out_state = ex.extract(name)
assert ret == 0, ret
ncnn_out_state = torch.from_numpy(ncnn_out_state.numpy())
if i < len(self.num_encoder_layers):
# for cached_len, we need to discard the last dim
ncnn_out_state = ncnn_out_state.squeeze(1)
out_states.append(ncnn_out_state)
return encoder_out, out_states
def run_decoder(self, decoder_input):
assert decoder_input.dtype == torch.int32
with self.decoder_net.create_extractor() as ex:
ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
return decoder_out
def run_joiner(self, encoder_out, decoder_out):
with self.joiner_net.create_extractor() as ex:
ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone())
ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone()
return joiner_out
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
def create_streaming_feature_extractor() -> OnlineFeature:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
return OnlineFbank(opts)
def greedy_search(
model: Model,
encoder_out: torch.Tensor,
decoder_out: Optional[torch.Tensor] = None,
hyp: Optional[List[int]] = None,
):
context_size = 2
blank_id = 0
if decoder_out is None:
assert hyp is None, hyp
hyp = [blank_id] * context_size
decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size)
decoder_out = model.run_decoder(decoder_input).squeeze(0)
else:
assert decoder_out.ndim == 1
assert hyp is not None, hyp
T = encoder_out.size(0)
for t in range(T):
cur_encoder_out = encoder_out[t]
joiner_out = model.run_joiner(cur_encoder_out, decoder_out)
y = joiner_out.argmax(dim=0).item()
if y != blank_id:
hyp.append(y)
decoder_input = hyp[-context_size:]
decoder_input = torch.tensor(decoder_input, dtype=torch.int32)
decoder_out = model.run_decoder(decoder_input).squeeze(0)
return hyp, decoder_out
def main():
args = get_args()
logging.info(vars(args))
model = Model(args)
sound_file = args.sound_filename
sample_rate = 16000
logging.info("Constructing Fbank computer")
online_fbank = create_streaming_feature_extractor()
logging.info(f"Reading sound files: {sound_file}")
wave_samples = read_sound_files(
filenames=[sound_file],
expected_sample_rate=sample_rate,
)[0]
logging.info(wave_samples.shape)
tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
wave_samples = torch.cat([wave_samples, tail_padding])
states = model.get_init_states()
logging.info(f"number of states: {len(states)}")
hyp = None
decoder_out = None
num_processed_frames = 0
segment = model.T
offset = model.chunk_length
chunk = int(1 * sample_rate) # 0.2 second
start = 0
while start < wave_samples.numel():
end = min(start + chunk, wave_samples.numel())
samples = wave_samples[start:end]
start += chunk
online_fbank.accept_waveform(
sampling_rate=sample_rate,
waveform=samples,
)
while online_fbank.num_frames_ready - num_processed_frames >= segment:
frames = []
for i in range(segment):
frames.append(online_fbank.get_frame(num_processed_frames + i))
num_processed_frames += offset
frames = torch.cat(frames, dim=0)
encoder_out, states = model.run_encoder(frames, states)
hyp, decoder_out = greedy_search(model, encoder_out, decoder_out, hyp)
symbol_table = k2.SymbolTable.from_file(args.tokens)
context_size = 2
text = ""
for i in hyp[context_size:]:
text += symbol_table[i]
text = text.replace("", " ").strip()
logging.info(sound_file)
logging.info(text)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

File diff suppressed because it is too large Load Diff

View File

@ -44,7 +44,6 @@ from scaling import (
) )
from torch import Tensor, nn from torch import Tensor, nn
from icefall.dist import get_rank
from icefall.utils import make_pad_mask, subsequent_chunk_mask from icefall.utils import make_pad_mask, subsequent_chunk_mask
@ -271,7 +270,6 @@ class Zipformer(EncoderInterface):
num_encoder_layers (int): number of encoder layers num_encoder_layers (int): number of encoder layers
dropout (float): dropout rate dropout (float): dropout rate
cnn_module_kernels (int): Kernel size of convolution module cnn_module_kernels (int): Kernel size of convolution module
vgg_frontend (bool): whether to use vgg frontend.
warmup_batches (float): number of batches to warm up over warmup_batches (float): number of batches to warm up over
""" """
@ -388,9 +386,9 @@ class Zipformer(EncoderInterface):
def _init_skip_modules(self): def _init_skip_modules(self):
""" """
If self.zipformer_downsampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer If self.zipformer_downsampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer
indexed 4 (in zero indexing), with has subsapling_factor=4, we combine the output of indexed 4 (in zero indexing), which has subsampling_factor=4, we combine the output of
layers 2 and 3; and at the input of layer indexed 5, which which has subsampling_factor=2, layers 2 and 3; and at the input of layer indexed 5, which has subsampling_factor=2,
we combine the outputs of layers 1 and 5. we combine the outputs of layers 1 and 4.
""" """
skip_layers = [] skip_layers = []
skip_modules = [] skip_modules = []
@ -1272,8 +1270,7 @@ class ZipformerEncoder(nn.Module):
Shape: Shape:
src: (S, N, E). src: (S, N, E).
cached_len: (N,) cached_len: (num_layers,)
N is the batch size.
cached_avg: (num_layers, N, C). cached_avg: (num_layers, N, C).
N is the batch size, C is the feature dimension. N is the batch size, C is the feature dimension.
cached_key: (num_layers, left_context_len, N, K). cached_key: (num_layers, left_context_len, N, K).
@ -1289,8 +1286,8 @@ class ZipformerEncoder(nn.Module):
Returns: A tuple of 8 tensors: Returns: A tuple of 8 tensors:
- output tensor - output tensor
- updated cached number of past frmaes. - updated cached number of past frames.
- updated cached average of past frmaes. - updated cached average of past frames.
- updated cached key tensor of of the first attention module. - updated cached key tensor of of the first attention module.
- updated cached value tensor of of the first attention module. - updated cached value tensor of of the first attention module.
- updated cached value tensor of of the second attention module. - updated cached value tensor of of the second attention module.
@ -1522,9 +1519,6 @@ class AttentionDownsample(torch.nn.Module):
""" """
def __init__(self, in_channels: int, out_channels: int, downsample: int): def __init__(self, in_channels: int, out_channels: int, downsample: int):
"""
Require out_channels > in_channels.
"""
super(AttentionDownsample, self).__init__() super(AttentionDownsample, self).__init__()
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5))
@ -1902,8 +1896,6 @@ class RelPositionMultiheadAttention(nn.Module):
Args: Args:
x: input to be projected to query, key, value x: input to be projected to query, key, value
pos_emb: Positional embedding tensor pos_emb: Positional embedding tensor
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shape: Shape:
- Inputs: - Inputs:
@ -1911,13 +1903,6 @@ class RelPositionMultiheadAttention(nn.Module):
the embedding dimension. the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension. the embedding dimension.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
- cached_key: :math:`(left_context_len, N, K)`, where N is the batch size, K is the key dimension. - cached_key: :math:`(left_context_len, N, K)`, where N is the batch size, K is the key dimension.
- cached_val: :math:`(left_context_len, N, V)`, where N is the batch size, V is the value dimension. - cached_val: :math:`(left_context_len, N, V)`, where N is the batch size, V is the value dimension.

File diff suppressed because it is too large Load Diff