mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
Merge branch 'k2-fsa:master' into dev/k2ssl
This commit is contained in:
commit
25b6dd26ac
42
.github/scripts/multi-zh-hans.sh
vendored
42
.github/scripts/multi-zh-hans.sh
vendored
@ -16,6 +16,48 @@ log "pwd: $PWD"
|
||||
|
||||
cd egs/multi_zh-hans/ASR
|
||||
|
||||
repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
pushd $repo
|
||||
cd exp
|
||||
git lfs pull --include pretrained.pt
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
cd ../data/lang_bpe_2000
|
||||
ls -lh
|
||||
git lfs pull --include L.pt L_disambig.pt Linv.pt bpe.model
|
||||
git lfs pull --include "*.model"
|
||||
ls -lh
|
||||
popd
|
||||
|
||||
log "--------------------------------------------"
|
||||
log "Export non-streaming ONNX transducer models "
|
||||
log "--------------------------------------------"
|
||||
./zipformer/export-onnx.py \
|
||||
--tokens $repo/data/lang_bpe_2000/tokens.txt \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp \
|
||||
--causal False
|
||||
|
||||
ls -lh $repo/exp
|
||||
|
||||
./zipformer/onnx_pretrained.py \
|
||||
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||
--tokens $repo/data/lang_bpe_2000/tokens.txt \
|
||||
$repo/test_wavs/DEV_T0000000000.wav \
|
||||
$repo/test_wavs/DEV_T0000000001.wav \
|
||||
$repo/test_wavs/DEV_T0000000002.wav \
|
||||
$repo/test_wavs/TEST_MEETING_T0000000113.wav \
|
||||
$repo/test_wavs/TEST_MEETING_T0000000219.wav \
|
||||
$repo/test_wavs/TEST_MEETING_T0000000351.wav
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
|
@ -74,7 +74,6 @@ import onnx
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from decoder import Decoder
|
||||
from onnxconverter_common import float16
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
@ -756,6 +755,7 @@ def main():
|
||||
logging.info(f"Exported joiner to {joiner_filename}")
|
||||
|
||||
if(params.fp16) :
|
||||
from onnxconverter_common import float16
|
||||
logging.info("Generate fp16 models")
|
||||
|
||||
encoder = onnx.load(encoder_filename)
|
||||
|
@ -191,6 +191,7 @@ class Zipformer2(EncoderInterface):
|
||||
dim=encoder_dim[i],
|
||||
downsample=downsampling_factor[i],
|
||||
dropout=dropout,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
encoders.append(encoder)
|
||||
@ -198,7 +199,10 @@ class Zipformer2(EncoderInterface):
|
||||
self.encoders = nn.ModuleList(encoders)
|
||||
|
||||
self.downsample_output = SimpleDownsample(
|
||||
max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout
|
||||
max(encoder_dim),
|
||||
downsample=output_downsampling_factor,
|
||||
dropout=dropout,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]:
|
||||
@ -1217,11 +1221,16 @@ class DownsampledZipformer2Encoder(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike
|
||||
self,
|
||||
encoder: nn.Module,
|
||||
dim: int,
|
||||
downsample: int,
|
||||
dropout: FloatLike,
|
||||
causal: bool,
|
||||
):
|
||||
super(DownsampledZipformer2Encoder, self).__init__()
|
||||
self.downsample_factor = downsample
|
||||
self.downsample = SimpleDownsample(dim, downsample, dropout)
|
||||
self.downsample = SimpleDownsample(dim, downsample, dropout, causal)
|
||||
self.num_layers = encoder.num_layers
|
||||
self.encoder = encoder
|
||||
self.upsample = SimpleUpsample(dim, downsample)
|
||||
@ -1310,9 +1319,12 @@ class SimpleDownsample(torch.nn.Module):
|
||||
Does downsampling with attention, by weighted sum, and a projection..
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int, downsample: int, dropout: FloatLike):
|
||||
def __init__(
|
||||
self, channels: int, downsample: int, dropout: FloatLike, causal: bool
|
||||
):
|
||||
super(SimpleDownsample, self).__init__()
|
||||
|
||||
self.causal = causal
|
||||
self.bias = nn.Parameter(torch.zeros(downsample))
|
||||
|
||||
self.name = None # will be set from training code
|
||||
@ -1333,9 +1345,18 @@ class SimpleDownsample(torch.nn.Module):
|
||||
# Pad to an exact multiple of self.downsample
|
||||
# right-pad src, repeating the last element.
|
||||
pad = d_seq_len * ds - seq_len
|
||||
src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
|
||||
src = torch.cat((src, src_extra), dim=0)
|
||||
assert src.shape[0] == d_seq_len * ds
|
||||
|
||||
if self.causal and torch.jit.is_tracing():
|
||||
assert (
|
||||
pad == 0
|
||||
), f"pad should be zero for exporting streaming models. Given {pad}"
|
||||
|
||||
# If we are exporting a streaming model, then we skip the if statement
|
||||
if not self.causal or not torch.jit.is_tracing():
|
||||
src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
|
||||
src = torch.cat((src, src_extra), dim=0)
|
||||
|
||||
assert src.shape[0] == d_seq_len * ds, (src.shape, d_seq_len, ds)
|
||||
|
||||
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
|
||||
|
||||
@ -1609,7 +1630,11 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
k = x[..., query_dim : 2 * query_dim]
|
||||
# p is the position-encoding query
|
||||
p = x[..., 2 * query_dim :]
|
||||
assert p.shape[-1] == num_heads * pos_head_dim, (p.shape[-1], num_heads, pos_head_dim)
|
||||
assert p.shape[-1] == num_heads * pos_head_dim, (
|
||||
p.shape[-1],
|
||||
num_heads,
|
||||
pos_head_dim,
|
||||
)
|
||||
|
||||
q = self.copy_query(q) # for diagnostics only, does nothing.
|
||||
k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
|
||||
|
@ -63,8 +63,8 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
ln -svf $(realpath ./open-commands/CN/small/commands.txt) commands_small.txt
|
||||
ln -svf $(realpath ./open-commands/CN/large/commands.txt) commands_large.txt
|
||||
pushd open-commands
|
||||
./script/prepare.sh --stage 1 --stop-stage 1
|
||||
./script/prepare.sh --stage 3 --stop-stage 5
|
||||
./scripts/prepare.sh --stage 1 --stop-stage 1
|
||||
./scripts/prepare.sh --stage 3 --stop-stage 5
|
||||
popd
|
||||
popd
|
||||
pushd data/fbank
|
||||
|
Loading…
x
Reference in New Issue
Block a user