mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Export streaming zipformer to ncnn (#906)
This commit is contained in:
parent
e63a8c27f8
commit
c5e687ddf5
99
.github/scripts/test-ncnn-export.sh
vendored
99
.github/scripts/test-ncnn-export.sh
vendored
@ -131,3 +131,102 @@ python3 ./lstm_transducer_stateless2/ncnn-decode.py \
|
|||||||
|
|
||||||
rm -rf $repo
|
rm -rf $repo
|
||||||
log "--------------------------------------------------------------------------"
|
log "--------------------------------------------------------------------------"
|
||||||
|
|
||||||
|
log "=========================================================================="
|
||||||
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
\
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--num-encoder-layers "2,4,3,2,4" \
|
||||||
|
--feedforward-dims "1024,1024,2048,2048,1024" \
|
||||||
|
--nhead "8,8,8,8,8" \
|
||||||
|
--encoder-dims "384,384,384,384,384" \
|
||||||
|
--attention-dims "192,192,192,192,192" \
|
||||||
|
--encoder-unmasked-dims "256,256,256,256,256" \
|
||||||
|
--zipformer-downsampling-factors "1,2,4,8,2" \
|
||||||
|
--cnn-module-kernels "31,31,31,31,31" \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512
|
||||||
|
|
||||||
|
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt
|
||||||
|
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt
|
||||||
|
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt
|
||||||
|
|
||||||
|
python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
|
||||||
|
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
|
||||||
|
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
|
||||||
|
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
|
||||||
|
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
|
||||||
|
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
log "--------------------------------------------------------------------------"
|
||||||
|
|
||||||
|
log "=========================================================================="
|
||||||
|
repo_url=https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_char_bpe/L.pt"
|
||||||
|
git lfs pull --include "data/lang_char_bpe/L_disambig.pt"
|
||||||
|
git lfs pull --include "data/lang_char_bpe/Linv.pt"
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
|
||||||
|
--lang-dir $repo/data/lang_char_bpe \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--num-encoder-layers "2,4,3,2,4" \
|
||||||
|
--feedforward-dims "1024,1024,1536,1536,1024" \
|
||||||
|
--nhead "8,8,8,8,8" \
|
||||||
|
--encoder-dims "384,384,384,384,384" \
|
||||||
|
--attention-dims "192,192,192,192,192" \
|
||||||
|
--encoder-unmasked-dims "256,256,256,256,256" \
|
||||||
|
--zipformer-downsampling-factors "1,2,4,8,2" \
|
||||||
|
--cnn-module-kernels "31,31,31,31,31" \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512
|
||||||
|
|
||||||
|
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt
|
||||||
|
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt
|
||||||
|
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt
|
||||||
|
|
||||||
|
python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \
|
||||||
|
--tokens $repo/data/lang_char_bpe/tokens.txt \
|
||||||
|
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
|
||||||
|
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
|
||||||
|
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
|
||||||
|
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
|
||||||
|
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
|
||||||
|
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
|
||||||
|
$repo/test_wavs/0.wav
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
log "--------------------------------------------------------------------------"
|
||||||
|
@ -310,6 +310,16 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
convert_scaled_to_non_scaled(model, inplace=True)
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
|
||||||
|
encoder_num_param = sum([p.numel() for p in model.encoder.parameters()])
|
||||||
|
decoder_num_param = sum([p.numel() for p in model.decoder.parameters()])
|
||||||
|
joiner_num_param = sum([p.numel() for p in model.joiner.parameters()])
|
||||||
|
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||||
|
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||||
|
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||||
|
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||||
|
logging.info(f"total parameters: {total_num_param}")
|
||||||
|
|
||||||
logging.info("Using torch.jit.trace()")
|
logging.info("Using torch.jit.trace()")
|
||||||
|
|
||||||
logging.info("Exporting encoder")
|
logging.info("Exporting encoder")
|
||||||
|
@ -203,11 +203,8 @@ class Model:
|
|||||||
# (1, 512, 2) -> (512, 2)
|
# (1, 512, 2) -> (512, 2)
|
||||||
ex.input(name, ncnn.Mat(states[i * 4 + 3].numpy()).clone())
|
ex.input(name, ncnn.Mat(states[i * 4 + 3].numpy()).clone())
|
||||||
|
|
||||||
import pdb
|
|
||||||
|
|
||||||
# pdb.set_trace()
|
|
||||||
ret, ncnn_out0 = ex.extract("out0")
|
ret, ncnn_out0 = ex.extract("out0")
|
||||||
# assert ret == 0, ret
|
assert ret == 0, ret
|
||||||
encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
|
encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
|
||||||
|
|
||||||
out_states: List[torch.Tensor] = []
|
out_states: List[torch.Tensor] = []
|
||||||
|
@ -99,7 +99,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="pruned_transducer_stateless2/exp",
|
default="lstm_transducer_stateless2/exp",
|
||||||
help="""It specifies the directory where all training related
|
help="""It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
""",
|
""",
|
||||||
@ -316,6 +316,16 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
convert_scaled_to_non_scaled(model, inplace=True)
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
|
||||||
|
encoder_num_param = sum([p.numel() for p in model.encoder.parameters()])
|
||||||
|
decoder_num_param = sum([p.numel() for p in model.decoder.parameters()])
|
||||||
|
joiner_num_param = sum([p.numel() for p in model.joiner.parameters()])
|
||||||
|
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||||
|
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||||
|
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||||
|
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||||
|
logging.info(f"total parameters: {total_num_param}")
|
||||||
|
|
||||||
logging.info("Using torch.jit.trace()")
|
logging.info("Using torch.jit.trace()")
|
||||||
|
|
||||||
logging.info("Exporting encoder")
|
logging.info("Exporting encoder")
|
||||||
|
@ -87,6 +87,10 @@ class Decoder(nn.Module):
|
|||||||
y = y.to(torch.int64)
|
y = y.to(torch.int64)
|
||||||
# this stuff about clamp() is a temporary fix for a mismatch
|
# this stuff about clamp() is a temporary fix for a mismatch
|
||||||
# at utterance start, we use negative ids in beam_search.py
|
# at utterance start, we use negative ids in beam_search.py
|
||||||
|
if torch.jit.is_tracing():
|
||||||
|
# This is for exporting to PNNX via ONNX
|
||||||
|
embedding_out = self.embedding(y)
|
||||||
|
else:
|
||||||
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
|
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
|
||||||
if self.context_size > 1:
|
if self.context_size > 1:
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
|
@ -53,7 +53,6 @@ class Joiner(nn.Module):
|
|||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == decoder_out.ndim
|
assert encoder_out.ndim == decoder_out.ndim
|
||||||
assert encoder_out.ndim in (2, 4)
|
assert encoder_out.ndim in (2, 4)
|
||||||
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
|
|
||||||
|
|
||||||
if project_input:
|
if project_input:
|
||||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
||||||
|
@ -22,11 +22,101 @@ BasicNorm is replaced by a module with `exp` removed.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from typing import List
|
from typing import List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from scaling import ActivationBalancer, BasicNorm, Whiten
|
from scaling import ActivationBalancer, BasicNorm, Whiten
|
||||||
|
from zipformer import PoolingModule
|
||||||
|
|
||||||
|
|
||||||
|
class PoolingModuleNoProj(nn.Module):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cached_len: torch.Tensor,
|
||||||
|
cached_avg: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A tensor of shape (T, N, C)
|
||||||
|
cached_len:
|
||||||
|
A tensor of shape (N,)
|
||||||
|
cached_avg:
|
||||||
|
A tensor of shape (N, C)
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- new_x
|
||||||
|
- new_cached_len
|
||||||
|
- new_cached_avg
|
||||||
|
"""
|
||||||
|
x = x.cumsum(dim=0) # (T, N, C)
|
||||||
|
x = x + (cached_avg * cached_len.unsqueeze(1)).unsqueeze(0)
|
||||||
|
# Cumulated numbers of frames from start
|
||||||
|
cum_mask = torch.arange(1, x.size(0) + 1, device=x.device)
|
||||||
|
cum_mask = cum_mask.unsqueeze(1) + cached_len.unsqueeze(0) # (T, N)
|
||||||
|
pooling_mask = (1.0 / cum_mask).unsqueeze(2)
|
||||||
|
# now pooling_mask: (T, N, 1)
|
||||||
|
x = x * pooling_mask # (T, N, C)
|
||||||
|
|
||||||
|
cached_len = cached_len + x.size(0)
|
||||||
|
cached_avg = x[-1]
|
||||||
|
|
||||||
|
return x, cached_len, cached_avg
|
||||||
|
|
||||||
|
|
||||||
|
class PoolingModuleWithProj(nn.Module):
|
||||||
|
def __init__(self, proj: torch.nn.Module):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = proj
|
||||||
|
self.pooling = PoolingModuleNoProj()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cached_len: torch.Tensor,
|
||||||
|
cached_avg: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A tensor of shape (T, N, C)
|
||||||
|
cached_len:
|
||||||
|
A tensor of shape (N,)
|
||||||
|
cached_avg:
|
||||||
|
A tensor of shape (N, C)
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- new_x
|
||||||
|
- new_cached_len
|
||||||
|
- new_cached_avg
|
||||||
|
"""
|
||||||
|
x, cached_len, cached_avg = self.pooling(x, cached_len, cached_avg)
|
||||||
|
return self.proj(x), cached_len, cached_avg
|
||||||
|
|
||||||
|
def streaming_forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cached_len: torch.Tensor,
|
||||||
|
cached_avg: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A tensor of shape (T, N, C)
|
||||||
|
cached_len:
|
||||||
|
A tensor of shape (N,)
|
||||||
|
cached_avg:
|
||||||
|
A tensor of shape (N, C)
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- new_x
|
||||||
|
- new_cached_len
|
||||||
|
- new_cached_avg
|
||||||
|
"""
|
||||||
|
x, cached_len, cached_avg = self.pooling(x, cached_len, cached_avg)
|
||||||
|
return self.proj(x), cached_len, cached_avg
|
||||||
|
|
||||||
|
|
||||||
class NonScaledNorm(nn.Module):
|
class NonScaledNorm(nn.Module):
|
||||||
@ -53,7 +143,7 @@ class NonScaledNorm(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm:
|
def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm:
|
||||||
assert isinstance(basic_norm, BasicNorm), type(BasicNorm)
|
assert isinstance(basic_norm, BasicNorm), type(basic_norm)
|
||||||
norm = NonScaledNorm(
|
norm = NonScaledNorm(
|
||||||
num_channels=basic_norm.num_channels,
|
num_channels=basic_norm.num_channels,
|
||||||
eps_exp=basic_norm.eps.data.exp().item(),
|
eps_exp=basic_norm.eps.data.exp().item(),
|
||||||
@ -62,6 +152,11 @@ def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm:
|
|||||||
return norm
|
return norm
|
||||||
|
|
||||||
|
|
||||||
|
def convert_pooling_module(pooling: PoolingModule) -> PoolingModuleWithProj:
|
||||||
|
assert isinstance(pooling, PoolingModule), type(pooling)
|
||||||
|
return PoolingModuleWithProj(proj=pooling.proj)
|
||||||
|
|
||||||
|
|
||||||
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
|
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
|
||||||
# get_submodule was added to nn.Module at v1.9.0
|
# get_submodule was added to nn.Module at v1.9.0
|
||||||
def get_submodule(model, target):
|
def get_submodule(model, target):
|
||||||
@ -83,6 +178,7 @@ def get_submodule(model, target):
|
|||||||
def convert_scaled_to_non_scaled(
|
def convert_scaled_to_non_scaled(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
|
is_pnnx: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -91,6 +187,8 @@ def convert_scaled_to_non_scaled(
|
|||||||
inplace:
|
inplace:
|
||||||
If True, the input model is modified inplace.
|
If True, the input model is modified inplace.
|
||||||
If False, the input model is copied and we modify the copied version.
|
If False, the input model is copied and we modify the copied version.
|
||||||
|
is_pnnx:
|
||||||
|
True if we are going to export the model for PNNX.
|
||||||
Return:
|
Return:
|
||||||
Return a model without scaled layers.
|
Return a model without scaled layers.
|
||||||
"""
|
"""
|
||||||
@ -103,6 +201,8 @@ def convert_scaled_to_non_scaled(
|
|||||||
d[name] = convert_basic_norm(m)
|
d[name] = convert_basic_norm(m)
|
||||||
elif isinstance(m, (ActivationBalancer, Whiten)):
|
elif isinstance(m, (ActivationBalancer, Whiten)):
|
||||||
d[name] = nn.Identity()
|
d[name] = nn.Identity()
|
||||||
|
elif isinstance(m, PoolingModule) and is_pnnx:
|
||||||
|
d[name] = convert_pooling_module(m)
|
||||||
|
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
if "." in k:
|
if "." in k:
|
||||||
|
@ -1,3 +1,10 @@
|
|||||||
This recipe implements Streaming Zipformer-Transducer model.
|
This recipe implements Streaming Zipformer-Transducer model.
|
||||||
|
|
||||||
See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer_transducer.html for detailed tutorials.
|
See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer_transducer.html for detailed tutorials.
|
||||||
|
|
||||||
|
[./emformer.py](./emformer.py) and [./train.py](./train.py)
|
||||||
|
are basically the same as
|
||||||
|
[./emformer2.py](./emformer2.py) and [./train2.py](./train2.py).
|
||||||
|
The only purpose of [./emformer2.py](./emformer2.py) and [./train2.py](./train2.py)
|
||||||
|
is for exporting to [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn).
|
||||||
|
|
||||||
|
367
egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py
Executable file
367
egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py
Executable file
@ -0,0 +1,367 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
"""
|
||||||
|
Please see
|
||||||
|
https://k2-fsa.github.io/icefall/model-export/export-ncnn.html
|
||||||
|
for more details about how to use this file.
|
||||||
|
|
||||||
|
We use
|
||||||
|
https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed
|
||||||
|
to demonstrate the usage of this file.
|
||||||
|
|
||||||
|
1. Download the pre-trained model
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_char_bpe/L.pt"
|
||||||
|
git lfs pull --include "data/lang_char_bpe/L_disambig.pt"
|
||||||
|
git lfs pull --include "data/lang_char_bpe/Linv.pt"
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
2. Export to ncnn
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
|
||||||
|
--lang-dir $repo/data/lang_char_bpe \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--num-encoder-layers "2,4,3,2,4" \
|
||||||
|
--feedforward-dims "1024,1024,1536,1536,1024" \
|
||||||
|
--nhead "8,8,8,8,8" \
|
||||||
|
--encoder-dims "384,384,384,384,384" \
|
||||||
|
--attention-dims "192,192,192,192,192" \
|
||||||
|
--encoder-unmasked-dims "256,256,256,256,256" \
|
||||||
|
--zipformer-downsampling-factors "1,2,4,8,2" \
|
||||||
|
--cnn-module-kernels "31,31,31,31,31" \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512
|
||||||
|
|
||||||
|
cd $repo/exp
|
||||||
|
|
||||||
|
pnnx encoder_jit_trace-pnnx.pt
|
||||||
|
pnnx decoder_jit_trace-pnnx.pt
|
||||||
|
pnnx joiner_jit_trace-pnnx.pt
|
||||||
|
|
||||||
|
You can find converted models at
|
||||||
|
https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-bilingual-zh-en-2023-02-13
|
||||||
|
|
||||||
|
See ./streaming-ncnn-decode.py
|
||||||
|
and
|
||||||
|
https://github.com/k2-fsa/sherpa-ncnn
|
||||||
|
for usage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
|
from train2 import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=28,
|
||||||
|
help="""It specifies the checkpoint to use for averaging.
|
||||||
|
Note: Epoch counts from 0.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="pruned_transducer_stateless7_streaming/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_char",
|
||||||
|
help="The lang dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_jit_trace(
|
||||||
|
encoder_model: torch.nn.Module,
|
||||||
|
encoder_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given encoder model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The warmup argument is fixed to 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
The input encoder model
|
||||||
|
encoder_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
encoder_model.__class__.forward = encoder_model.__class__.streaming_forward
|
||||||
|
|
||||||
|
decode_chunk_len = encoder_model.decode_chunk_size * 2
|
||||||
|
pad_length = 7
|
||||||
|
T = decode_chunk_len + pad_length # 32 + 7 = 39
|
||||||
|
|
||||||
|
logging.info(f"decode_chunk_len: {decode_chunk_len}")
|
||||||
|
logging.info(f"T: {T}")
|
||||||
|
|
||||||
|
x = torch.zeros(1, T, 80, dtype=torch.float32)
|
||||||
|
states = encoder_model.get_init_state()
|
||||||
|
|
||||||
|
traced_model = torch.jit.trace(encoder_model, (x, states))
|
||||||
|
traced_model.save(encoder_filename)
|
||||||
|
logging.info(f"Saved to {encoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_jit_trace(
|
||||||
|
decoder_model: torch.nn.Module,
|
||||||
|
decoder_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given decoder model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The argument need_pad is fixed to False.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The input decoder model
|
||||||
|
decoder_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||||
|
need_pad = torch.tensor([False])
|
||||||
|
|
||||||
|
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
|
||||||
|
traced_model.save(decoder_filename)
|
||||||
|
logging.info(f"Saved to {decoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_jit_trace(
|
||||||
|
joiner_model: torch.nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given joiner model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The argument project_input is fixed to True. A user should not
|
||||||
|
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||||
|
will do that for the user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
joiner_model:
|
||||||
|
The input joiner model
|
||||||
|
joiner_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||||
|
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||||
|
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||||
|
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
|
||||||
|
traced_model.save(joiner_filename)
|
||||||
|
logging.info(f"Saved to {joiner_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn")
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
params.blank_id = 0
|
||||||
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True)
|
||||||
|
|
||||||
|
encoder_num_param = sum([p.numel() for p in model.encoder.parameters()])
|
||||||
|
decoder_num_param = sum([p.numel() for p in model.decoder.parameters()])
|
||||||
|
joiner_num_param = sum([p.numel() for p in model.joiner.parameters()])
|
||||||
|
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||||
|
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||||
|
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||||
|
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||||
|
logging.info(f"total parameters: {total_num_param}")
|
||||||
|
|
||||||
|
logging.info("Using torch.jit.trace()")
|
||||||
|
|
||||||
|
logging.info("Exporting encoder")
|
||||||
|
encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt"
|
||||||
|
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||||
|
|
||||||
|
logging.info("Exporting decoder")
|
||||||
|
decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt"
|
||||||
|
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||||
|
|
||||||
|
logging.info("Exporting joiner")
|
||||||
|
joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt"
|
||||||
|
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
main()
|
369
egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
Executable file
369
egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
Executable file
@ -0,0 +1,369 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
"""
|
||||||
|
Please see
|
||||||
|
https://k2-fsa.github.io/icefall/model-export/export-ncnn.html
|
||||||
|
for more details about how to use this file.
|
||||||
|
|
||||||
|
We use
|
||||||
|
https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||||
|
to demonstrate the usage of this file.
|
||||||
|
|
||||||
|
1. Download the pre-trained model
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe/bpe.model"
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
2. Export to ncnn
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
\
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--num-encoder-layers "2,4,3,2,4" \
|
||||||
|
--feedforward-dims "1024,1024,2048,2048,1024" \
|
||||||
|
--nhead "8,8,8,8,8" \
|
||||||
|
--encoder-dims "384,384,384,384,384" \
|
||||||
|
--attention-dims "192,192,192,192,192" \
|
||||||
|
--encoder-unmasked-dims "256,256,256,256,256" \
|
||||||
|
--zipformer-downsampling-factors "1,2,4,8,2" \
|
||||||
|
--cnn-module-kernels "31,31,31,31,31" \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512
|
||||||
|
|
||||||
|
cd $repo/exp
|
||||||
|
|
||||||
|
pnnx encoder_jit_trace-pnnx.pt
|
||||||
|
pnnx decoder_jit_trace-pnnx.pt
|
||||||
|
pnnx joiner_jit_trace-pnnx.pt
|
||||||
|
|
||||||
|
You can find converted models at
|
||||||
|
https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-en-2023-02-13
|
||||||
|
|
||||||
|
See ./streaming-ncnn-decode.py
|
||||||
|
and
|
||||||
|
https://github.com/k2-fsa/sherpa-ncnn
|
||||||
|
for usage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
|
from train2 import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.utils import setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=28,
|
||||||
|
help="""It specifies the checkpoint to use for averaging.
|
||||||
|
Note: Epoch counts from 0.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="pruned_transducer_stateless7_streaming/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500/bpe.model",
|
||||||
|
help="Path to the BPE model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_jit_trace(
|
||||||
|
encoder_model: torch.nn.Module,
|
||||||
|
encoder_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given encoder model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The warmup argument is fixed to 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
The input encoder model
|
||||||
|
encoder_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
encoder_model.__class__.forward = encoder_model.__class__.streaming_forward
|
||||||
|
|
||||||
|
decode_chunk_len = encoder_model.decode_chunk_size * 2
|
||||||
|
pad_length = 7
|
||||||
|
T = decode_chunk_len + pad_length # 32 + 7 = 39
|
||||||
|
|
||||||
|
logging.info(f"decode_chunk_len: {decode_chunk_len}")
|
||||||
|
logging.info(f"T: {T}")
|
||||||
|
|
||||||
|
x = torch.zeros(1, T, 80, dtype=torch.float32)
|
||||||
|
states = encoder_model.get_init_state()
|
||||||
|
|
||||||
|
traced_model = torch.jit.trace(encoder_model, (x, states))
|
||||||
|
traced_model.save(encoder_filename)
|
||||||
|
logging.info(f"Saved to {encoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_jit_trace(
|
||||||
|
decoder_model: torch.nn.Module,
|
||||||
|
decoder_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given decoder model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The argument need_pad is fixed to False.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The input decoder model
|
||||||
|
decoder_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||||
|
need_pad = torch.tensor([False])
|
||||||
|
|
||||||
|
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
|
||||||
|
traced_model.save(decoder_filename)
|
||||||
|
logging.info(f"Saved to {decoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_jit_trace(
|
||||||
|
joiner_model: torch.nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given joiner model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The argument project_input is fixed to True. A user should not
|
||||||
|
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||||
|
will do that for the user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
joiner_model:
|
||||||
|
The input joiner model
|
||||||
|
joiner_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||||
|
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||||
|
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||||
|
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
|
||||||
|
traced_model.save(joiner_filename)
|
||||||
|
logging.info(f"Saved to {joiner_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn")
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True)
|
||||||
|
|
||||||
|
encoder_num_param = sum([p.numel() for p in model.encoder.parameters()])
|
||||||
|
decoder_num_param = sum([p.numel() for p in model.decoder.parameters()])
|
||||||
|
joiner_num_param = sum([p.numel() for p in model.joiner.parameters()])
|
||||||
|
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||||
|
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||||
|
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||||
|
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||||
|
logging.info(f"total parameters: {total_num_param}")
|
||||||
|
|
||||||
|
logging.info("Using torch.jit.trace()")
|
||||||
|
|
||||||
|
logging.info("Exporting encoder")
|
||||||
|
encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt"
|
||||||
|
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||||
|
|
||||||
|
logging.info("Exporting decoder")
|
||||||
|
decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt"
|
||||||
|
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||||
|
|
||||||
|
logging.info("Exporting joiner")
|
||||||
|
joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt"
|
||||||
|
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
main()
|
@ -0,0 +1,419 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \
|
||||||
|
--tokens ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/tokens.txt \
|
||||||
|
--encoder-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/encoder_jit_trace-pnnx.ncnn.param \
|
||||||
|
--encoder-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/encoder_jit_trace-pnnx.ncnn.bin \
|
||||||
|
--decoder-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/decoder_jit_trace-pnnx.ncnn.param \
|
||||||
|
--decoder-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/decoder_jit_trace-pnnx.ncnn.bin \
|
||||||
|
--joiner-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/joiner_jit_trace-pnnx.ncnn.param \
|
||||||
|
--joiner-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/joiner_jit_trace-pnnx.ncnn.bin \
|
||||||
|
./sherpa-ncnn-streaming-zipformer-en-2023-02-13/test_wavs/1089-134686-0001.wav
|
||||||
|
|
||||||
|
You can find pretrained models at
|
||||||
|
- English: https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-en-2023-02-13
|
||||||
|
- Bilingual (Chinese + English): https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-bilingual-zh-en-2023-02-13
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import ncnn
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
help="Path to tokens.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-param-filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to encoder.ncnn.param",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-bin-filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to encoder.ncnn.bin",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-param-filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to decoder.ncnn.param",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-bin-filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to decoder.ncnn.bin",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-param-filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to joiner.ncnn.param",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-bin-filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to joiner.ncnn.bin",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to foo.wav",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def to_int_tuple(s: str):
|
||||||
|
return tuple(map(int, s.split(",")))
|
||||||
|
|
||||||
|
|
||||||
|
class Model:
|
||||||
|
def __init__(self, args):
|
||||||
|
self.init_encoder(args)
|
||||||
|
self.init_decoder(args)
|
||||||
|
self.init_joiner(args)
|
||||||
|
|
||||||
|
# Please change the parameters according to your model
|
||||||
|
self.num_encoder_layers = to_int_tuple("2,4,3,2,4")
|
||||||
|
self.encoder_dims = to_int_tuple("384,384,384,384,384") # also known as d_model
|
||||||
|
self.attention_dims = to_int_tuple("192,192,192,192,192")
|
||||||
|
self.zipformer_downsampling_factors = to_int_tuple("1,2,4,8,2")
|
||||||
|
self.cnn_module_kernels = to_int_tuple("31,31,31,31,31")
|
||||||
|
|
||||||
|
self.decode_chunk_size = 32 // 2
|
||||||
|
num_left_chunks = 4
|
||||||
|
self.left_context_length = self.decode_chunk_size * num_left_chunks # 64
|
||||||
|
|
||||||
|
self.chunk_length = self.decode_chunk_size * 2
|
||||||
|
pad_length = 7
|
||||||
|
self.T = self.chunk_length + pad_length
|
||||||
|
|
||||||
|
def get_init_states(self) -> List[torch.Tensor]:
|
||||||
|
cached_len_list = []
|
||||||
|
cached_avg_list = []
|
||||||
|
cached_key_list = []
|
||||||
|
cached_val_list = []
|
||||||
|
cached_val2_list = []
|
||||||
|
cached_conv1_list = []
|
||||||
|
cached_conv2_list = []
|
||||||
|
|
||||||
|
for i in range(len(self.num_encoder_layers)):
|
||||||
|
num_layers = self.num_encoder_layers[i]
|
||||||
|
ds = self.zipformer_downsampling_factors[i]
|
||||||
|
attention_dim = self.attention_dims[i]
|
||||||
|
left_context_length = self.left_context_length // ds
|
||||||
|
encoder_dim = self.encoder_dims[i]
|
||||||
|
cnn_module_kernel = self.cnn_module_kernels[i]
|
||||||
|
|
||||||
|
cached_len_list.append(torch.zeros(num_layers))
|
||||||
|
cached_avg_list.append(torch.zeros(num_layers, encoder_dim))
|
||||||
|
cached_key_list.append(
|
||||||
|
torch.zeros(num_layers, left_context_length, attention_dim)
|
||||||
|
)
|
||||||
|
cached_val_list.append(
|
||||||
|
torch.zeros(num_layers, left_context_length, attention_dim // 2)
|
||||||
|
)
|
||||||
|
cached_val2_list.append(
|
||||||
|
torch.zeros(num_layers, left_context_length, attention_dim // 2)
|
||||||
|
)
|
||||||
|
cached_conv1_list.append(
|
||||||
|
torch.zeros(num_layers, encoder_dim, cnn_module_kernel - 1)
|
||||||
|
)
|
||||||
|
cached_conv2_list.append(
|
||||||
|
torch.zeros(num_layers, encoder_dim, cnn_module_kernel - 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
states = (
|
||||||
|
cached_len_list
|
||||||
|
+ cached_avg_list
|
||||||
|
+ cached_key_list
|
||||||
|
+ cached_val_list
|
||||||
|
+ cached_val2_list
|
||||||
|
+ cached_conv1_list
|
||||||
|
+ cached_conv2_list
|
||||||
|
)
|
||||||
|
|
||||||
|
return states
|
||||||
|
|
||||||
|
def init_encoder(self, args):
|
||||||
|
encoder_net = ncnn.Net()
|
||||||
|
encoder_net.opt.use_packing_layout = False
|
||||||
|
encoder_net.opt.use_fp16_storage = False
|
||||||
|
encoder_net.opt.num_threads = 4
|
||||||
|
|
||||||
|
encoder_param = args.encoder_param_filename
|
||||||
|
encoder_model = args.encoder_bin_filename
|
||||||
|
|
||||||
|
encoder_net.load_param(encoder_param)
|
||||||
|
encoder_net.load_model(encoder_model)
|
||||||
|
|
||||||
|
self.encoder_net = encoder_net
|
||||||
|
|
||||||
|
def init_decoder(self, args):
|
||||||
|
decoder_param = args.decoder_param_filename
|
||||||
|
decoder_model = args.decoder_bin_filename
|
||||||
|
|
||||||
|
decoder_net = ncnn.Net()
|
||||||
|
decoder_net.opt.num_threads = 4
|
||||||
|
|
||||||
|
decoder_net.load_param(decoder_param)
|
||||||
|
decoder_net.load_model(decoder_model)
|
||||||
|
|
||||||
|
self.decoder_net = decoder_net
|
||||||
|
|
||||||
|
def init_joiner(self, args):
|
||||||
|
joiner_param = args.joiner_param_filename
|
||||||
|
joiner_model = args.joiner_bin_filename
|
||||||
|
joiner_net = ncnn.Net()
|
||||||
|
joiner_net.opt.num_threads = 4
|
||||||
|
|
||||||
|
joiner_net.load_param(joiner_param)
|
||||||
|
joiner_net.load_model(joiner_model)
|
||||||
|
|
||||||
|
self.joiner_net = joiner_net
|
||||||
|
|
||||||
|
def run_encoder(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
states: List[torch.Tensor],
|
||||||
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A tensor of shape (T, C)
|
||||||
|
states:
|
||||||
|
A list of tensors. len(states) == self.num_layers * 4
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- encoder_out, a tensor of shape (T, encoder_dim).
|
||||||
|
- next_states, a list of tensors containing the next states
|
||||||
|
"""
|
||||||
|
with self.encoder_net.create_extractor() as ex:
|
||||||
|
ex.input("in0", ncnn.Mat(x.numpy()).clone())
|
||||||
|
|
||||||
|
for i in range(len(states)):
|
||||||
|
name = f"in{i+1}"
|
||||||
|
ex.input(name, ncnn.Mat(states[i].squeeze().numpy()).clone())
|
||||||
|
|
||||||
|
ret, ncnn_out0 = ex.extract("out0")
|
||||||
|
assert ret == 0, ret
|
||||||
|
encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
|
||||||
|
|
||||||
|
out_states: List[torch.Tensor] = []
|
||||||
|
for i in range(len(states)):
|
||||||
|
name = f"out{i+1}"
|
||||||
|
ret, ncnn_out_state = ex.extract(name)
|
||||||
|
assert ret == 0, ret
|
||||||
|
ncnn_out_state = torch.from_numpy(ncnn_out_state.numpy())
|
||||||
|
|
||||||
|
if i < len(self.num_encoder_layers):
|
||||||
|
# for cached_len, we need to discard the last dim
|
||||||
|
ncnn_out_state = ncnn_out_state.squeeze(1)
|
||||||
|
|
||||||
|
out_states.append(ncnn_out_state)
|
||||||
|
|
||||||
|
return encoder_out, out_states
|
||||||
|
|
||||||
|
def run_decoder(self, decoder_input):
|
||||||
|
assert decoder_input.dtype == torch.int32
|
||||||
|
|
||||||
|
with self.decoder_net.create_extractor() as ex:
|
||||||
|
ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone())
|
||||||
|
ret, ncnn_out0 = ex.extract("out0")
|
||||||
|
assert ret == 0, ret
|
||||||
|
decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
|
||||||
|
return decoder_out
|
||||||
|
|
||||||
|
def run_joiner(self, encoder_out, decoder_out):
|
||||||
|
with self.joiner_net.create_extractor() as ex:
|
||||||
|
ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone())
|
||||||
|
ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone())
|
||||||
|
ret, ncnn_out0 = ex.extract("out0")
|
||||||
|
assert ret == 0, ret
|
||||||
|
joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone()
|
||||||
|
return joiner_out
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert (
|
||||||
|
sample_rate == expected_sample_rate
|
||||||
|
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def create_streaming_feature_extractor() -> OnlineFeature:
|
||||||
|
"""Create a CPU streaming feature extractor.
|
||||||
|
|
||||||
|
At present, we assume it returns a fbank feature extractor with
|
||||||
|
fixed options. In the future, we will support passing in the options
|
||||||
|
from outside.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a CPU streaming feature extractor.
|
||||||
|
"""
|
||||||
|
opts = FbankOptions()
|
||||||
|
opts.device = "cpu"
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = 16000
|
||||||
|
opts.mel_opts.num_bins = 80
|
||||||
|
return OnlineFbank(opts)
|
||||||
|
|
||||||
|
|
||||||
|
def greedy_search(
|
||||||
|
model: Model,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
decoder_out: Optional[torch.Tensor] = None,
|
||||||
|
hyp: Optional[List[int]] = None,
|
||||||
|
):
|
||||||
|
context_size = 2
|
||||||
|
blank_id = 0
|
||||||
|
|
||||||
|
if decoder_out is None:
|
||||||
|
assert hyp is None, hyp
|
||||||
|
hyp = [blank_id] * context_size
|
||||||
|
decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size)
|
||||||
|
decoder_out = model.run_decoder(decoder_input).squeeze(0)
|
||||||
|
else:
|
||||||
|
assert decoder_out.ndim == 1
|
||||||
|
assert hyp is not None, hyp
|
||||||
|
|
||||||
|
T = encoder_out.size(0)
|
||||||
|
for t in range(T):
|
||||||
|
cur_encoder_out = encoder_out[t]
|
||||||
|
|
||||||
|
joiner_out = model.run_joiner(cur_encoder_out, decoder_out)
|
||||||
|
y = joiner_out.argmax(dim=0).item()
|
||||||
|
if y != blank_id:
|
||||||
|
hyp.append(y)
|
||||||
|
decoder_input = hyp[-context_size:]
|
||||||
|
decoder_input = torch.tensor(decoder_input, dtype=torch.int32)
|
||||||
|
decoder_out = model.run_decoder(decoder_input).squeeze(0)
|
||||||
|
|
||||||
|
return hyp, decoder_out
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
model = Model(args)
|
||||||
|
|
||||||
|
sound_file = args.sound_filename
|
||||||
|
|
||||||
|
sample_rate = 16000
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
online_fbank = create_streaming_feature_extractor()
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {sound_file}")
|
||||||
|
wave_samples = read_sound_files(
|
||||||
|
filenames=[sound_file],
|
||||||
|
expected_sample_rate=sample_rate,
|
||||||
|
)[0]
|
||||||
|
logging.info(wave_samples.shape)
|
||||||
|
|
||||||
|
tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
|
||||||
|
|
||||||
|
wave_samples = torch.cat([wave_samples, tail_padding])
|
||||||
|
|
||||||
|
states = model.get_init_states()
|
||||||
|
logging.info(f"number of states: {len(states)}")
|
||||||
|
|
||||||
|
hyp = None
|
||||||
|
decoder_out = None
|
||||||
|
|
||||||
|
num_processed_frames = 0
|
||||||
|
segment = model.T
|
||||||
|
offset = model.chunk_length
|
||||||
|
|
||||||
|
chunk = int(1 * sample_rate) # 0.2 second
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
while start < wave_samples.numel():
|
||||||
|
end = min(start + chunk, wave_samples.numel())
|
||||||
|
samples = wave_samples[start:end]
|
||||||
|
start += chunk
|
||||||
|
|
||||||
|
online_fbank.accept_waveform(
|
||||||
|
sampling_rate=sample_rate,
|
||||||
|
waveform=samples,
|
||||||
|
)
|
||||||
|
while online_fbank.num_frames_ready - num_processed_frames >= segment:
|
||||||
|
frames = []
|
||||||
|
for i in range(segment):
|
||||||
|
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||||
|
num_processed_frames += offset
|
||||||
|
frames = torch.cat(frames, dim=0)
|
||||||
|
encoder_out, states = model.run_encoder(frames, states)
|
||||||
|
hyp, decoder_out = greedy_search(model, encoder_out, decoder_out, hyp)
|
||||||
|
|
||||||
|
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||||
|
|
||||||
|
context_size = 2
|
||||||
|
text = ""
|
||||||
|
for i in hyp[context_size:]:
|
||||||
|
text += symbol_table[i]
|
||||||
|
text = text.replace("▁", " ").strip()
|
||||||
|
|
||||||
|
logging.info(sound_file)
|
||||||
|
logging.info(text)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
main()
|
1265
egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py
Executable file
1265
egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -44,7 +44,6 @@ from scaling import (
|
|||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from icefall.dist import get_rank
|
|
||||||
from icefall.utils import make_pad_mask, subsequent_chunk_mask
|
from icefall.utils import make_pad_mask, subsequent_chunk_mask
|
||||||
|
|
||||||
|
|
||||||
@ -271,7 +270,6 @@ class Zipformer(EncoderInterface):
|
|||||||
num_encoder_layers (int): number of encoder layers
|
num_encoder_layers (int): number of encoder layers
|
||||||
dropout (float): dropout rate
|
dropout (float): dropout rate
|
||||||
cnn_module_kernels (int): Kernel size of convolution module
|
cnn_module_kernels (int): Kernel size of convolution module
|
||||||
vgg_frontend (bool): whether to use vgg frontend.
|
|
||||||
warmup_batches (float): number of batches to warm up over
|
warmup_batches (float): number of batches to warm up over
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -388,9 +386,9 @@ class Zipformer(EncoderInterface):
|
|||||||
def _init_skip_modules(self):
|
def _init_skip_modules(self):
|
||||||
"""
|
"""
|
||||||
If self.zipformer_downsampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer
|
If self.zipformer_downsampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer
|
||||||
indexed 4 (in zero indexing), with has subsapling_factor=4, we combine the output of
|
indexed 4 (in zero indexing), which has subsampling_factor=4, we combine the output of
|
||||||
layers 2 and 3; and at the input of layer indexed 5, which which has subsampling_factor=2,
|
layers 2 and 3; and at the input of layer indexed 5, which has subsampling_factor=2,
|
||||||
we combine the outputs of layers 1 and 5.
|
we combine the outputs of layers 1 and 4.
|
||||||
"""
|
"""
|
||||||
skip_layers = []
|
skip_layers = []
|
||||||
skip_modules = []
|
skip_modules = []
|
||||||
@ -1272,8 +1270,7 @@ class ZipformerEncoder(nn.Module):
|
|||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
src: (S, N, E).
|
src: (S, N, E).
|
||||||
cached_len: (N,)
|
cached_len: (num_layers,)
|
||||||
N is the batch size.
|
|
||||||
cached_avg: (num_layers, N, C).
|
cached_avg: (num_layers, N, C).
|
||||||
N is the batch size, C is the feature dimension.
|
N is the batch size, C is the feature dimension.
|
||||||
cached_key: (num_layers, left_context_len, N, K).
|
cached_key: (num_layers, left_context_len, N, K).
|
||||||
@ -1289,8 +1286,8 @@ class ZipformerEncoder(nn.Module):
|
|||||||
|
|
||||||
Returns: A tuple of 8 tensors:
|
Returns: A tuple of 8 tensors:
|
||||||
- output tensor
|
- output tensor
|
||||||
- updated cached number of past frmaes.
|
- updated cached number of past frames.
|
||||||
- updated cached average of past frmaes.
|
- updated cached average of past frames.
|
||||||
- updated cached key tensor of of the first attention module.
|
- updated cached key tensor of of the first attention module.
|
||||||
- updated cached value tensor of of the first attention module.
|
- updated cached value tensor of of the first attention module.
|
||||||
- updated cached value tensor of of the second attention module.
|
- updated cached value tensor of of the second attention module.
|
||||||
@ -1522,9 +1519,6 @@ class AttentionDownsample(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels: int, out_channels: int, downsample: int):
|
def __init__(self, in_channels: int, out_channels: int, downsample: int):
|
||||||
"""
|
|
||||||
Require out_channels > in_channels.
|
|
||||||
"""
|
|
||||||
super(AttentionDownsample, self).__init__()
|
super(AttentionDownsample, self).__init__()
|
||||||
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5))
|
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5))
|
||||||
|
|
||||||
@ -1902,8 +1896,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
x: input to be projected to query, key, value
|
x: input to be projected to query, key, value
|
||||||
pos_emb: Positional embedding tensor
|
pos_emb: Positional embedding tensor
|
||||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
|
||||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
- Inputs:
|
- Inputs:
|
||||||
@ -1911,13 +1903,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
the embedding dimension.
|
the embedding dimension.
|
||||||
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
|
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
|
||||||
the embedding dimension.
|
the embedding dimension.
|
||||||
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
|
||||||
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
|
||||||
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
|
||||||
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
|
||||||
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
|
||||||
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
|
||||||
is provided, it will be added to the attention weight.
|
|
||||||
- cached_key: :math:`(left_context_len, N, K)`, where N is the batch size, K is the key dimension.
|
- cached_key: :math:`(left_context_len, N, K)`, where N is the batch size, K is the key dimension.
|
||||||
- cached_val: :math:`(left_context_len, N, V)`, where N is the batch size, V is the value dimension.
|
- cached_val: :math:`(left_context_len, N, V)`, where N is the batch size, V is the value dimension.
|
||||||
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user