diff --git a/.github/scripts/run-librispeech-streaming-zipformer-2023-05-18.sh b/.github/scripts/run-librispeech-streaming-zipformer-2023-05-18.sh
index 45324cb27..f4e2124b1 100755
--- a/.github/scripts/run-librispeech-streaming-zipformer-2023-05-18.sh
+++ b/.github/scripts/run-librispeech-streaming-zipformer-2023-05-18.sh
@@ -23,6 +23,7 @@ ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/jit_script_chunk_16_left_128.pt"
git lfs pull --include "exp/pretrained.pt"
ln -s pretrained.pt epoch-99.pt
@@ -33,7 +34,7 @@ log "Export to torchscript model"
./zipformer/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
@@ -46,7 +47,7 @@ ls -lh $repo/exp/*.pt
log "Decode with models exported by torch.jit.script()"
./zipformer/jit_pretrained_streaming.py \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
--nn-model-filename $repo/exp/jit_script_chunk_16_left_128.pt \
$repo/test_wavs/1089-134686-0001.wav
@@ -60,7 +61,7 @@ for method in greedy_search modified_beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
diff --git a/.github/scripts/run-librispeech-zipformer-2023-05-18.sh b/.github/scripts/run-librispeech-zipformer-2023-05-18.sh
index 6aac1793e..fb1a0149d 100755
--- a/.github/scripts/run-librispeech-zipformer-2023-05-18.sh
+++ b/.github/scripts/run-librispeech-zipformer-2023-05-18.sh
@@ -23,6 +23,7 @@ ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/jit_script.pt"
git lfs pull --include "exp/pretrained.pt"
ln -s pretrained.pt epoch-99.pt
@@ -33,7 +34,7 @@ log "Export to torchscript model"
./zipformer/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--jit 1
@@ -43,7 +44,7 @@ ls -lh $repo/exp/*.pt
log "Decode with models exported by torch.jit.script()"
./zipformer/jit_pretrained.py \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
--nn-model-filename $repo/exp/jit_script.pt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
@@ -56,7 +57,7 @@ for method in greedy_search modified_beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
diff --git a/.github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh b/.github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh
index cfa9c420c..0026d2109 100755
--- a/.github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh
+++ b/.github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh
@@ -23,6 +23,7 @@ ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "data/lang_bpe_500/HLG.pt"
git lfs pull --include "data/lang_bpe_500/L.pt"
git lfs pull --include "data/lang_bpe_500/LG.pt"
@@ -40,7 +41,7 @@ log "Export to torchscript model"
--use-transducer 1 \
--use-ctc 1 \
--use-averaged-model false \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--jit 1
@@ -51,7 +52,7 @@ log "Decode with models exported by torch.jit.script()"
for method in ctc-decoding 1best; do
./zipformer/jit_pretrained_ctc.py \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
--model-filename $repo/exp/jit_script.pt \
--HLG $repo/data/lang_bpe_500/HLG.pt \
--words-file $repo/data/lang_bpe_500/words.txt \
@@ -71,8 +72,7 @@ for method in ctc-decoding 1best; do
--use-ctc 1 \
--method $method \
--checkpoint $repo/exp/pretrained.pt \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --words-file $repo/data/lang_bpe_500/words.txt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
--HLG $repo/data/lang_bpe_500/HLG.pt \
--G $repo/data/lm/G_4_gram.pt \
--words-file $repo/data/lang_bpe_500/words.txt \
diff --git a/.github/scripts/test-ncnn-export.sh b/.github/scripts/test-ncnn-export.sh
index 52491d2ea..ac16131d0 100755
--- a/.github/scripts/test-ncnn-export.sh
+++ b/.github/scripts/test-ncnn-export.sh
@@ -195,14 +195,14 @@ git lfs pull --include "data/lang_char_bpe/Linv.pt"
git lfs pull --include "exp/pretrained.pt"
cd exp
-ln -s pretrained.pt epoch-99.pt
+ln -s pretrained.pt epoch-9999.pt
popd
./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
--lang-dir $repo/data/lang_char_bpe \
--exp-dir $repo/exp \
--use-averaged-model 0 \
- --epoch 99 \
+ --epoch 9999 \
--avg 1 \
--decode-chunk-len 32 \
--num-encoder-layers "2,4,3,2,4" \
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train.py b/egs/aishell/ASR/pruned_transducer_stateless7/train.py
index ef536c035..cbb7db086 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7/train.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7/train.py
@@ -240,7 +240,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
- default="pruned_transducer_stateless3/exp",
+ default="pruned_transducer_stateless7/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train2.py b/egs/aishell/ASR/pruned_transducer_stateless7/train2.py
index fb35a6c95..c30f6f960 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7/train2.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7/train2.py
@@ -243,7 +243,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
- default="pruned_transducer_stateless3/exp",
+ default="pruned_transducer_stateless7/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py
index fcb0ebc4e..da9000164 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py
@@ -397,7 +397,6 @@ def decode_one_batch(
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
- subtract_ilme=True,
ilme_scale=params.ilme_scale,
)
for hyp in hyp_tokens:
diff --git a/egs/libricss/SURT/README.md b/egs/libricss/SURT/README.md
index dd460906e..10a1aaad1 100644
--- a/egs/libricss/SURT/README.md
+++ b/egs/libricss/SURT/README.md
@@ -41,7 +41,7 @@ The model is a combination of a speech separation model and a speech recognition
but trained end-to-end with a single loss function. The overall architecture is shown
in the figure below. Note that this architecture is slightly different from the one
in the above papers. A detailed description of the model can be found in the following
-paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR]().
+paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR](https://arxiv.org/abs/2306.10559).
@@ -50,7 +50,7 @@ paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR]().
-In the `dprnn_zipformer` recipe, for example, we use a DPRNN-based masking network
+In the [dprnn_zipformer](./dprnn_zipformer) recipe, for example, we use a DPRNN-based masking network
and a Zipfomer-based recognition network. But other combinations are possible as well.
## Training objective
diff --git a/egs/libricss/SURT/dprnn_zipformer/decode.py b/egs/libricss/SURT/dprnn_zipformer/decode.py
index 2054c2dc1..6abbffe00 100755
--- a/egs/libricss/SURT/dprnn_zipformer/decode.py
+++ b/egs/libricss/SURT/dprnn_zipformer/decode.py
@@ -233,22 +233,23 @@ def decode_one_batch(
masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1)
x_masked = [feature * m for m in masks]
- # To save the masks, we split them by batch and trim each mask to the length of
- # the corresponding feature. We save them in a dict, where the key is the
- # cut ID and the value is the mask.
masks_dict = {}
- for i in range(B):
- mask = torch.cat(
- [x_masked[j][i, : feature_lens[i]] for j in range(params.num_channels)],
- dim=-1,
- )
- mask = mask.cpu().numpy()
- masks_dict[batch["cuts"][i].id] = mask
+ if params.save_masks:
+ # To save the masks, we split them by batch and trim each mask to the length of
+ # the corresponding feature. We save them in a dict, where the key is the
+ # cut ID and the value is the mask.
+ for i in range(B):
+ mask = torch.cat(
+ [x_masked[j][i, : feature_lens[i]] for j in range(params.num_channels)],
+ dim=-1,
+ )
+ mask = mask.cpu().numpy()
+ masks_dict[batch["cuts"][i].id] = mask
# Recognition
- # Stack the inputs along the batch axis
+ # Concatenate the inputs along the batch axis
h = torch.cat(x_masked, dim=0)
- h_lens = torch.cat([feature_lens for _ in range(params.num_channels)], dim=0)
+ h_lens = feature_lens.repeat(params.num_channels)
encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens)
if model.joint_encoder_layer is not None:
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
index 1bbad6946..fd59d4b7f 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
@@ -22,6 +22,7 @@ from typing import Dict, List, Optional, Tuple, Union
import k2
import sentencepiece as spm
import torch
+from torch import nn
from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost
from icefall.decode import Nbest, one_best_decoding
@@ -35,7 +36,6 @@ from icefall.utils import (
get_texts,
get_texts_with_timestamp,
)
-from torch import nn
def fast_beam_search_one_best(
@@ -47,9 +47,10 @@ def fast_beam_search_one_best(
max_states: int,
max_contexts: int,
temperature: float = 1.0,
- subtract_ilme: bool = False,
- ilme_scale: float = 0.1,
+ ilme_scale: float = 0.0,
+ blank_penalty: float = 0.0,
return_timestamps: bool = False,
+ allow_partial: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
@@ -90,8 +91,9 @@ def fast_beam_search_one_best(
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
- subtract_ilme=subtract_ilme,
ilme_scale=ilme_scale,
+ allow_partial=allow_partial,
+ blank_penalty=blank_penalty,
)
best_path = one_best_decoding(lattice)
@@ -114,7 +116,10 @@ def fast_beam_search_nbest_LG(
nbest_scale: float = 0.5,
use_double_scores: bool = True,
temperature: float = 1.0,
+ blank_penalty: float = 0.0,
+ ilme_scale: float = 0.0,
return_timestamps: bool = False,
+ allow_partial: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
@@ -168,6 +173,9 @@ def fast_beam_search_nbest_LG(
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
+ allow_partial=allow_partial,
+ blank_penalty=blank_penalty,
+ ilme_scale=ilme_scale,
)
nbest = Nbest.from_lattice(
@@ -240,7 +248,9 @@ def fast_beam_search_nbest(
nbest_scale: float = 0.5,
use_double_scores: bool = True,
temperature: float = 1.0,
+ blank_penalty: float = 0.0,
return_timestamps: bool = False,
+ allow_partial: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
@@ -293,7 +303,9 @@ def fast_beam_search_nbest(
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
+ blank_penalty=blank_penalty,
temperature=temperature,
+ allow_partial=allow_partial,
)
nbest = Nbest.from_lattice(
@@ -331,7 +343,9 @@ def fast_beam_search_nbest_oracle(
use_double_scores: bool = True,
nbest_scale: float = 0.5,
temperature: float = 1.0,
+ blank_penalty: float = 0.0,
return_timestamps: bool = False,
+ allow_partial: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
@@ -389,6 +403,8 @@ def fast_beam_search_nbest_oracle(
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
+ allow_partial=allow_partial,
+ blank_penalty=blank_penalty,
)
nbest = Nbest.from_lattice(
@@ -434,6 +450,8 @@ def fast_beam_search(
temperature: float = 1.0,
subtract_ilme: bool = False,
ilme_scale: float = 0.1,
+ allow_partial: bool = False,
+ blank_penalty: float = 0.0,
) -> k2.Fsa:
"""It limits the maximum number of symbols per frame to 1.
@@ -503,8 +521,13 @@ def fast_beam_search(
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
+
+ if blank_penalty != 0:
+ logits[:, 0] -= blank_penalty
+
log_probs = (logits / temperature).log_softmax(dim=-1)
- if subtract_ilme:
+
+ if ilme_scale != 0:
ilme_logits = model.joiner(
torch.zeros_like(
current_encoder_out, device=current_encoder_out.device
@@ -513,11 +536,16 @@ def fast_beam_search(
project_input=False,
)
ilme_logits = ilme_logits.squeeze(1).squeeze(1)
+ if blank_penalty != 0:
+ ilme_logits[:, 0] -= blank_penalty
ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1)
log_probs -= ilme_scale * ilme_log_probs
+
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
- lattice = decoding_streams.format_output(encoder_out_lens.tolist())
+ lattice = decoding_streams.format_output(
+ encoder_out_lens.tolist(), allow_partial=allow_partial
+ )
return lattice
@@ -526,6 +554,7 @@ def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
max_sym_per_frame: int,
+ blank_penalty: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]:
"""Greedy search for a single utterance.
@@ -595,6 +624,9 @@ def greedy_search(
)
# logits is (1, 1, 1, vocab_size)
+ if blank_penalty != 0:
+ logits[:, :, :, 0] -= blank_penalty
+
y = logits.argmax().item()
if y not in (blank_id, unk_id):
hyp.append(y)
@@ -626,6 +658,7 @@ def greedy_search_batch(
model: nn.Module,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
+ blank_penalty: float = 0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
@@ -703,6 +736,10 @@ def greedy_search_batch(
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
+
+ if blank_penalty != 0:
+ logits[:, 0] -= blank_penalty
+
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
@@ -923,6 +960,7 @@ def modified_beam_search(
context_graph: Optional[ContextGraph] = None,
beam: int = 4,
temperature: float = 1.0,
+ blank_penalty: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
@@ -1028,6 +1066,9 @@ def modified_beam_search(
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
+ if blank_penalty != 0:
+ logits[:, 0] -= blank_penalty
+
log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
@@ -1662,6 +1703,7 @@ def beam_search(
encoder_out: torch.Tensor,
beam: int = 4,
temperature: float = 1.0,
+ blank_penalty: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]:
"""
@@ -1758,6 +1800,9 @@ def beam_search(
project_input=False,
)
+ if blank_penalty != 0:
+ logits[:, :, :, 0] -= blank_penalty
+
# TODO(fangjun): Scale the blank posterior
log_prob = (logits / temperature).log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
index b76272e66..a0f54b6e1 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
@@ -22,7 +22,7 @@ Usage:
--avg 15 \
--decode-chunk-len 32 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
- --decoding_method greedy_search \
+ --decoding-method greedy_search \
--num-decode-streams 2000
"""
diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py
index 8cec09869..3eb06f68c 100755
--- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py
+++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
#
-# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
+# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
"""
@@ -19,7 +19,7 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
-git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/pretrained.pt"
cd exp
@@ -29,7 +29,7 @@ popd
2. Export the model to ONNX
./zipformer/export-onnx-streaming.py \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
@@ -57,9 +57,9 @@ whose value is "64,128,256,-1".
It will generate the following 3 files inside $repo/exp:
- - encoder-epoch-99-avg-1.onnx
- - decoder-epoch-99-avg-1.onnx
- - joiner-epoch-99-avg-1.onnx
+ - encoder-epoch-99-avg-1-chunk-16-left-64.onnx
+ - decoder-epoch-99-avg-1-chunk-16-left-64.onnx
+ - joiner-epoch-99-avg-1-chunk-16-left-64.onnx
See ./onnx_pretrained-streaming.py for how to use the exported ONNX models.
"""
@@ -69,14 +69,15 @@ import logging
from pathlib import Path
from typing import Dict, List, Tuple
+import k2
import onnx
-import sentencepiece as spm
import torch
import torch.nn as nn
from decoder import Decoder
+from export import num_tokens
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
-from train import add_model_arguments, get_params, get_model
+from train import add_model_arguments, get_model, get_params
from zipformer import Zipformer2
from icefall.checkpoint import (
@@ -85,7 +86,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
-from icefall.utils import str2bool, make_pad_mask
+from icefall.utils import str2bool
def get_parser():
@@ -142,10 +143,10 @@ def get_parser():
)
parser.add_argument(
- "--bpe-model",
+ "--tokens",
type=str,
- default="data/lang_bpe_500/bpe.model",
- help="Path to the BPE model",
+ default="data/lang_bpe_500/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -217,7 +218,7 @@ class OnnxEncoder(nn.Module):
)
assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size)
- src_key_padding_mask = make_pad_mask(x_lens)
+ src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
@@ -271,6 +272,7 @@ class OnnxEncoder(nn.Module):
states = self.encoder.get_init_states(batch_size, device)
embed_states = self.encoder_embed.get_init_states(batch_size, device)
+
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)
@@ -585,12 +587,9 @@ def main():
logging.info(f"device: {device}")
- sp = spm.SentencePieceProcessor()
- sp.load(params.bpe_model)
-
- # is defined in local/train_bpe_model.py
- params.blank_id = sp.piece_to_id("")
- params.vocab_size = sp.get_piece_size()
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
@@ -709,6 +708,8 @@ def main():
suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}"
+ suffix += f"-chunk-{params.chunk_size}"
+ suffix += f"-left-{params.left_context_frames}"
opset_version = 13
@@ -756,7 +757,7 @@ def main():
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
- op_types_to_quantize=["MatMul"],
+ op_types_to_quantize=["MatMul", "Gather"],
weight_type=QuantType.QInt8,
)
diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py
index f5b01ce71..724fdd2a6 100755
--- a/egs/librispeech/ASR/zipformer/export-onnx.py
+++ b/egs/librispeech/ASR/zipformer/export-onnx.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
#
-# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
+# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
"""
@@ -19,7 +19,7 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
-git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/pretrained.pt"
cd exp
@@ -29,12 +29,11 @@ popd
2. Export the model to ONNX
./zipformer/export-onnx.py \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
- \
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
@@ -67,14 +66,15 @@ import logging
from pathlib import Path
from typing import Dict, Tuple
+import k2
import onnx
-import sentencepiece as spm
import torch
import torch.nn as nn
from decoder import Decoder
+from export import num_tokens
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
-from train import add_model_arguments, get_params, get_model
+from train import add_model_arguments, get_model, get_params
from zipformer import Zipformer2
from icefall.checkpoint import (
@@ -83,7 +83,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
-from icefall.utils import str2bool, make_pad_mask
+from icefall.utils import make_pad_mask, str2bool
def get_parser():
@@ -140,10 +140,10 @@ def get_parser():
)
parser.add_argument(
- "--bpe-model",
+ "--tokens",
type=str,
- default="data/lang_bpe_500/bpe.model",
- help="Path to the BPE model",
+ default="data/lang_bpe_500/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -434,12 +434,9 @@ def main():
logging.info(f"device: {device}")
- sp = spm.SentencePieceProcessor()
- sp.load(params.bpe_model)
-
- # is defined in local/train_bpe_model.py
- params.blank_id = sp.piece_to_id("")
- params.vocab_size = sp.get_piece_size()
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
@@ -605,7 +602,7 @@ def main():
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
- op_types_to_quantize=["MatMul"],
+ op_types_to_quantize=["MatMul", "Gather"],
weight_type=QuantType.QInt8,
)
diff --git a/egs/librispeech/ASR/zipformer/export.py b/egs/librispeech/ASR/zipformer/export.py
index a100cbb8d..4a48d5bad 100755
--- a/egs/librispeech/ASR/zipformer/export.py
+++ b/egs/librispeech/ASR/zipformer/export.py
@@ -1,6 +1,8 @@
#!/usr/bin/env python3
#
-# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao)
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@@ -22,13 +24,16 @@
Usage:
+Note: This is a example for librispeech dataset, if you are using different
+dataset, you should change the argument values according to your dataset.
+
(1) Export to torchscript model using torch.jit.script()
- For non-streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
@@ -48,7 +53,7 @@ for how to use the exported models outside of icefall.
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
@@ -67,7 +72,7 @@ for how to use the exported models outside of icefall.
./zipformer/export.py \
--exp-dir ./zipformer/exp \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
@@ -76,7 +81,7 @@ for how to use the exported models outside of icefall.
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--causal 1 \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
@@ -155,13 +160,15 @@ with the following commands:
import argparse
import logging
+import re
from pathlib import Path
from typing import List, Tuple
-import sentencepiece as spm
+import k2
import torch
+from scaling_converter import convert_scaled_to_non_scaled
from torch import Tensor, nn
-from train import add_model_arguments, get_params, get_model
+from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
@@ -170,7 +177,26 @@ from icefall.checkpoint import (
load_checkpoint,
)
from icefall.utils import make_pad_mask, str2bool
-from scaling_converter import convert_scaled_to_non_scaled
+
+
+def num_tokens(
+ token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
+) -> int:
+ """Return the number of tokens excluding those from
+ disambiguation symbols.
+
+ Caution:
+ 0 is not a token ID so it is excluded from the return value.
+ """
+ symbols = token_table.symbols
+ ans = []
+ for s in symbols:
+ if not disambig_pattern.match(s):
+ ans.append(token_table[s])
+ num_tokens = len(ans)
+ if 0 in ans:
+ num_tokens -= 1
+ return num_tokens
def get_parser():
@@ -227,10 +253,10 @@ def get_parser():
)
parser.add_argument(
- "--bpe-model",
+ "--tokens",
type=str,
- default="data/lang_bpe_500/bpe.model",
- help="Path to the BPE model",
+ default="data/lang_bpe_500/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -238,7 +264,7 @@ def get_parser():
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
- It will generate a file named cpu_jit.pt.
+ It will generate a file named jit_script.pt.
Check ./jit_pretrained.py for how to use it.
""",
)
@@ -257,6 +283,7 @@ def get_parser():
class EncoderModel(nn.Module):
"""A wrapper for encoder and encoder_embed"""
+
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
super().__init__()
self.encoder = encoder
@@ -275,9 +302,7 @@ class EncoderModel(nn.Module):
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
- encoder_out, encoder_out_lens = self.encoder(
- x, x_lens, src_key_padding_mask
- )
+ encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return encoder_out, encoder_out_lens
@@ -398,12 +423,9 @@ def main():
logging.info(f"device: {device}")
- sp = spm.SentencePieceProcessor()
- sp.load(params.bpe_model)
-
- # is defined in local/train_bpe_model.py
- params.blank_id = sp.piece_to_id("")
- params.vocab_size = sp.get_piece_size()
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/librispeech/ASR/zipformer/generate_averaged_model.py b/egs/librispeech/ASR/zipformer/generate_averaged_model.py
index e0c7b52cb..68111fad7 100755
--- a/egs/librispeech/ASR/zipformer/generate_averaged_model.py
+++ b/egs/librispeech/ASR/zipformer/generate_averaged_model.py
@@ -40,16 +40,11 @@ You can later load it by `torch.load("iter-22000-avg-5.pt")`.
import argparse
from pathlib import Path
-import sentencepiece as spm
+import k2
import torch
-from asr_datamodule import LibriSpeechAsrDataModule
+from train import add_model_arguments, get_model, get_params
-from train import add_model_arguments, get_params, get_model
-
-from icefall.checkpoint import (
- average_checkpoints_with_averaged_model,
- find_checkpoints,
-)
+from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints
def get_parser():
@@ -93,10 +88,10 @@ def get_parser():
)
parser.add_argument(
- "--bpe-model",
+ "--tokens",
type=str,
- default="data/lang_bpe_500/bpe.model",
- help="Path to the BPE model",
+ default="data/lang_bpe_500/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -114,7 +109,6 @@ def get_parser():
@torch.no_grad()
def main():
parser = get_parser()
- LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@@ -131,13 +125,10 @@ def main():
device = torch.device("cpu")
print(f"Device: {device}")
- sp = spm.SentencePieceProcessor()
- sp.load(params.bpe_model)
-
- # is defined in local/train_bpe_model.py
- params.blank_id = sp.piece_to_id("")
- params.unk_id = sp.piece_to_id("")
- params.vocab_size = sp.get_piece_size()
+ symbol_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = symbol_table[""]
+ params.unk_id = symbol_table[""]
+ params.vocab_size = len(symbol_table)
print("About to create model")
model = get_model(params)
diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained.py b/egs/librispeech/ASR/zipformer/jit_pretrained.py
index 87cd5102c..a41fbc1c9 100755
--- a/egs/librispeech/ASR/zipformer/jit_pretrained.py
+++ b/egs/librispeech/ASR/zipformer/jit_pretrained.py
@@ -21,7 +21,7 @@ You can use the following command to get the exported models:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
@@ -30,7 +30,7 @@ Usage of this script:
./zipformer/jit_pretrained.py \
--nn-model-filename ./zipformer/exp/cpu_jit.pt \
- --bpe-model ./data/lang_bpe_500/bpe.model \
+ --tokens ./data/lang_bpe_500/tokens.txt \
/path/to/foo.wav \
/path/to/bar.wav
"""
@@ -40,8 +40,8 @@ import logging
import math
from typing import List
+import k2
import kaldifeat
-import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
@@ -60,9 +60,9 @@ def get_parser():
)
parser.add_argument(
- "--bpe-model",
+ "--tokens",
type=str,
- help="""Path to bpe.model.""",
+ help="""Path to tokens.txt.""",
)
parser.add_argument(
@@ -128,7 +128,7 @@ def greedy_search(
)
device = encoder_out.device
- blank_id = 0 # hard-code to 0
+ blank_id = model.decoder.blank_id
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
@@ -215,9 +215,6 @@ def main():
model.to(device)
- sp = spm.SentencePieceProcessor()
- sp.load(args.bpe_model)
-
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
@@ -256,10 +253,21 @@ def main():
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
+
s = "\n"
+
+ token_table = k2.SymbolTable.from_file(args.tokens)
+
+ def token_ids_to_words(token_ids: List[int]) -> str:
+ text = ""
+ for i in token_ids:
+ text += token_table[i]
+ return text.replace("▁", " ").strip()
+
for filename, hyp in zip(args.sound_files, hyps):
- words = sp.decode(hyp)
- s += f"{filename}:\n{words}\n\n"
+ words = token_ids_to_words(hyp)
+ s += f"{filename}:\n{words}\n"
+
logging.info(s)
logging.info("Decoding Done")
diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py
index 1ec390d5b..14faeedd1 100755
--- a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py
+++ b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py
@@ -24,7 +24,7 @@ You can generate the checkpoint with the following command:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
@@ -35,7 +35,7 @@ You can generate the checkpoint with the following command:
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--causal 1 \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
@@ -45,7 +45,7 @@ Usage of this script:
(1) ctc-decoding
./zipformer/jit_pretrained_ctc.py \
--model-filename ./zipformer/exp/jit_script.pt \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens data/lang_bpe_500/tokens.txt \
--method ctc-decoding \
--sample-rate 16000 \
/path/to/foo.wav \
@@ -91,10 +91,10 @@ from typing import List
import k2
import kaldifeat
-import sentencepiece as spm
import torch
import torchaudio
from ctc_decode import get_decoding_params
+from export import num_tokens
from torch.nn.utils.rnn import pad_sequence
from train import get_params
@@ -136,9 +136,9 @@ def get_parser():
)
parser.add_argument(
- "--bpe-model",
+ "--tokens",
type=str,
- help="""Path to bpe.model.
+ help="""Path to tokens.txt.
Used only when method is ctc-decoding.
""",
)
@@ -149,8 +149,8 @@ def get_parser():
default="1best",
help="""Decoding method.
Possible values are:
- (0) ctc-decoding - Use CTC decoding. It uses a sentence
- piece model, i.e., lang_dir/bpe.model, to convert
+ (0) ctc-decoding - Use CTC decoding. It uses a token table,
+ i.e., lang_dir/token.txt, to convert
word pieces to words. It needs neither a lexicon
nor an n-gram LM.
(1) 1best - Use the best path as decoding output. Only
@@ -263,10 +263,8 @@ def main():
params.update(get_decoding_params())
params.update(vars(args))
- sp = spm.SentencePieceProcessor()
- sp.load(params.bpe_model)
-
- params.vocab_size = sp.get_piece_size()
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.vocab_size = num_tokens(token_table)
logging.info(f"{params}")
@@ -340,8 +338,7 @@ def main():
lattice=lattice, use_double_scores=params.use_double_scores
)
token_ids = get_texts(best_path)
- hyps = sp.decode(token_ids)
- hyps = [s.split() for s in hyps]
+ hyps = [[token_table[i] for i in ids] for ids in token_ids]
elif params.method in [
"1best",
"nbest-rescoring",
@@ -415,6 +412,7 @@ def main():
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
+ words = words.replace("▁", " ").strip()
s += f"{filename}:\n{words}\n\n"
logging.info(s)
diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py b/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py
index 58d736685..d4ceacefd 100755
--- a/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py
+++ b/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py
@@ -25,7 +25,7 @@ You can use the following command to get the exported models:
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
@@ -34,7 +34,7 @@ Usage of this script:
./zipformer/jit_pretrained_streaming.py \
--nn-model-filename ./zipformer/exp-causal/jit_script_chunk_16_left_128.pt \
- --bpe-model ./data/lang_bpe_500/bpe.model \
+ --tokens ./data/lang_bpe_500/tokens.txt \
/path/to/foo.wav \
"""
@@ -43,8 +43,8 @@ import logging
import math
from typing import List, Optional
+import k2
import kaldifeat
-import sentencepiece as spm
import torch
import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
@@ -60,13 +60,13 @@ def get_parser():
"--nn-model-filename",
type=str,
required=True,
- help="Path to the torchscript model cpu_jit.pt",
+ help="Path to the torchscript model jit_script.pt",
)
parser.add_argument(
- "--bpe-model",
+ "--tokens",
type=str,
- help="""Path to bpe.model.""",
+ help="""Path to tokens.txt.""",
)
parser.add_argument(
@@ -120,8 +120,8 @@ def greedy_search(
device: torch.device = torch.device("cpu"),
):
assert encoder_out.ndim == 2
- context_size = 2
- blank_id = 0
+ context_size = decoder.context_size
+ blank_id = decoder.blank_id
if decoder_out is None:
assert hyp is None, hyp
@@ -190,8 +190,8 @@ def main():
decoder = model.decoder
joiner = model.joiner
- sp = spm.SentencePieceProcessor()
- sp.load(args.bpe_model)
+ token_table = k2.SymbolTable.from_file(args.tokens)
+ context_size = decoder.context_size
logging.info("Constructing Fbank computer")
online_fbank = create_streaming_feature_extractor(args.sample_rate)
@@ -250,9 +250,13 @@ def main():
decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp, device=device
)
- context_size = 2
+ text = ""
+ for i in hyp[context_size:]:
+ text += token_table[i]
+ text = text.replace("▁", " ").strip()
+
logging.info(args.sound_file)
- logging.info(sp.decode(hyp[context_size:]))
+ logging.info(text)
logging.info("Decoding Done")
diff --git a/egs/librispeech/ASR/zipformer/onnx_check.py b/egs/librispeech/ASR/zipformer/onnx_check.py
new file mode 100755
index 000000000..b38b875d0
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer/onnx_check.py
@@ -0,0 +1,241 @@
+#!/usr/bin/env python3
+#
+# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script checks that exported onnx models produce the same output
+with the given torchscript model for the same input.
+
+We use the pre-trained model from
+https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
+as an example to show how to use this file.
+
+1. Download the pre-trained model
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+git lfs pull --include "data/lang_bpe_500/tokens.txt"
+git lfs pull --include "exp/pretrained.pt"
+
+cd exp
+ln -s pretrained.pt epoch-99.pt
+popd
+
+2. Export the model via torchscript (torch.jit.script())
+
+./zipformer/export.py \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --use-averaged-model 0 \
+ --epoch 99 \
+ --avg 1 \
+ --exp-dir $repo/exp/ \
+ --jit 1
+
+It will generate the following file in $repo/exp:
+ - jit_script.pt
+
+3. Export the model to ONNX
+
+./zipformer/export-onnx.py \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --use-averaged-model 0 \
+ --epoch 99 \
+ --avg 1 \
+ --exp-dir $repo/exp/
+
+It will generate the following 3 files inside $repo/exp:
+
+ - encoder-epoch-99-avg-1.onnx
+ - decoder-epoch-99-avg-1.onnx
+ - joiner-epoch-99-avg-1.onnx
+
+4. Run this file
+
+./zipformer/onnx_check.py \
+ --jit-filename $repo/exp/jit_script.pt \
+ --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
+ --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
+ --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx
+"""
+
+import argparse
+import logging
+
+import torch
+from onnx_pretrained import OnnxModel
+
+from icefall import is_module_available
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--jit-filename",
+ required=True,
+ type=str,
+ help="Path to the torchscript model",
+ )
+
+ parser.add_argument(
+ "--onnx-encoder-filename",
+ required=True,
+ type=str,
+ help="Path to the onnx encoder model",
+ )
+
+ parser.add_argument(
+ "--onnx-decoder-filename",
+ required=True,
+ type=str,
+ help="Path to the onnx decoder model",
+ )
+
+ parser.add_argument(
+ "--onnx-joiner-filename",
+ required=True,
+ type=str,
+ help="Path to the onnx joiner model",
+ )
+
+ return parser
+
+
+def test_encoder(
+ torch_model: torch.jit.ScriptModule,
+ onnx_model: OnnxModel,
+):
+ C = 80
+ for i in range(3):
+ N = torch.randint(low=1, high=20, size=(1,)).item()
+ T = torch.randint(low=30, high=50, size=(1,)).item()
+ logging.info(f"test_encoder: iter {i}, N={N}, T={T}")
+
+ x = torch.rand(N, T, C)
+ x_lens = torch.randint(low=30, high=T + 1, size=(N,))
+ x_lens[0] = T
+
+ torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens)
+ torch_encoder_out = torch_model.joiner.encoder_proj(torch_encoder_out)
+
+ onnx_encoder_out, onnx_encoder_out_lens = onnx_model.run_encoder(x, x_lens)
+
+ assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-05), (
+ (torch_encoder_out - onnx_encoder_out).abs().max()
+ )
+
+
+def test_decoder(
+ torch_model: torch.jit.ScriptModule,
+ onnx_model: OnnxModel,
+):
+ context_size = onnx_model.context_size
+ vocab_size = onnx_model.vocab_size
+ for i in range(10):
+ N = torch.randint(1, 100, size=(1,)).item()
+ logging.info(f"test_decoder: iter {i}, N={N}")
+ x = torch.randint(
+ low=1,
+ high=vocab_size,
+ size=(N, context_size),
+ dtype=torch.int64,
+ )
+ torch_decoder_out = torch_model.decoder(x, need_pad=torch.tensor([False]))
+ torch_decoder_out = torch_model.joiner.decoder_proj(torch_decoder_out)
+ torch_decoder_out = torch_decoder_out.squeeze(1)
+
+ onnx_decoder_out = onnx_model.run_decoder(x)
+ assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), (
+ (torch_decoder_out - onnx_decoder_out).abs().max()
+ )
+
+
+def test_joiner(
+ torch_model: torch.jit.ScriptModule,
+ onnx_model: OnnxModel,
+):
+ encoder_dim = torch_model.joiner.encoder_proj.weight.shape[1]
+ decoder_dim = torch_model.joiner.decoder_proj.weight.shape[1]
+ for i in range(10):
+ N = torch.randint(1, 100, size=(1,)).item()
+ logging.info(f"test_joiner: iter {i}, N={N}")
+ encoder_out = torch.rand(N, encoder_dim)
+ decoder_out = torch.rand(N, decoder_dim)
+
+ projected_encoder_out = torch_model.joiner.encoder_proj(encoder_out)
+ projected_decoder_out = torch_model.joiner.decoder_proj(decoder_out)
+
+ torch_joiner_out = torch_model.joiner(encoder_out, decoder_out)
+ onnx_joiner_out = onnx_model.run_joiner(
+ projected_encoder_out, projected_decoder_out
+ )
+
+ assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), (
+ (torch_joiner_out - onnx_joiner_out).abs().max()
+ )
+
+
+@torch.no_grad()
+def main():
+ args = get_parser().parse_args()
+ logging.info(vars(args))
+
+ torch_model = torch.jit.load(args.jit_filename)
+
+ onnx_model = OnnxModel(
+ encoder_model_filename=args.onnx_encoder_filename,
+ decoder_model_filename=args.onnx_decoder_filename,
+ joiner_model_filename=args.onnx_joiner_filename,
+ )
+
+ logging.info("Test encoder")
+ test_encoder(torch_model, onnx_model)
+
+ logging.info("Test decoder")
+ test_decoder(torch_model, onnx_model)
+
+ logging.info("Test joiner")
+ test_joiner(torch_model, onnx_model)
+ logging.info("Finished checking ONNX models")
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+# See https://github.com/pytorch/pytorch/issues/38342
+# and https://github.com/pytorch/pytorch/issues/33354
+#
+# If we don't do this, the delay increases whenever there is
+# a new request that changes the actual batch size.
+# If you use `py-spy dump --pid --native`, you will
+# see a lot of time is spent in re-compiling the torch script model.
+torch._C._jit_set_profiling_executor(False)
+torch._C._jit_set_profiling_mode(False)
+torch._C._set_graph_executor_optimize(False)
+if __name__ == "__main__":
+ torch.manual_seed(20220727)
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/librispeech/ASR/zipformer/onnx_decode.py b/egs/librispeech/ASR/zipformer/onnx_decode.py
new file mode 100755
index 000000000..2aca36ca9
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer/onnx_decode.py
@@ -0,0 +1,323 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads ONNX exported models and uses them to decode the test sets.
+
+We use the pre-trained model from
+https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
+as an example to show how to use this file.
+
+1. Download the pre-trained model
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "exp/pretrained.pt"
+
+cd exp
+ln -s pretrained.pt epoch-99.pt
+popd
+
+2. Export the model to ONNX
+
+./zipformer/export-onnx.py \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --use-averaged-model 0 \
+ --epoch 99 \
+ --avg 1 \
+ --exp-dir $repo/exp \
+ --causal False
+
+It will generate the following 3 files inside $repo/exp:
+
+ - encoder-epoch-99-avg-1.onnx
+ - decoder-epoch-99-avg-1.onnx
+ - joiner-epoch-99-avg-1.onnx
+
+2. Run this file
+
+./zipformer/onnx_decode.py \
+ --exp-dir $repo/exp \
+ --max-duration 600 \
+ --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
+ --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
+ --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+"""
+
+
+import argparse
+import logging
+import time
+from pathlib import Path
+from typing import List, Tuple
+
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+
+from onnx_pretrained import greedy_search, OnnxModel
+
+from icefall.utils import setup_logger, store_transcripts, write_error_stats
+from k2 import SymbolTable
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--encoder-model-filename",
+ type=str,
+ required=True,
+ help="Path to the encoder onnx model. ",
+ )
+
+ parser.add_argument(
+ "--decoder-model-filename",
+ type=str,
+ required=True,
+ help="Path to the decoder onnx model. ",
+ )
+
+ parser.add_argument(
+ "--joiner-model-filename",
+ type=str,
+ required=True,
+ help="Path to the joiner onnx model. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ help="""Path to tokens.txt.""",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="Valid values are greedy_search and modified_beam_search",
+ )
+
+ return parser
+
+
+def decode_one_batch(
+ model: OnnxModel, token_table: SymbolTable, batch: dict
+) -> List[List[str]]:
+ """Decode one batch and return the result.
+ Currently it only greedy_search is supported.
+
+ Args:
+ model:
+ The neural model.
+ token_table:
+ The token table.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+
+ Returns:
+ Return the decoded results for each utterance.
+ """
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(dtype=torch.int64)
+
+ encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens)
+
+ hyps = greedy_search(
+ model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
+ )
+
+ def token_ids_to_words(token_ids: List[int]) -> str:
+ text = ""
+ for i in token_ids:
+ text += token_table[i]
+ return text.replace("▁", " ").strip()
+
+ hyps = [token_ids_to_words(h).split() for h in hyps]
+ return hyps
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ model: nn.Module,
+ token_table: SymbolTable,
+) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ model:
+ The neural model.
+ token_table:
+ The token table.
+
+ Returns:
+ - A list of tuples. Each tuple contains three elements:
+ - cut_id,
+ - reference transcript,
+ - predicted result.
+ - The total duration (in seconds) of the dataset.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ log_interval = 10
+ total_duration = 0
+
+ results = []
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+ total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
+
+ hyps = decode_one_batch(model=model, token_table=token_table, batch=batch)
+
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results.extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+
+ return results, total_duration
+
+
+def save_results(
+ res_dir: Path,
+ test_set_name: str,
+ results: List[Tuple[str, List[str], List[str]]],
+):
+ recog_path = res_dir / f"recogs-{test_set_name}.txt"
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = res_dir / f"errs-{test_set_name}.txt"
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ errs_info = res_dir / f"wer-summary-{test_set_name}.txt"
+ with open(errs_info, "w") as f:
+ print("WER", file=f)
+ print(wer, file=f)
+
+ s = "\nFor {}, WER is {}:\n".format(test_set_name, wer)
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+
+ assert (
+ args.decoding_method == "greedy_search"
+ ), "Only supports greedy_search currently."
+ res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}"
+
+ setup_logger(f"{res_dir}/log-decode")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ logging.info(f"Device: {device}")
+
+ token_table = SymbolTable.from_file(args.tokens)
+
+ logging.info(vars(args))
+
+ logging.info("About to create model")
+ model = OnnxModel(
+ encoder_model_filename=args.encoder_model_filename,
+ decoder_model_filename=args.decoder_model_filename,
+ joiner_model_filename=args.joiner_model_filename,
+ )
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ test_clean_cuts = librispeech.test_clean_cuts()
+ test_other_cuts = librispeech.test_other_cuts()
+
+ test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
+ test_other_dl = librispeech.test_dataloaders(test_other_cuts)
+
+ test_sets = ["test-clean", "test-other"]
+ test_dl = [test_clean_dl, test_other_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ start_time = time.time()
+ results, total_duration = decode_dataset(dl=test_dl, model=model, token_table=token_table)
+ end_time = time.time()
+ elapsed_seconds = end_time - start_time
+ rtf = elapsed_seconds / total_duration
+
+ logging.info(f"Elapsed time: {elapsed_seconds:.3f} s")
+ logging.info(f"Wave duration: {total_duration:.3f} s")
+ logging.info(
+ f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
+ )
+
+ save_results(res_dir=res_dir, test_set_name=test_set, results=results)
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py
index 273f883df..2ce4506a8 100755
--- a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py
+++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py
@@ -524,11 +524,11 @@ def main():
hyp,
)
- symbol_table = k2.SymbolTable.from_file(args.tokens)
+ token_table = k2.SymbolTable.from_file(args.tokens)
text = ""
for i in hyp[context_size:]:
- text += symbol_table[i]
+ text += token_table[i]
text = text.replace("▁", " ").strip()
logging.info(args.sound_file)
diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained.py b/egs/librispeech/ASR/zipformer/onnx_pretrained.py
deleted file mode 120000
index 0069288fe..000000000
--- a/egs/librispeech/ASR/zipformer/onnx_pretrained.py
+++ /dev/null
@@ -1 +0,0 @@
-../pruned_transducer_stateless7/onnx_pretrained.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained.py b/egs/librispeech/ASR/zipformer/onnx_pretrained.py
new file mode 100755
index 000000000..e8a521460
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer/onnx_pretrained.py
@@ -0,0 +1,419 @@
+#!/usr/bin/env python3
+# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads ONNX models and uses them to decode waves.
+You can use the following command to get the exported models:
+
+We use the pre-trained model from
+https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
+as an example to show how to use this file.
+
+1. Download the pre-trained model
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+git lfs pull --include "data/lang_bpe_500/tokens.txt"
+git lfs pull --include "exp/pretrained.pt"
+
+cd exp
+ln -s pretrained.pt epoch-99.pt
+popd
+
+2. Export the model to ONNX
+
+./zipformer/export-onnx.py \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --use-averaged-model 0 \
+ --epoch 99 \
+ --avg 1 \
+ --exp-dir $repo/exp \
+ --causal False
+
+It will generate the following 3 files inside $repo/exp:
+
+ - encoder-epoch-99-avg-1.onnx
+ - decoder-epoch-99-avg-1.onnx
+ - joiner-epoch-99-avg-1.onnx
+
+3. Run this file
+
+./zipformer/onnx_pretrained.py \
+ --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
+ --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
+ --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+"""
+
+import argparse
+import logging
+import math
+from typing import List, Tuple
+
+import k2
+import kaldifeat
+import onnxruntime as ort
+import torch
+import torchaudio
+from torch.nn.utils.rnn import pad_sequence
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--encoder-model-filename",
+ type=str,
+ required=True,
+ help="Path to the encoder onnx model. ",
+ )
+
+ parser.add_argument(
+ "--decoder-model-filename",
+ type=str,
+ required=True,
+ help="Path to the decoder onnx model. ",
+ )
+
+ parser.add_argument(
+ "--joiner-model-filename",
+ type=str,
+ required=True,
+ help="Path to the joiner onnx model. ",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ help="""Path to tokens.txt.""",
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=int,
+ default=16000,
+ help="The sample rate of the input sound file",
+ )
+
+ return parser
+
+
+class OnnxModel:
+ def __init__(
+ self,
+ encoder_model_filename: str,
+ decoder_model_filename: str,
+ joiner_model_filename: str,
+ ):
+ session_opts = ort.SessionOptions()
+ session_opts.inter_op_num_threads = 1
+ session_opts.intra_op_num_threads = 4
+
+ self.session_opts = session_opts
+
+ self.init_encoder(encoder_model_filename)
+ self.init_decoder(decoder_model_filename)
+ self.init_joiner(joiner_model_filename)
+
+ def init_encoder(self, encoder_model_filename: str):
+ self.encoder = ort.InferenceSession(
+ encoder_model_filename,
+ sess_options=self.session_opts,
+ )
+
+ def init_decoder(self, decoder_model_filename: str):
+ self.decoder = ort.InferenceSession(
+ decoder_model_filename,
+ sess_options=self.session_opts,
+ )
+
+ decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
+ self.context_size = int(decoder_meta["context_size"])
+ self.vocab_size = int(decoder_meta["vocab_size"])
+
+ logging.info(f"context_size: {self.context_size}")
+ logging.info(f"vocab_size: {self.vocab_size}")
+
+ def init_joiner(self, joiner_model_filename: str):
+ self.joiner = ort.InferenceSession(
+ joiner_model_filename,
+ sess_options=self.session_opts,
+ )
+
+ joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
+ self.joiner_dim = int(joiner_meta["joiner_dim"])
+
+ logging.info(f"joiner_dim: {self.joiner_dim}")
+
+ def run_encoder(
+ self,
+ x: torch.Tensor,
+ x_lens: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ x:
+ A 3-D tensor of shape (N, T, C)
+ x_lens:
+ A 2-D tensor of shape (N,). Its dtype is torch.int64
+ Returns:
+ Return a tuple containing:
+ - encoder_out, its shape is (N, T', joiner_dim)
+ - encoder_out_lens, its shape is (N,)
+ """
+ out = self.encoder.run(
+ [
+ self.encoder.get_outputs()[0].name,
+ self.encoder.get_outputs()[1].name,
+ ],
+ {
+ self.encoder.get_inputs()[0].name: x.numpy(),
+ self.encoder.get_inputs()[1].name: x_lens.numpy(),
+ },
+ )
+ return torch.from_numpy(out[0]), torch.from_numpy(out[1])
+
+ def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ decoder_input:
+ A 2-D tensor of shape (N, context_size)
+ Returns:
+ Return a 2-D tensor of shape (N, joiner_dim)
+ """
+ out = self.decoder.run(
+ [self.decoder.get_outputs()[0].name],
+ {self.decoder.get_inputs()[0].name: decoder_input.numpy()},
+ )[0]
+
+ return torch.from_numpy(out)
+
+ def run_joiner(
+ self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ Args:
+ encoder_out:
+ A 2-D tensor of shape (N, joiner_dim)
+ decoder_out:
+ A 2-D tensor of shape (N, joiner_dim)
+ Returns:
+ Return a 2-D tensor of shape (N, vocab_size)
+ """
+ out = self.joiner.run(
+ [self.joiner.get_outputs()[0].name],
+ {
+ self.joiner.get_inputs()[0].name: encoder_out.numpy(),
+ self.joiner.get_inputs()[1].name: decoder_out.numpy(),
+ },
+ )[0]
+
+ return torch.from_numpy(out)
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+ # We use only the first channel
+ ans.append(wave[0])
+ return ans
+
+
+def greedy_search(
+ model: OnnxModel,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+) -> List[List[int]]:
+ """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
+ Args:
+ model:
+ The transducer model.
+ encoder_out:
+ A 3-D tensor of shape (N, T, joiner_dim)
+ encoder_out_lens:
+ A 1-D tensor of shape (N,).
+ Returns:
+ Return the decoded results for each utterance.
+ """
+ assert encoder_out.ndim == 3, encoder_out.shape
+ assert encoder_out.size(0) >= 1, encoder_out.size(0)
+
+ packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
+ input=encoder_out,
+ lengths=encoder_out_lens.cpu(),
+ batch_first=True,
+ enforce_sorted=False,
+ )
+
+ blank_id = 0 # hard-code to 0
+
+ batch_size_list = packed_encoder_out.batch_sizes.tolist()
+ N = encoder_out.size(0)
+
+ assert torch.all(encoder_out_lens > 0), encoder_out_lens
+ assert N == batch_size_list[0], (N, batch_size_list)
+
+ context_size = model.context_size
+ hyps = [[blank_id] * context_size for _ in range(N)]
+
+ decoder_input = torch.tensor(
+ hyps,
+ dtype=torch.int64,
+ ) # (N, context_size)
+
+ decoder_out = model.run_decoder(decoder_input)
+
+ offset = 0
+ for batch_size in batch_size_list:
+ start = offset
+ end = offset + batch_size
+ current_encoder_out = packed_encoder_out.data[start:end]
+ # current_encoder_out's shape: (batch_size, joiner_dim)
+ offset = end
+
+ decoder_out = decoder_out[:batch_size]
+ logits = model.run_joiner(current_encoder_out, decoder_out)
+
+ # logits'shape (batch_size, vocab_size)
+
+ assert logits.ndim == 2, logits.shape
+ y = logits.argmax(dim=1).tolist()
+ emitted = False
+ for i, v in enumerate(y):
+ if v != blank_id:
+ hyps[i].append(v)
+ emitted = True
+ if emitted:
+ # update decoder output
+ decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
+ decoder_input = torch.tensor(
+ decoder_input,
+ dtype=torch.int64,
+ )
+ decoder_out = model.run_decoder(decoder_input)
+
+ sorted_ans = [h[context_size:] for h in hyps]
+ ans = []
+ unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
+ for i in range(N):
+ ans.append(sorted_ans[unsorted_indices[i]])
+
+ return ans
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+ logging.info(vars(args))
+ model = OnnxModel(
+ encoder_model_filename=args.encoder_model_filename,
+ decoder_model_filename=args.decoder_model_filename,
+ joiner_model_filename=args.joiner_model_filename,
+ )
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = "cpu"
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = args.sample_rate
+ opts.mel_opts.num_bins = 80
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {args.sound_files}")
+ waves = read_sound_files(
+ filenames=args.sound_files,
+ expected_sample_rate=args.sample_rate,
+ )
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+ feature_lengths = [f.size(0) for f in features]
+
+ features = pad_sequence(
+ features,
+ batch_first=True,
+ padding_value=math.log(1e-10),
+ )
+
+ feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
+ encoder_out, encoder_out_lens = model.run_encoder(features, feature_lengths)
+
+ hyps = greedy_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ s = "\n"
+
+ token_table = k2.SymbolTable.from_file(args.tokens)
+
+ def token_ids_to_words(token_ids: List[int]) -> str:
+ text = ""
+ for i in token_ids:
+ text += token_table[i]
+ return text.replace("▁", " ").strip()
+
+ for filename, hyp in zip(args.sound_files, hyps):
+ words = token_ids_to_words(hyp)
+ s += f"{filename}:\n{words}\n"
+ logging.info(s)
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/librispeech/ASR/zipformer/pretrained.py b/egs/librispeech/ASR/zipformer/pretrained.py
index 2944f79e3..3104b6084 100755
--- a/egs/librispeech/ASR/zipformer/pretrained.py
+++ b/egs/librispeech/ASR/zipformer/pretrained.py
@@ -18,11 +18,14 @@
This script loads a checkpoint and uses it to decode waves.
You can generate the checkpoint with the following command:
+Note: This is a example for librispeech dataset, if you are using different
+dataset, you should change the argument values according to your dataset.
+
- For non-streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
@@ -31,7 +34,7 @@ You can generate the checkpoint with the following command:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--causal 1 \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
@@ -42,7 +45,7 @@ Usage of this script:
(1) greedy search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
- --bpe-model ./data/lang_bpe_500/bpe.model \
+ --tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@@ -50,7 +53,7 @@ Usage of this script:
(2) modified beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
- --bpe-model ./data/lang_bpe_500/bpe.model \
+ --tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
@@ -58,7 +61,7 @@ Usage of this script:
(3) fast beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
- --bpe-model ./data/lang_bpe_500/bpe.model \
+ --tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
@@ -71,7 +74,7 @@ Usage of this script:
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
- --bpe-model ./data/lang_bpe_500/bpe.model \
+ --tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@@ -82,7 +85,7 @@ Usage of this script:
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
- --bpe-model ./data/lang_bpe_500/bpe.model \
+ --tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
@@ -93,7 +96,7 @@ Usage of this script:
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
- --bpe-model ./data/lang_bpe_500/bpe.model \
+ --tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
@@ -112,7 +115,6 @@ from typing import List
import k2
import kaldifeat
-import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@@ -120,8 +122,11 @@ from beam_search import (
greedy_search_batch,
modified_beam_search,
)
+from export import num_tokens
from torch.nn.utils.rnn import pad_sequence
-from train import add_model_arguments, get_params, get_model
+from train import add_model_arguments, get_model, get_params
+
+from icefall.utils import make_pad_mask
def get_parser():
@@ -139,9 +144,9 @@ def get_parser():
)
parser.add_argument(
- "--bpe-model",
+ "--tokens",
type=str,
- help="""Path to bpe.model.""",
+ help="""Path to tokens.txt.""",
)
parser.add_argument(
@@ -258,13 +263,11 @@ def main():
params.update(vars(args))
- sp = spm.SentencePieceProcessor()
- sp.load(params.bpe_model)
+ token_table = k2.SymbolTable.from_file(params.tokens)
- # is defined in local/train_bpe_model.py
- params.blank_id = sp.piece_to_id("")
- params.unk_id = sp.piece_to_id("")
- params.vocab_size = sp.get_piece_size()
+ params.blank_id = token_table[""]
+ params.unk_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(f"{params}")
@@ -323,6 +326,12 @@ def main():
msg = f"Using {params.method}"
logging.info(msg)
+ def token_ids_to_words(token_ids: List[int]) -> str:
+ text = ""
+ for i in token_ids:
+ text += token_table[i]
+ return text.replace("▁", " ").strip()
+
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@@ -334,8 +343,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
- for hyp in sp.decode(hyp_tokens):
- hyps.append(hyp.split())
+ for hyp in hyp_tokens:
+ hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@@ -344,23 +353,22 @@ def main():
beam=params.beam_size,
)
- for hyp in sp.decode(hyp_tokens):
- hyps.append(hyp.split())
+ for hyp in hyp_tokens:
+ hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
- for hyp in sp.decode(hyp_tokens):
- hyps.append(hyp.split())
+ for hyp in hyp_tokens:
+ hyps.append(token_ids_to_words(hyp))
else:
raise ValueError(f"Unsupported method: {params.method}")
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
- words = " ".join(hyp)
- s += f"{filename}:\n{words}\n\n"
+ s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")
diff --git a/egs/librispeech/ASR/zipformer/pretrained_ctc.py b/egs/librispeech/ASR/zipformer/pretrained_ctc.py
index f10d95449..be239e9c3 100755
--- a/egs/librispeech/ASR/zipformer/pretrained_ctc.py
+++ b/egs/librispeech/ASR/zipformer/pretrained_ctc.py
@@ -24,7 +24,7 @@ You can generate the checkpoint with the following command:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
@@ -34,7 +34,7 @@ You can generate the checkpoint with the following command:
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--causal 1 \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
@@ -43,7 +43,7 @@ Usage of this script:
(1) ctc-decoding
./zipformer/pretrained_ctc.py \
--checkpoint ./zipformer/exp/pretrained.pt \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens data/lang_bpe_500/tokens.txt \
--method ctc-decoding \
--sample-rate 16000 \
/path/to/foo.wav \
@@ -90,12 +90,12 @@ from typing import List
import k2
import kaldifeat
-import sentencepiece as spm
import torch
import torchaudio
from ctc_decode import get_decoding_params
+from export import num_tokens
from torch.nn.utils.rnn import pad_sequence
-from train import add_model_arguments, get_params, get_model
+from train import add_model_arguments, get_model, get_params
from icefall.decode import (
get_lattice,
@@ -144,9 +144,9 @@ def get_parser():
)
parser.add_argument(
- "--bpe-model",
+ "--tokens",
type=str,
- help="""Path to bpe.model.
+ help="""Path to tokens.txt.
Used only when method is ctc-decoding.
""",
)
@@ -157,8 +157,8 @@ def get_parser():
default="1best",
help="""Decoding method.
Possible values are:
- (0) ctc-decoding - Use CTC decoding. It uses a sentence
- piece model, i.e., lang_dir/bpe.model, to convert
+ (0) ctc-decoding - Use CTC decoding. It uses a token table,
+ i.e., lang_dir/tokens.txt, to convert
word pieces to words. It needs neither a lexicon
nor an n-gram LM.
(1) 1best - Use the best path as decoding output. Only
@@ -273,11 +273,10 @@ def main():
params.update(get_decoding_params())
params.update(vars(args))
- sp = spm.SentencePieceProcessor()
- sp.load(params.bpe_model)
-
- params.vocab_size = sp.get_piece_size()
- params.blank_id = 0
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.vocab_size = num_tokens(token_table)
+ params.blank_id = token_table[""]
+ assert params.blank_id == 0
logging.info(f"{params}")
@@ -358,8 +357,7 @@ def main():
lattice=lattice, use_double_scores=params.use_double_scores
)
token_ids = get_texts(best_path)
- hyps = sp.decode(token_ids)
- hyps = [s.split() for s in hyps]
+ hyps = [[token_table[i] for i in ids] for ids in token_ids]
elif params.method in [
"1best",
"nbest-rescoring",
@@ -433,6 +431,7 @@ def main():
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
+ words = words.replace("▁", " ").strip()
s += f"{filename}:\n{words}\n\n"
logging.info(s)
diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py
index 9f23eeead..4ee7b7826 100644
--- a/egs/librispeech/ASR/zipformer/scaling.py
+++ b/egs/librispeech/ASR/zipformer/scaling.py
@@ -25,6 +25,11 @@ import math
import torch.nn as nn
from torch import Tensor
+def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
+ max_value = torch.max(x, y)
+ diff = torch.abs(x - y)
+ return max_value + torch.log1p(torch.exp(-diff))
+
# RuntimeError: Exporting the operator logaddexp to ONNX opset version
# 14 is not supported. Please feel free to request support or submit
@@ -33,10 +38,22 @@ from torch import Tensor
# The following function is to solve the above error when exporting
# models to ONNX via torch.jit.trace()
def logaddexp(x: Tensor, y: Tensor) -> Tensor:
- if not torch.jit.is_tracing():
+ # Caution(fangjun): Put torch.jit.is_scripting() before
+ # torch.onnx.is_in_onnx_export();
+ # otherwise, it will cause errors for torch.jit.script().
+ #
+ # torch.logaddexp() works for both torch.jit.script() and
+ # torch.jit.trace() but it causes errors for ONNX export.
+ #
+ if torch.jit.is_scripting():
+ # Note: We cannot use torch.jit.is_tracing() here as it also
+ # matches torch.onnx.export().
return torch.logaddexp(x, y)
+ elif torch.onnx.is_in_onnx_export():
+ return logaddexp_onnx(x, y)
else:
- return (x.exp() + y.exp()).log()
+ # for torch.jit.trace()
+ return torch.logaddexp(x, y)
class PiecewiseLinear(object):
"""
@@ -1334,6 +1351,13 @@ class SwooshL(torch.nn.Module):
return k2.swoosh_l(x)
# return SwooshLFunction.apply(x)
+class SwooshLOnnx(torch.nn.Module):
+ def forward(self, x: Tensor) -> Tensor:
+ """Return Swoosh-L activation.
+ """
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+ return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035
+
class SwooshRFunction(torch.autograd.Function):
"""
@@ -1400,6 +1424,13 @@ class SwooshR(torch.nn.Module):
return k2.swoosh_r(x)
# return SwooshRFunction.apply(x)
+class SwooshROnnx(torch.nn.Module):
+ def forward(self, x: Tensor) -> Tensor:
+ """Return Swoosh-R activation.
+ """
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+ return logaddexp_onnx(zero, x - 1.) - 0.08 * x - 0.313261687
+
# simple version of SwooshL that does not redefine the backprop, used in
# ActivationDropoutAndLinearFunction.
diff --git a/egs/librispeech/ASR/zipformer/scaling_converter.py b/egs/librispeech/ASR/zipformer/scaling_converter.py
index 54a5c2a6a..76622fa12 100644
--- a/egs/librispeech/ASR/zipformer/scaling_converter.py
+++ b/egs/librispeech/ASR/zipformer/scaling_converter.py
@@ -26,7 +26,16 @@ from typing import List, Tuple
import torch
import torch.nn as nn
-from scaling import Balancer, Dropout3, ScaleGrad, Whiten
+from scaling import (
+ Balancer,
+ Dropout3,
+ ScaleGrad,
+ SwooshL,
+ SwooshLOnnx,
+ SwooshR,
+ SwooshROnnx,
+ Whiten,
+)
from zipformer import CompactRelPositionalEncoding
@@ -75,6 +84,10 @@ def convert_scaled_to_non_scaled(
for name, m in model.named_modules():
if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)):
d[name] = nn.Identity()
+ elif is_onnx and isinstance(m, SwooshR):
+ d[name] = SwooshROnnx()
+ elif is_onnx and isinstance(m, SwooshL):
+ d[name] = SwooshLOnnx()
elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
# We want to recreate the positional encoding vector when
# the input changes, so we have to use torch.jit.script()
diff --git a/egs/librispeech/ASR/zipformer/streaming_beam_search.py b/egs/librispeech/ASR/zipformer/streaming_beam_search.py
index e6e0fb1c8..3c8565b33 100644
--- a/egs/librispeech/ASR/zipformer/streaming_beam_search.py
+++ b/egs/librispeech/ASR/zipformer/streaming_beam_search.py
@@ -31,6 +31,7 @@ def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
+ blank_penalty: float = 0.0,
) -> None:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
@@ -71,6 +72,9 @@ def greedy_search(
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
+ if blank_penalty != 0.0:
+ logits[:, 0] -= blank_penalty
+
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
@@ -97,6 +101,7 @@ def modified_beam_search(
encoder_out: torch.Tensor,
streams: List[DecodeStream],
num_active_paths: int = 4,
+ blank_penalty: float = 0.0,
) -> None:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
@@ -158,6 +163,9 @@ def modified_beam_search(
logits = logits.squeeze(1).squeeze(1)
+ if blank_penalty != 0.0:
+ logits[:, 0] -= blank_penalty
+
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
@@ -205,6 +213,7 @@ def fast_beam_search_one_best(
beam: float,
max_states: int,
max_contexts: int,
+ blank_penalty: float = 0.0,
) -> None:
"""It limits the maximum number of symbols per frame to 1.
@@ -269,6 +278,10 @@ def fast_beam_search_one_best(
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
+
+ if blank_penalty != 0.0:
+ logits[:, 0] -= blank_penalty
+
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
diff --git a/egs/tedlium3/ASR/RESULTS.md b/egs/tedlium3/ASR/RESULTS.md
index 38eaa8f44..bd8a5b43f 100644
--- a/egs/tedlium3/ASR/RESULTS.md
+++ b/egs/tedlium3/ASR/RESULTS.md
@@ -1,5 +1,111 @@
## Results
+### TedLium3 BPE training results (Zipformer)
+
+#### 2023-06-15 (Regular transducer)
+
+Using the codes from this PR https://github.com/k2-fsa/icefall/pull/1125.
+
+Number of model parameters: 65549011, i.e., 65.5 M
+
+The WERs are
+
+| | dev | test | comment |
+|------------------------------------|------------|------------|------------------------------------------|
+| greedy search | 6.74 | 6.16 | --epoch 50, --avg 22, --max-duration 500 |
+| beam search (beam size 4) | 6.56 | 5.95 | --epoch 50, --avg 22, --max-duration 500 |
+| modified beam search (beam size 4) | 6.54 | 6.00 | --epoch 50, --avg 22, --max-duration 500 |
+| fast beam search (set as default) | 6.91 | 6.28 | --epoch 50, --avg 22, --max-duration 500 |
+
+The training command for reproducing is given below:
+
+```
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./zipformer/train.py \
+ --use-fp16 true \
+ --world-size 4 \
+ --num-epochs 50 \
+ --start-epoch 0 \
+ --exp-dir zipformer/exp \
+ --max-duration 1000
+```
+
+The tensorboard training log can be found at
+https://tensorboard.dev/experiment/AKXbJha0S9aXyfmuvG4h5A/#scalars
+
+The decoding command is:
+```
+epoch=50
+avg=22
+
+## greedy search
+./zipformer/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir zipformer/exp \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --max-duration 500
+
+## beam search
+./zipformer/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir zipformer/exp \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --max-duration 500 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+## modified beam search
+./zipformer/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir zipformer/exp \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --max-duration 500 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+## fast beam search
+./zipformer/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir ./zipformer/exp \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --max-duration 1500 \
+ --decoding-method fast_beam_search \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8
+```
+
+A pre-trained model and decoding logs can be found at
+
+#### 2023-06-26 (Modified transducer)
+
+```
+./zipformer/train.py \
+ --use-fp16 true \
+ --world-size 4 \
+ --num-epochs 50 \
+ --start-epoch 0 \
+ --exp-dir zipformer/exp \
+ --max-duration 1000 \
+ --rnnt-type modified
+```
+
+The tensorboard training log can be found at
+https://tensorboard.dev/experiment/3d4bYmbJTGiWQQaW88CVEQ/#scalars
+
+| | dev | test | comment |
+|------------------------------------|------------|------------|------------------------------------------|
+| greedy search | 6.32 | 5.83 | --epoch 50, --avg 22, --max-duration 500 |
+| modified beam search (beam size 4) | 6.16 | 5.79 | --epoch 50, --avg 22, --max-duration 500 |
+| fast beam search (set as default) | 6.30 | 5.89 | --epoch 50, --avg 22, --max-duration 500 |
+
+A pre-trained model and decoding logs can be found at .
+
### TedLium3 BPE training results (Conformer-CTC 2)
#### [conformer_ctc2](./conformer_ctc2)
diff --git a/egs/tedlium3/ASR/zipformer/__init__.py b/egs/tedlium3/ASR/zipformer/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/tedlium3/ASR/zipformer/asr_datamodule.py b/egs/tedlium3/ASR/zipformer/asr_datamodule.py
new file mode 120000
index 000000000..49b2ee483
--- /dev/null
+++ b/egs/tedlium3/ASR/zipformer/asr_datamodule.py
@@ -0,0 +1 @@
+../transducer_stateless/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/zipformer/beam_search.py b/egs/tedlium3/ASR/zipformer/beam_search.py
new file mode 120000
index 000000000..e24eca39f
--- /dev/null
+++ b/egs/tedlium3/ASR/zipformer/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/zipformer/decode.py b/egs/tedlium3/ASR/zipformer/decode.py
new file mode 100755
index 000000000..ea1cbba1b
--- /dev/null
+++ b/egs/tedlium3/ASR/zipformer/decode.py
@@ -0,0 +1,833 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./zipformer/decode.py \
+ --epoch 30 \
+ --avg 9 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./zipformer/decode.py \
+ --epoch 30 \
+ --avg 9 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./zipformer/decode.py \
+ --epoch 30 \
+ --avg 9 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./zipformer/decode.py \
+ --epoch 30 \
+ --avg 9 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest)
+./zipformer/decode.py \
+ --epoch 30 \
+ --avg 9 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./zipformer/decode.py \
+ --epoch 30 \
+ --avg 9 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./zipformer/decode.py \
+ --epoch 30 \
+ --avg 9 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import TedLiumAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_LG,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bpe_500",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ - fast_beam_search_nbest_LG
+ If you use fast_beam_search_nbest_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding_method is fast_beam_search_nbest_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ if params.causal:
+ # this seems to cause insertions at the end of the utterance if used with zipformer.
+ pad_len = 30
+ feature_lens += pad_len
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, pad_len),
+ value=LOG_EPS,
+ )
+
+ x, x_lens = model.encoder_embed(feature, feature_lens)
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+
+ encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask)
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ hyps = []
+ unk = sp.decode(sp.unk_id()).strip()
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ allow_partial=True,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyp = [w for w in hyp.split() if w != unk]
+ hyps.append(hyp)
+ elif params.decoding_method == "fast_beam_search_nbest_LG":
+ hyp_tokens = fast_beam_search_nbest_LG(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ allow_partial=True,
+ )
+ for hyp in hyp_tokens:
+ hyp = [word_table[i] for i in hyp if word_table[i] != unk]
+ hyps.append(hyp)
+ elif params.decoding_method == "fast_beam_search_nbest":
+ hyp_tokens = fast_beam_search_nbest(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ allow_partial=True,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyp = [w for w in hyp.split() if w != unk]
+ hyps.append(hyp)
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(supervisions["text"]),
+ nbest_scale=params.nbest_scale,
+ allow_partial=True,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyp = [w for w in hyp.split() if w != unk]
+ hyps.append(hyp)
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyp = [w for w in hyp.split() if w != unk]
+ hyps.append(hyp)
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyp = [w for w in hyp.split() if w != unk]
+ hyps.append(hyp)
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyp = [w for w in sp.decode(hyp).split() if w != unk]
+ hyps.append(hyp)
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ word_table=word_table,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ TedLiumAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest",
+ "fast_beam_search_nbest_LG",
+ "fast_beam_search_nbest_oracle",
+ "modified_beam_search",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ if "fast_beam_search" in params.decoding_method:
+ if params.decoding_method == "fast_beam_search_nbest_LG":
+ lexicon = Lexicon(params.lang_dir)
+ word_table = lexicon.word_table
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ word_table = None
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+ word_table = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ tedlium = TedLiumAsrDataModule(args)
+
+ dev_cuts = tedlium.dev_cuts()
+ test_cuts = tedlium.test_cuts()
+
+ dev_dl = tedlium.test_dataloaders(dev_cuts)
+ test_dl = tedlium.test_dataloaders(test_cuts)
+
+ test_sets = ["dev", "test"]
+ test_dls = [dev_dl, test_dl]
+
+ for name, dl in zip(test_sets, test_dls):
+ results_dict = decode_dataset(
+ dl=dl,
+ params=params,
+ model=model,
+ sp=sp,
+ word_table=word_table,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=name,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/tedlium3/ASR/zipformer/decoder.py b/egs/tedlium3/ASR/zipformer/decoder.py
new file mode 120000
index 000000000..5a8018680
--- /dev/null
+++ b/egs/tedlium3/ASR/zipformer/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/decoder.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/zipformer/encoder_interface.py b/egs/tedlium3/ASR/zipformer/encoder_interface.py
new file mode 120000
index 000000000..653c5b09a
--- /dev/null
+++ b/egs/tedlium3/ASR/zipformer/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer_stateless/encoder_interface.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/zipformer/export.py b/egs/tedlium3/ASR/zipformer/export.py
new file mode 120000
index 000000000..dfc1bec08
--- /dev/null
+++ b/egs/tedlium3/ASR/zipformer/export.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/zipformer/joiner.py b/egs/tedlium3/ASR/zipformer/joiner.py
new file mode 120000
index 000000000..5b8a36332
--- /dev/null
+++ b/egs/tedlium3/ASR/zipformer/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/joiner.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/zipformer/model.py b/egs/tedlium3/ASR/zipformer/model.py
new file mode 100644
index 000000000..90ec7e7aa
--- /dev/null
+++ b/egs/tedlium3/ASR/zipformer/model.py
@@ -0,0 +1,223 @@
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import k2
+import torch
+import torch.nn as nn
+from encoder_interface import EncoderInterface
+
+from icefall.utils import add_sos, make_pad_mask
+from scaling import ScaledLinear
+
+
+class Transducer(nn.Module):
+ """It implements https://arxiv.org/pdf/1211.3711.pdf
+ "Sequence Transduction with Recurrent Neural Networks"
+ """
+
+ def __init__(
+ self,
+ encoder_embed: nn.Module,
+ encoder: EncoderInterface,
+ decoder: nn.Module,
+ joiner: nn.Module,
+ encoder_dim: int,
+ decoder_dim: int,
+ joiner_dim: int,
+ vocab_size: int,
+ ):
+ """
+ Args:
+ encoder_embed:
+ It is a Convolutional 2D subsampling module. It converts
+ an input of shape (N, T, idim) to an output of of shape
+ (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
+ encoder:
+ It is the transcription network in the paper. Its accepts
+ two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
+ It returns two tensors: `logits` of shape (N, T, encoder_dim) and
+ `logit_lens` of shape (N,).
+ decoder:
+ It is the prediction network in the paper. Its input shape
+ is (N, U) and its output shape is (N, U, decoder_dim).
+ It should contain one attribute: `blank_id`.
+ joiner:
+ It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
+ Its output shape is (N, T, U, vocab_size). Note that its output contains
+ unnormalized probs, i.e., not processed by log-softmax.
+ """
+ super().__init__()
+ assert isinstance(encoder, EncoderInterface), type(encoder)
+ assert hasattr(decoder, "blank_id")
+
+ self.encoder_embed = encoder_embed
+ self.encoder = encoder
+ self.decoder = decoder
+ self.joiner = joiner
+
+ self.simple_am_proj = ScaledLinear(
+ encoder_dim,
+ vocab_size,
+ initial_scale=0.25,
+ )
+ self.simple_lm_proj = ScaledLinear(
+ decoder_dim,
+ vocab_size,
+ initial_scale=0.25,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_lens: torch.Tensor,
+ y: k2.RaggedTensor,
+ prune_range: int = 5,
+ am_scale: float = 0.0,
+ lm_scale: float = 0.0,
+ rnnt_type: str = "regular",
+ ) -> torch.Tensor:
+ """
+ Args:
+ x:
+ A 3-D tensor of shape (N, T, C).
+ x_lens:
+ A 1-D tensor of shape (N,). It contains the number of frames in `x`
+ before padding.
+ y:
+ A ragged tensor with 2 axes [utt][label]. It contains labels of each
+ utterance.
+ prune_range:
+ The prune range for rnnt loss, it means how many symbols(context)
+ we are considering for each frame to compute the loss.
+ am_scale:
+ The scale to smooth the loss with am (output of encoder network)
+ part
+ lm_scale:
+ The scale to smooth the loss with lm (output of predictor network)
+ part
+ rnnt_type:
+ The type of label topology to use for the transducer loss. One of "regular",
+ "modified", or "constrained".
+ Returns:
+ Return the transducer loss.
+
+ Note:
+ Regarding am_scale & lm_scale, it will make the loss-function one of
+ the form:
+ lm_scale * lm_probs + am_scale * am_probs +
+ (1-lm_scale-am_scale) * combined_probs
+ """
+ assert x.ndim == 3, x.shape
+ assert x_lens.ndim == 1, x_lens.shape
+ assert y.num_axes == 2, y.num_axes
+
+ assert x.size(0) == x_lens.size(0) == y.dim0
+
+ # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
+ x, x_lens = self.encoder_embed(x, x_lens)
+ # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+
+ encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask)
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ assert torch.all(x_lens > 0)
+
+ # Now for the decoder, i.e., the prediction network
+ row_splits = y.shape.row_splits(1)
+ y_lens = row_splits[1:] - row_splits[:-1]
+
+ blank_id = self.decoder.blank_id
+ sos_y = add_sos(y, sos_id=blank_id)
+
+ # sos_y_padded: [B, S + 1], start with SOS.
+ sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
+
+ # decoder_out: [B, S + 1, decoder_dim]
+ decoder_out = self.decoder(sos_y_padded)
+
+ # Note: y does not start with SOS
+ # y_padded : [B, S]
+ y_padded = y.pad(mode="constant", padding_value=0)
+
+ y_padded = y_padded.to(torch.int64)
+ boundary = torch.zeros(
+ (encoder_out.size(0), 4),
+ dtype=torch.int64,
+ device=encoder_out.device,
+ )
+ boundary[:, 2] = y_lens
+ boundary[:, 3] = x_lens
+
+ lm = self.simple_lm_proj(decoder_out)
+ am = self.simple_am_proj(encoder_out)
+
+ # if self.training and random.random() < 0.25:
+ # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
+ # if self.training and random.random() < 0.25:
+ # am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
+
+ with torch.cuda.amp.autocast(enabled=False):
+ simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
+ lm=lm.float(),
+ am=am.float(),
+ symbols=y_padded,
+ termination_symbol=blank_id,
+ lm_only_scale=lm_scale,
+ am_only_scale=am_scale,
+ boundary=boundary,
+ reduction="sum",
+ return_grad=True,
+ rnnt_type=rnnt_type,
+ )
+
+ # ranges : [B, T, prune_range]
+ ranges = k2.get_rnnt_prune_ranges(
+ px_grad=px_grad,
+ py_grad=py_grad,
+ boundary=boundary,
+ s_range=prune_range,
+ )
+
+ # am_pruned : [B, T, prune_range, encoder_dim]
+ # lm_pruned : [B, T, prune_range, decoder_dim]
+ am_pruned, lm_pruned = k2.do_rnnt_pruning(
+ am=self.joiner.encoder_proj(encoder_out),
+ lm=self.joiner.decoder_proj(decoder_out),
+ ranges=ranges,
+ )
+
+ # logits : [B, T, prune_range, vocab_size]
+
+ # project_input=False since we applied the decoder's input projections
+ # prior to do_rnnt_pruning (this is an optimization for speed).
+ logits = self.joiner(am_pruned, lm_pruned, project_input=False)
+
+ with torch.cuda.amp.autocast(enabled=False):
+ pruned_loss = k2.rnnt_loss_pruned(
+ logits=logits.float(),
+ symbols=y_padded,
+ ranges=ranges,
+ termination_symbol=blank_id,
+ boundary=boundary,
+ reduction="sum",
+ rnnt_type=rnnt_type,
+ )
+
+ return (simple_loss, pruned_loss)
diff --git a/egs/tedlium3/ASR/zipformer/optim.py b/egs/tedlium3/ASR/zipformer/optim.py
new file mode 120000
index 000000000..5eaa3cffd
--- /dev/null
+++ b/egs/tedlium3/ASR/zipformer/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/optim.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/zipformer/pretrained.py b/egs/tedlium3/ASR/zipformer/pretrained.py
new file mode 120000
index 000000000..0bd71dde4
--- /dev/null
+++ b/egs/tedlium3/ASR/zipformer/pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/pretrained.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/zipformer/profile.py b/egs/tedlium3/ASR/zipformer/profile.py
new file mode 120000
index 000000000..c93adbd14
--- /dev/null
+++ b/egs/tedlium3/ASR/zipformer/profile.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/profile.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/zipformer/scaling.py b/egs/tedlium3/ASR/zipformer/scaling.py
new file mode 120000
index 000000000..6f398f431
--- /dev/null
+++ b/egs/tedlium3/ASR/zipformer/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/scaling.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/zipformer/scaling_converter.py b/egs/tedlium3/ASR/zipformer/scaling_converter.py
new file mode 120000
index 000000000..b0ecee05e
--- /dev/null
+++ b/egs/tedlium3/ASR/zipformer/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/scaling_converter.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/zipformer/subsampling.py b/egs/tedlium3/ASR/zipformer/subsampling.py
new file mode 120000
index 000000000..01ae9002c
--- /dev/null
+++ b/egs/tedlium3/ASR/zipformer/subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/subsampling.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/zipformer/train.py b/egs/tedlium3/ASR/zipformer/train.py
new file mode 100755
index 000000000..9271c8438
--- /dev/null
+++ b/egs/tedlium3/ASR/zipformer/train.py
@@ -0,0 +1,1308 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Daniel Povey)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+# For non-streaming model training:
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --full-libri 1 \
+ --max-duration 1000
+
+# For streaming model training:
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --causal 1 \
+ --full-libri 1 \
+ --max-duration 1000
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import TedLiumAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids
+from model import Transducer
+from optim import Eden, ScaledAdam
+from scaling import ScheduledFloat
+from subsampling import Conv2dSubsampling
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer2
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,2,3,4,3,2",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--downsampling-factor",
+ type=str,
+ default="1,2,4,8,4,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dim",
+ type=str,
+ default="512,768,1024,1536,1024,768",
+ help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--num-heads",
+ type=str,
+ default="4,4,4,8,4,4",
+ help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=str,
+ default="192,256,384,512,384,256",
+ help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=str,
+ default="32",
+ help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=str,
+ default="12",
+ help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=str,
+ default="4",
+ help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default="48",
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dim",
+ type=str,
+ default="192,192,256,256,256,192",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernel",
+ type=str,
+ default="31,31,15,15,15,31",
+ help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--causal",
+ type=str2bool,
+ default=False,
+ help="If True, use causal version of model.",
+ )
+
+ parser.add_argument(
+ "--chunk-size",
+ type=str,
+ default="16,32,64,-1",
+ help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
+ " Must be just -1 if --causal=False",
+ )
+
+ parser.add_argument(
+ "--left-context-frames",
+ type=str,
+ default="64,128,256,-1",
+ help="Maximum left-contexts for causal training, measured in frames which will "
+ "be converted to a number of chunks. If splitting into chunks, "
+ "chunk left-context frames will be chosen randomly from this list; else not relevant.",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=50,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.04, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--rnnt-type",
+ type=str,
+ default="regular",
+ choices=["regular", "modified", "constrained"],
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)" "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=1,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_encoder_embed(params: AttributeDict) -> nn.Module:
+ # encoder_embed converts the input of shape (N, T, num_features)
+ # to the shape (N, (T - 7) // 2, encoder_dims).
+ # That is, it does two things simultaneously:
+ # (1) subsampling: T -> (T - 7) // 2
+ # (2) embedding: num_features -> encoder_dims
+ # In the normal configuration, we will downsample once more at the end
+ # by a factor of 2, and most of the encoder stacks will run at a lower
+ # sampling rate.
+ encoder_embed = Conv2dSubsampling(
+ in_channels=params.feature_dim,
+ out_channels=_to_int_tuple(params.encoder_dim)[0],
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ )
+ return encoder_embed
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ encoder = Zipformer2(
+ output_downsampling_factor=2,
+ downsampling_factor=_to_int_tuple(params.downsampling_factor),
+ num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
+ encoder_dim=_to_int_tuple(params.encoder_dim),
+ encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
+ query_head_dim=_to_int_tuple(params.query_head_dim),
+ pos_head_dim=_to_int_tuple(params.pos_head_dim),
+ value_head_dim=_to_int_tuple(params.value_head_dim),
+ pos_dim=params.pos_dim,
+ num_heads=_to_int_tuple(params.num_heads),
+ feedforward_dim=_to_int_tuple(params.feedforward_dim),
+ cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ warmup_batches=4000.0,
+ causal=params.causal,
+ chunk_size=_to_int_tuple(params.chunk_size),
+ left_context_frames=_to_int_tuple(params.left_context_frames),
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+ encoder_embed = get_encoder_embed(params)
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = Transducer(
+ encoder_embed=encoder_embed,
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=int(max(params.encoder_dim.split(","))),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute RNNT loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = convert_texts_into_ids(texts, sp)
+ y = k2.RaggedTensor(y).to(device)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ rnnt_type=params.rnnt_type,
+ )
+
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+
+ loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 2**22
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ tedlium = TedLiumAsrDataModule(args)
+
+ train_cuts = tedlium.train_cuts()
+ train_cuts = train_cuts.filter(lambda c: 1.0 <= c.duration <= 20.0)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = tedlium.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = tedlium.dev_cuts()
+ valid_dl = tedlium.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ TedLiumAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/tedlium3/ASR/zipformer/zipformer.py b/egs/tedlium3/ASR/zipformer/zipformer.py
new file mode 120000
index 000000000..23011dda7
--- /dev/null
+++ b/egs/tedlium3/ASR/zipformer/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/zipformer.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/RESULTS.md b/egs/wenetspeech/ASR/RESULTS.md
index 658ad4a9b..1a0e0681f 100644
--- a/egs/wenetspeech/ASR/RESULTS.md
+++ b/egs/wenetspeech/ASR/RESULTS.md
@@ -1,5 +1,90 @@
## Results
+### WenetSpeech char-based training results (Non-streaming and streaming) on zipformer model
+
+This is the [pull request](https://github.com/k2-fsa/icefall/pull/1130) in icefall.
+
+#### Non-streaming
+
+Best results (num of params : ~76M):
+
+Type | Greedy(dev & net & meeting) | Beam search(dev & net & meeting) |
+-- | -- | -- | --
+Non-streaming | 7.36 & 7.65 & 12.43 | 7.32 & 7.61 & 12.35 | --epoch=12
+
+The training command:
+
+```
+./zipformer/train.py \
+ --world-size 6 \
+ --num-epochs 12 \
+ --use-fp16 1 \
+ --max-duration 450 \
+ --training-subset L \
+ --lr-epochs 1.5 \
+ --context-size 2 \
+ --exp-dir zipformer/exp_L_context_2 \
+ --causal 0 \
+ --num-workers 8
+```
+
+Listed best results for each epoch below:
+
+Epoch | Greedy search(dev & net & meeting) | Modified beam search(dev & net & meeting) |
+-- | -- | -- | --
+4 | 7.83 & 8.86 &13.73 | 7.75 & 8.81 & 13.67 | avg=1;blank-penalty=2
+5 | 7.75 & 8.46 & 13.38 | 7.68 & 8.41 & 13.27 | avg=1;blank-penalty=2
+6 | 7.72 & 8.19 & 13.16 | 7.62 & 8.14 & 13.06 | avg=1;blank-penalty=2
+7 | 7.59 & 8.08 & 12.97 | 7.53 & 8.01 & 12.87 | avg=2;blank-penalty=2
+8 | 7.68 & 7.87 & 12.96 | 7.61 & 7.81 & 12.88 | avg=1;blank-penalty=2
+9 | 7.57 & 7.77 & 12.87 | 7.5 & 7.71 & 12.77 | avg=1;blank-penalty=2
+10 | 7.45 & 7.7 & 12.69 | 7.39 & 7.63 & 12.59 | avg=2;blank-penalty=2
+11 | 7.35 & 7.67 & 12.46 | 7.31 & 7.63 & 12.43 | avg=3;blank-penalty=2
+12 | 7.36 & 7.65 & 12.43 | 7.32 & 7.61 & 12.35 | avg=4;blank-penalty=2
+
+The pre-trained model is available here : https://huggingface.co/pkufool/icefall-asr-zipformer-wenetspeech-20230615
+
+
+#### Streaming
+
+Best results (num of params : ~76M):
+
+Type | Greedy(dev & net & meeting) | Beam search(dev & net & meeting) |
+-- | -- | -- | --
+Streaming | 8.45 & 9.89 & 16.46 | 8.21 & 9.77 & 16.07 | --epoch=12; --chunk-size=16; --left-context-frames=256
+Streaming | 8.0 & 9.0 & 15.11 | 7.84 & 8.94 & 14.92 | --epoch=12; --chunk-size=32; --left-context-frames=256
+
+The training command:
+
+```
+./zipformer/train.py \
+ --world-size 8 \
+ --num-epochs 12 \
+ --use-fp16 1 \
+ --max-duration 450 \
+ --training-subset L \
+ --lr-epochs 1.5 \
+ --context-size 2 \
+ --exp-dir zipformer/exp_L_causal_context_2 \
+ --causal 1 \
+ --num-workers 8
+```
+
+Best results for each epoch (--chunk-size=16; --left-context-frames=128)
+
+Epoch | Greedy search(dev & net & meeting) | Modified beam search(dev & net & meeting) |
+-- | -- | -- | --
+6 | 9.14 & 10.75 & 18.15 | 8.79 & 10.54 & 17.64 | avg=1;blank-penalty=1.5
+7 | 9.11 & 10.61 & 17.86 | 8.8 & 10.42 & 17.29 | avg=1;blank-penalty=1.5
+8 | 8.89 & 10.32 & 17.44 | 8.59 & 10.09 & 16.9 | avg=1;blank-penalty=1.5
+9 | 8.86 & 10.11 & 17.35 | 8.55 & 9.87 & 16.76 | avg=1;blank-penalty=1.5
+10 | 8.66 & 10.0 & 16.94 | 8.39 & 9.83 & 16.47 | avg=2;blank-penalty=1.5
+11 | 8.58 & 9.92 & 16.67 | 8.32 & 9.77 & 16.27 | avg=3;blank-penalty=1.5
+12 | 8.45 & 9.89 & 16.46 | 8.21 & 9.77 & 16.07 | avg=4;blank-penalty=1.5
+
+The pre-trained model is available here: https://huggingface.co/pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615
+
+
### WenetSpeech char-based training results (offline and streaming) (Pruned Transducer 5)
#### 2022-07-22
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index 7cb2e1048..746b212ff 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -292,7 +292,7 @@ class WenetSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
- buffer_size=30000,
+ buffer_size=300000,
drop_last=True,
)
else:
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
index dc431578c..36b8a4b67 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
@@ -588,7 +588,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
- texts = [list(str(text)) for text in texts]
+ texts = [list("".join(text.split())) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
diff --git a/egs/wenetspeech/ASR/zipformer/__init__.py b/egs/wenetspeech/ASR/zipformer/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/wenetspeech/ASR/zipformer/asr_datamodule.py b/egs/wenetspeech/ASR/zipformer/asr_datamodule.py
new file mode 120000
index 000000000..a074d6085
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/asr_datamodule.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/zipformer/beam_search.py b/egs/wenetspeech/ASR/zipformer/beam_search.py
new file mode 120000
index 000000000..8554e44cc
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/beam_search.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/zipformer/decode.py b/egs/wenetspeech/ASR/zipformer/decode.py
new file mode 100755
index 000000000..0fbc8244b
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/decode.py
@@ -0,0 +1,818 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao
+# Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./zipformer/decode.py \
+ --epoch 35 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --lang-dir data/lang_char \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) modified beam search
+./zipformer/decode.py \
+ --epoch 35 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --lang-dir data/lang_char \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(3) fast beam search (trivial_graph)
+./zipformer/decode.py \
+ --epoch 35 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --lang-dir data/lang_char \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(4) fast beam search (LG)
+./zipformer/decode.py \
+ --epoch 30 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --lang-dir data/lang_char \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest oracle WER)
+./zipformer/decode.py \
+ --epoch 35 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --lang-dir data/lang_char \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+from asr_datamodule import WenetSpeechAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_LG,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from lhotse.cut import Cut
+from train import add_model_arguments, get_model, get_params
+
+from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_char",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - modified_beam_search
+ - fast_beam_search
+ - fast_beam_search_LG
+ - fast_beam_search_nbest_oracle
+ If you use fast_beam_search_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search, fast_beam_search_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding_method is fast_beam_search_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--ilme-scale",
+ type=float,
+ default=0.2,
+ help="""
+ Used only when --decoding_method is fast_beam_search_LG.
+ It specifies the scale for the internal language model estimation.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search, fast_beam_search_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search, fast_beam_search_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--blank-penalty",
+ type=float,
+ default=0.0,
+ help="""
+ The penalty applied on blank symbol during decoding.
+ Note: It is a positive value that would be applied to logits like
+ this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
+ [batch_size, vocab] and blank id is 0).
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ lexicon: Lexicon,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ batch: dict,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ if params.causal:
+ # this seems to cause insertions at the end of the utterance if used with zipformer.
+ pad_len = 30
+ feature_lens += pad_len
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, pad_len),
+ value=LOG_EPS,
+ )
+
+ x, x_lens = model.encoder_embed(feature, feature_lens)
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+
+ encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask)
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ blank_penalty=params.blank_penalty,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif params.decoding_method == "fast_beam_search_LG":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ blank_penalty=params.blank_penalty,
+ ilme_scale=params.ilme_scale,
+ )
+ for hyp in hyp_tokens:
+ sentence = "".join([lexicon.word_table[i] for i in hyp])
+ hyps.append(list(sentence))
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=graph_compiler.texts_to_ids(supervisions["text"]),
+ nbest_scale=params.nbest_scale,
+ blank_penalty=params.blank_penalty,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ blank_penalty=params.blank_penalty,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ blank_penalty=params.blank_penalty,
+ beam=params.beam_size,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ blank_penalty=params.blank_penalty,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ blank_penalty=params.blank_penalty,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append([lexicon.token_table[idx] for idx in hyp])
+
+ key = f"blank_penalty_{params.blank_penalty}"
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search_" + key: hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key += f"_beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ilme_scale_{params.ilme_scale}"
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}_" + key: hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ lexicon: Lexicon,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ texts = [list("".join(text.split())) for text in texts]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ lexicon=lexicon,
+ graph_compiler=graph_compiler,
+ decoding_graph=decoding_graph,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ this_batch.append((cut_id, ref_text, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ WenetSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "modified_beam_search",
+ "fast_beam_search",
+ "fast_beam_search_LG",
+ "fast_beam_search_nbest_oracle",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"_ilme_scale_{params.ilme_scale}"
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+ params.suffix += f"-blank-penalty-{params.blank_penalty}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ params.blank_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ graph_compiler = CharCtcTrainingGraphCompiler(
+ lexicon=lexicon,
+ device=device,
+ )
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ if "fast_beam_search" in params.decoding_method:
+ if "LG" in params.decoding_method:
+ lexicon = Lexicon(params.lang_dir)
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ wenetspeech = WenetSpeechAsrDataModule(args)
+
+ def remove_short_utt(c: Cut):
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ if T <= 0:
+ logging.warning(
+ f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
+ )
+ return T > 0
+
+ dev_cuts = wenetspeech.valid_cuts()
+ dev_cuts = dev_cuts.filter(remove_short_utt)
+ dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
+
+ test_net_cuts = wenetspeech.test_net_cuts()
+ test_net_cuts = test_net_cuts.filter(remove_short_utt)
+ test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
+
+ test_meeting_cuts = wenetspeech.test_meeting_cuts()
+ test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt)
+ test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
+
+ test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
+ test_dls = [dev_dl, test_net_dl, test_meeting_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dls):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ lexicon=lexicon,
+ graph_compiler=graph_compiler,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/wenetspeech/ASR/zipformer/decode_stream.py b/egs/wenetspeech/ASR/zipformer/decode_stream.py
new file mode 120000
index 000000000..b8d8ddfc4
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/decode_stream.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/decode_stream.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/zipformer/decoder.py b/egs/wenetspeech/ASR/zipformer/decoder.py
new file mode 120000
index 000000000..5a8018680
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/decoder.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/zipformer/encoder_interface.py b/egs/wenetspeech/ASR/zipformer/encoder_interface.py
new file mode 120000
index 000000000..b9aa0ae08
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/encoder_interface.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/encoder_interface.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/zipformer/export-onnx-streaming.py b/egs/wenetspeech/ASR/zipformer/export-onnx-streaming.py
new file mode 120000
index 000000000..2962eb784
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/export-onnx-streaming.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export-onnx-streaming.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/zipformer/export-onnx.py b/egs/wenetspeech/ASR/zipformer/export-onnx.py
new file mode 120000
index 000000000..70a15683c
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/export-onnx.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export-onnx.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/zipformer/export.py b/egs/wenetspeech/ASR/zipformer/export.py
new file mode 120000
index 000000000..dfc1bec08
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/export.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/zipformer/jit_pretrained.py b/egs/wenetspeech/ASR/zipformer/jit_pretrained.py
new file mode 120000
index 000000000..25108391f
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/jit_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/jit_pretrained.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/zipformer/jit_pretrained_streaming.py b/egs/wenetspeech/ASR/zipformer/jit_pretrained_streaming.py
new file mode 120000
index 000000000..1962351e9
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/jit_pretrained_streaming.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/zipformer/joiner.py b/egs/wenetspeech/ASR/zipformer/joiner.py
new file mode 120000
index 000000000..5b8a36332
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/joiner.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/zipformer/model.py b/egs/wenetspeech/ASR/zipformer/model.py
new file mode 120000
index 000000000..cd7e07d72
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/model.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/zipformer/onnx_check.py b/egs/wenetspeech/ASR/zipformer/onnx_check.py
new file mode 120000
index 000000000..f3dd42004
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/onnx_check.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/onnx_check.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/zipformer/onnx_decode.py b/egs/wenetspeech/ASR/zipformer/onnx_decode.py
new file mode 100755
index 000000000..ed5f6db08
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/onnx_decode.py
@@ -0,0 +1,334 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Xiaoyu Yang,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads ONNX exported models and uses them to decode the test sets.
+
+We use the pre-trained model from
+https://huggingface.co/pkufool/icefall-asr-zipformer-wenetspeech-20230615
+as an example to show how to use this file.
+
+1. Download the pre-trained model
+
+cd egs/wenetspeech/ASR
+
+repo_url=https://huggingface.co/pkufool/icefall-asr-zipformer-wenetspeech-20230615
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+git lfs pull --include "data/lang_char/tokens.txt"
+git lfs pull --include "exp/pretrained.pt"
+
+cd exp
+ln -s pretrained.pt epoch-9999.pt
+popd
+
+2. Export the model to ONNX
+
+./zipformer/export-onnx.py \
+ --tokens $repo/data/lang_char/tokens.txt \
+ --epoch 9999 \
+ --avg 1 \
+ --exp-dir $repo/exp/
+
+It will generate the following 3 files inside $repo/exp:
+
+ - encoder-epoch-9999-avg-1.onnx
+ - decoder-epoch-9999-avg-1.onnx
+ - joiner-epoch-9999-avg-1.onnx
+
+2. Run this file
+
+./zipformer/onnx_decode.py \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \
+ --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \
+ --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \
+"""
+
+
+import argparse
+import logging
+import time
+from pathlib import Path
+from typing import List, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+from asr_datamodule import WenetSpeechAsrDataModule
+from lhotse.cut import Cut
+from onnx_pretrained import OnnxModel, greedy_search
+
+from icefall.utils import setup_logger, store_transcripts, write_error_stats
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--encoder-model-filename",
+ type=str,
+ required=True,
+ help="Path to the encoder onnx model. ",
+ )
+
+ parser.add_argument(
+ "--decoder-model-filename",
+ type=str,
+ required=True,
+ help="Path to the decoder onnx model. ",
+ )
+
+ parser.add_argument(
+ "--joiner-model-filename",
+ type=str,
+ required=True,
+ help="Path to the joiner onnx model. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless7/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="Valid values are greedy_search and modified_beam_search",
+ )
+
+ return parser
+
+
+def decode_one_batch(
+ model: OnnxModel, token_table: k2.SymbolTable, batch: dict
+) -> List[List[str]]:
+ """Decode one batch and return the result.
+ Currently it only greedy_search is supported.
+
+ Args:
+ model:
+ The neural model.
+ token_table:
+ Mapping ids to tokens.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+
+ Returns:
+ Return the decoded results for each utterance.
+ """
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(dtype=torch.int64)
+
+ encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens)
+
+ hyps = greedy_search(
+ model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
+ )
+
+ hyps = [[token_table[h] for h in hyp] for hyp in hyps]
+ return hyps
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ model: nn.Module,
+ token_table: k2.SymbolTable,
+) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ model:
+ The neural model.
+ token_table:
+ Mapping ids to tokens.
+
+ Returns:
+ - A list of tuples. Each tuple contains three elements:
+ - cut_id,
+ - reference transcript,
+ - predicted result.
+ - The total duration (in seconds) of the dataset.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ log_interval = 10
+ total_duration = 0
+
+ results = []
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+ total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
+
+ hyps = decode_one_batch(model=model, token_table=token_table, batch=batch)
+
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = list(ref_text)
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results.extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+
+ return results, total_duration
+
+
+def save_results(
+ res_dir: Path,
+ test_set_name: str,
+ results: List[Tuple[str, List[str], List[str]]],
+):
+ recog_path = res_dir / f"recogs-{test_set_name}.txt"
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = res_dir / f"errs-{test_set_name}.txt"
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ errs_info = res_dir / f"wer-summary-{test_set_name}.txt"
+ with open(errs_info, "w") as f:
+ print("WER", file=f)
+ print(wer, file=f)
+
+ s = "\nFor {}, WER is {}:\n".format(test_set_name, wer)
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ WenetSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+
+ assert (
+ args.decoding_method == "greedy_search"
+ ), "Only supports greedy_search currently."
+ res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}"
+
+ setup_logger(f"{res_dir}/log-decode")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ logging.info(f"Device: {device}")
+
+ token_table = k2.SymbolTable.from_file(args.tokens)
+ assert token_table[0] == ""
+
+ logging.info(vars(args))
+
+ logging.info("About to create model")
+ model = OnnxModel(
+ encoder_model_filename=args.encoder_model_filename,
+ decoder_model_filename=args.decoder_model_filename,
+ joiner_model_filename=args.joiner_model_filename,
+ )
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+
+ wenetspeech = WenetSpeechAsrDataModule(args)
+
+ def remove_short_utt(c: Cut):
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ if T <= 0:
+ logging.warning(
+ f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
+ )
+ return T > 0
+
+ dev_cuts = wenetspeech.valid_cuts()
+ dev_cuts = dev_cuts.filter(remove_short_utt)
+ dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
+
+ test_net_cuts = wenetspeech.test_net_cuts()
+ test_net_cuts = test_net_cuts.filter(remove_short_utt)
+ test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
+
+ test_meeting_cuts = wenetspeech.test_meeting_cuts()
+ test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt)
+ test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
+
+ test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
+ test_dl = [dev_dl, test_net_dl, test_meeting_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ start_time = time.time()
+ results, total_duration = decode_dataset(
+ dl=test_dl, model=model, token_table=token_table
+ )
+ end_time = time.time()
+ elapsed_seconds = end_time - start_time
+ rtf = elapsed_seconds / total_duration
+
+ logging.info(f"Elapsed time: {elapsed_seconds:.3f} s")
+ logging.info(f"Wave duration: {total_duration:.3f} s")
+ logging.info(
+ f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
+ )
+
+ save_results(res_dir=res_dir, test_set_name=test_set, results=results)
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/wenetspeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/wenetspeech/ASR/zipformer/onnx_pretrained-streaming.py
new file mode 120000
index 000000000..cfea104c2
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/onnx_pretrained-streaming.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/zipformer/onnx_pretrained.py b/egs/wenetspeech/ASR/zipformer/onnx_pretrained.py
new file mode 120000
index 000000000..8f32f4ee7
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/onnx_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/onnx_pretrained.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/zipformer/optim.py b/egs/wenetspeech/ASR/zipformer/optim.py
new file mode 120000
index 000000000..5eaa3cffd
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/optim.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/zipformer/pretrained.py b/egs/wenetspeech/ASR/zipformer/pretrained.py
new file mode 120000
index 000000000..0bd71dde4
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/pretrained.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/zipformer/scaling.py b/egs/wenetspeech/ASR/zipformer/scaling.py
new file mode 120000
index 000000000..6f398f431
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/scaling.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/zipformer/scaling_converter.py b/egs/wenetspeech/ASR/zipformer/scaling_converter.py
new file mode 120000
index 000000000..b0ecee05e
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/scaling_converter.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/zipformer/streaming_beam_search.py b/egs/wenetspeech/ASR/zipformer/streaming_beam_search.py
new file mode 120000
index 000000000..b1ed54557
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/streaming_beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/streaming_beam_search.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/zipformer/streaming_decode.py b/egs/wenetspeech/ASR/zipformer/streaming_decode.py
new file mode 100755
index 000000000..94c5fae5f
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/streaming_decode.py
@@ -0,0 +1,881 @@
+#!/usr/bin/env python3
+# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang,
+# Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Usage:
+./zipformer/streaming_decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 256 \
+ --exp-dir ./zipformer/exp \
+ --decoding-method greedy_search \
+ --num-decode-streams 2000
+"""
+
+import argparse
+import logging
+import math
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import numpy as np
+import torch
+from asr_datamodule import WenetSpeechAsrDataModule
+from decode_stream import DecodeStream
+from kaldifeat import Fbank, FbankOptions
+from lhotse import CutSet
+from streaming_beam_search import (
+ fast_beam_search_one_best,
+ greedy_search,
+ modified_beam_search,
+)
+from torch import Tensor, nn
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_model, get_params
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=28,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="Path to the lang dir(containing lexicon, tokens, etc.)",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Supported decoding methods are:
+ greedy_search
+ modified_beam_search
+ fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "--num_active_paths",
+ type=int,
+ default=4,
+ help="""An interger indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=32,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--blank-penalty",
+ type=float,
+ default=0.0,
+ help="""
+ The penalty applied on blank symbol during decoding.
+ Note: It is a positive value that would be applied to logits like
+ this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
+ [batch_size, vocab] and blank id is 0).
+ """,
+ )
+
+ parser.add_argument(
+ "--num-decode-streams",
+ type=int,
+ default=2000,
+ help="The number of streams that can be decoded parallel.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_init_states(
+ model: nn.Module,
+ batch_size: int = 1,
+ device: torch.device = torch.device("cpu"),
+) -> List[torch.Tensor]:
+ """
+ Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
+ is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ states[-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ states[-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+ """
+ states = model.encoder.get_init_states(batch_size, device)
+
+ embed_states = model.encoder_embed.get_init_states(batch_size, device)
+ states.append(embed_states)
+
+ processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
+ states.append(processed_lens)
+
+ return states
+
+
+def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
+ """Stack list of zipformer states that correspond to separate utterances
+ into a single emformer state, so that it can be used as an input for
+ zipformer when those utterances are formed into a batch.
+
+ Args:
+ state_list:
+ Each element in state_list corresponding to the internal state
+ of the zipformer model for a single utterance. For element-n,
+ state_list[n] is a list of cached tensors of all encoder layers. For layer-i,
+ state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1,
+ cached_val2, cached_conv1, cached_conv2).
+ state_list[n][-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ state_list[n][-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+
+ Note:
+ It is the inverse of :func:`unstack_states`.
+ """
+ batch_size = len(state_list)
+ assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0])
+ tot_num_layers = (len(state_list[0]) - 2) // 6
+
+ batch_states = []
+ for layer in range(tot_num_layers):
+ layer_offset = layer * 6
+ # cached_key: (left_context_len, batch_size, key_dim)
+ cached_key = torch.cat(
+ [state_list[i][layer_offset] for i in range(batch_size)], dim=1
+ )
+ # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
+ cached_nonlin_attn = torch.cat(
+ [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1
+ )
+ # cached_val1: (left_context_len, batch_size, value_dim)
+ cached_val1 = torch.cat(
+ [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1
+ )
+ # cached_val2: (left_context_len, batch_size, value_dim)
+ cached_val2 = torch.cat(
+ [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1
+ )
+ # cached_conv1: (#batch, channels, left_pad)
+ cached_conv1 = torch.cat(
+ [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0
+ )
+ # cached_conv2: (#batch, channels, left_pad)
+ cached_conv2 = torch.cat(
+ [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0
+ )
+ batch_states += [
+ cached_key,
+ cached_nonlin_attn,
+ cached_val1,
+ cached_val2,
+ cached_conv1,
+ cached_conv2,
+ ]
+
+ cached_embed_left_pad = torch.cat(
+ [state_list[i][-2] for i in range(batch_size)], dim=0
+ )
+ batch_states.append(cached_embed_left_pad)
+
+ processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
+ batch_states.append(processed_lens)
+
+ return batch_states
+
+
+def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
+ """Unstack the zipformer state corresponding to a batch of utterances
+ into a list of states, where the i-th entry is the state from the i-th
+ utterance in the batch.
+
+ Note:
+ It is the inverse of :func:`stack_states`.
+
+ Args:
+ batch_states: A list of cached tensors of all encoder layers. For layer-i,
+ states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
+ cached_conv1, cached_conv2).
+ state_list[-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ states[-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+
+ Returns:
+ state_list: A list of list. Each element in state_list corresponding to the internal state
+ of the zipformer model for a single utterance.
+ """
+ assert (len(batch_states) - 2) % 6 == 0, len(batch_states)
+ tot_num_layers = (len(batch_states) - 2) // 6
+
+ processed_lens = batch_states[-1]
+ batch_size = processed_lens.shape[0]
+
+ state_list = [[] for _ in range(batch_size)]
+
+ for layer in range(tot_num_layers):
+ layer_offset = layer * 6
+ # cached_key: (left_context_len, batch_size, key_dim)
+ cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
+ # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
+ cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
+ chunks=batch_size, dim=1
+ )
+ # cached_val1: (left_context_len, batch_size, value_dim)
+ cached_val1_list = batch_states[layer_offset + 2].chunk(
+ chunks=batch_size, dim=1
+ )
+ # cached_val2: (left_context_len, batch_size, value_dim)
+ cached_val2_list = batch_states[layer_offset + 3].chunk(
+ chunks=batch_size, dim=1
+ )
+ # cached_conv1: (#batch, channels, left_pad)
+ cached_conv1_list = batch_states[layer_offset + 4].chunk(
+ chunks=batch_size, dim=0
+ )
+ # cached_conv2: (#batch, channels, left_pad)
+ cached_conv2_list = batch_states[layer_offset + 5].chunk(
+ chunks=batch_size, dim=0
+ )
+ for i in range(batch_size):
+ state_list[i] += [
+ cached_key_list[i],
+ cached_nonlin_attn_list[i],
+ cached_val1_list[i],
+ cached_val2_list[i],
+ cached_conv1_list[i],
+ cached_conv2_list[i],
+ ]
+
+ cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
+ for i in range(batch_size):
+ state_list[i].append(cached_embed_left_pad_list[i])
+
+ processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0)
+ for i in range(batch_size):
+ state_list[i].append(processed_lens_list[i])
+
+ return state_list
+
+
+def streaming_forward(
+ features: Tensor,
+ feature_lens: Tensor,
+ model: nn.Module,
+ states: List[Tensor],
+ chunk_size: int,
+ left_context_len: int,
+) -> Tuple[Tensor, Tensor, List[Tensor]]:
+ """
+ Returns encoder outputs, output lengths, and updated states.
+ """
+ cached_embed_left_pad = states[-2]
+ (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
+ x=features,
+ x_lens=feature_lens,
+ cached_left_pad=cached_embed_left_pad,
+ )
+ assert x.size(1) == chunk_size, (x.size(1), chunk_size)
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+
+ # processed_mask is used to mask out initial states
+ processed_mask = torch.arange(left_context_len, device=x.device).expand(
+ x.size(0), left_context_len
+ )
+ processed_lens = states[-1] # (batch,)
+ # (batch, left_context_size)
+ processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
+ # Update processed lengths
+ new_processed_lens = processed_lens + x_lens
+
+ # (batch, left_context_size + chunk_size)
+ src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
+
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+ encoder_states = states[:-2]
+ (
+ encoder_out,
+ encoder_out_lens,
+ new_encoder_states,
+ ) = model.encoder.streaming_forward(
+ x=x,
+ x_lens=x_lens,
+ states=encoder_states,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ new_states = new_encoder_states + [
+ new_cached_embed_left_pad,
+ new_processed_lens,
+ ]
+ return encoder_out, encoder_out_lens, new_states
+
+
+def decode_one_chunk(
+ params: AttributeDict,
+ model: nn.Module,
+ decode_streams: List[DecodeStream],
+) -> List[int]:
+ """Decode one chunk frames of features for each decode_streams and
+ return the indexes of finished streams in a List.
+
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ decode_streams:
+ A List of DecodeStream, each belonging to a utterance.
+ Returns:
+ Return a List containing which DecodeStreams are finished.
+ """
+ device = model.device
+ chunk_size = int(params.chunk_size)
+ left_context_len = int(params.left_context_frames)
+
+ features = []
+ feature_lens = []
+ states = []
+ processed_lens = [] # Used in fast-beam-search
+
+ for stream in decode_streams:
+ feat, feat_len = stream.get_feature_frames(chunk_size * 2)
+ features.append(feat)
+ feature_lens.append(feat_len)
+ states.append(stream.states)
+ processed_lens.append(stream.done_frames)
+
+ feature_lens = torch.tensor(feature_lens, device=device)
+ features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
+
+ # Make sure the length after encoder_embed is at least 1.
+ # The encoder_embed subsample features (T - 7) // 2
+ # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
+ tail_length = chunk_size * 2 + 7 + 2 * 3
+ if features.size(1) < tail_length:
+ pad_length = tail_length - features.size(1)
+ feature_lens += pad_length
+ features = torch.nn.functional.pad(
+ features,
+ (0, 0, 0, pad_length),
+ mode="constant",
+ value=LOG_EPS,
+ )
+
+ states = stack_states(states)
+
+ encoder_out, encoder_out_lens, new_states = streaming_forward(
+ features=features,
+ feature_lens=feature_lens,
+ model=model,
+ states=states,
+ chunk_size=chunk_size,
+ left_context_len=left_context_len,
+ )
+
+ encoder_out = model.joiner.encoder_proj(encoder_out)
+
+ if params.decoding_method == "greedy_search":
+ greedy_search(
+ model=model,
+ encoder_out=encoder_out,
+ streams=decode_streams,
+ blank_penalty=params.blank_penalty,
+ )
+ elif params.decoding_method == "fast_beam_search":
+ processed_lens = torch.tensor(processed_lens, device=device)
+ processed_lens = processed_lens + encoder_out_lens
+ fast_beam_search_one_best(
+ model=model,
+ encoder_out=encoder_out,
+ processed_lens=processed_lens,
+ streams=decode_streams,
+ beam=params.beam,
+ max_states=params.max_states,
+ max_contexts=params.max_contexts,
+ blank_penalty=params.blank_penalty,
+ )
+ elif params.decoding_method == "modified_beam_search":
+ modified_beam_search(
+ model=model,
+ streams=decode_streams,
+ encoder_out=encoder_out,
+ num_active_paths=params.num_active_paths,
+ blank_penalty=params.blank_penalty,
+ )
+ else:
+ raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+
+ states = unstack_states(new_states)
+
+ finished_streams = []
+ for i in range(len(decode_streams)):
+ decode_streams[i].states = states[i]
+ decode_streams[i].done_frames += encoder_out_lens[i]
+ if decode_streams[i].done:
+ finished_streams.append(i)
+
+ return finished_streams
+
+
+def decode_dataset(
+ cuts: CutSet,
+ params: AttributeDict,
+ model: nn.Module,
+ lexicon: Lexicon,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ cuts:
+ Lhotse Cutset containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ lexicon:
+ The Lexicon.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ device = model.device
+
+ opts = FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = 16000
+ opts.mel_opts.num_bins = 80
+
+ log_interval = 100
+
+ decode_results = []
+ # Contain decode streams currently running.
+ decode_streams = []
+ for num, cut in enumerate(cuts):
+ # each utterance has a DecodeStream.
+ initial_states = get_init_states(model=model, batch_size=1, device=device)
+ decode_stream = DecodeStream(
+ params=params,
+ cut_id=cut.id,
+ initial_states=initial_states,
+ decoding_graph=decoding_graph,
+ device=device,
+ )
+
+ audio: np.ndarray = cut.load_audio()
+ # audio.shape: (1, num_samples)
+ assert len(audio.shape) == 2
+ assert audio.shape[0] == 1, "Should be single channel"
+ assert audio.dtype == np.float32, audio.dtype
+
+ # The trained model is using normalized samples
+ if audio.max() > 1:
+ logging.warning(
+ f"The audio should be normalized to [-1, 1], audio.max : {audio.max()}."
+ f"Clipping to [-1, 1]."
+ )
+ audio = np.clip(audio, -1, 1)
+
+ samples = torch.from_numpy(audio).squeeze(0)
+
+ fbank = Fbank(opts)
+ feature = fbank(samples.to(device))
+ decode_stream.set_features(feature, tail_pad_len=30)
+ decode_stream.ground_truth = cut.supervisions[0].text
+
+ decode_streams.append(decode_stream)
+
+ while len(decode_streams) >= params.num_decode_streams:
+ finished_streams = decode_one_chunk(
+ params=params, model=model, decode_streams=decode_streams
+ )
+ for i in sorted(finished_streams, reverse=True):
+ decode_results.append(
+ (
+ decode_streams[i].id,
+ list(decode_streams[i].ground_truth.strip()),
+ [
+ lexicon.token_table[idx]
+ for idx in decode_streams[i].decoding_result()
+ ],
+ )
+ )
+ del decode_streams[i]
+
+ if num % log_interval == 0:
+ logging.info(f"Cuts processed until now is {num}.")
+
+ # decode final chunks of last sequences
+ while len(decode_streams):
+ finished_streams = decode_one_chunk(
+ params=params, model=model, decode_streams=decode_streams
+ )
+ for i in sorted(finished_streams, reverse=True):
+ decode_results.append(
+ (
+ decode_streams[i].id,
+ decode_streams[i].ground_truth.split(),
+ [
+ lexicon.token_table[idx]
+ for idx in decode_streams[i].decoding_result()
+ ],
+ )
+ )
+ del decode_streams[i]
+
+ key = f"blank_penalty_{params.blank_penalty}"
+ if params.decoding_method == "greedy_search":
+ key = f"greedy_search_{key}"
+ elif params.decoding_method == "fast_beam_search":
+ key = (
+ f"beam_{params.beam}_"
+ f"max_contexts_{params.max_contexts}_"
+ f"max_states_{params.max_states}_{key}"
+ )
+ elif params.decoding_method == "modified_beam_search":
+ key = f"num_active_paths_{params.num_active_paths}_{key}"
+ else:
+ raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+ return {key: decode_results}
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ WenetSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ params.res_dir = params.exp_dir / "streaming" / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ assert params.causal, params.causal
+ assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+ params.suffix += f"-blank-penalty-{params.blank_penalty}"
+
+ # for fast_beam_search
+ if params.decoding_method == "fast_beam_search":
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ params.blank_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ decoding_graph = None
+ if params.decoding_method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ wenetspeech = WenetSpeechAsrDataModule(args)
+
+ dev_cuts = wenetspeech.valid_cuts()
+ test_net_cuts = wenetspeech.test_net_cuts()
+ test_meeting_cuts = wenetspeech.test_meeting_cuts()
+
+ test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
+ test_cuts = [dev_cuts, test_net_cuts, test_meeting_cuts]
+
+ for test_set, test_cut in zip(test_sets, test_cuts):
+ results_dict = decode_dataset(
+ cuts=test_cut,
+ params=params,
+ model=model,
+ lexicon=lexicon,
+ decoding_graph=decoding_graph,
+ )
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/wenetspeech/ASR/zipformer/subsampling.py b/egs/wenetspeech/ASR/zipformer/subsampling.py
new file mode 120000
index 000000000..01ae9002c
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/subsampling.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/zipformer/train.py b/egs/wenetspeech/ASR/zipformer/train.py
new file mode 100755
index 000000000..83dbfa22f
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/train.py
@@ -0,0 +1,1350 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Daniel Povey)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+./zipformer/train.py \
+ --world-size 8 \
+ --num-epochs 12 \
+ --start-epoch 1 \
+ --exp-dir zipformer/exp \
+ --training-subset L
+ --lr-epochs 1.5 \
+ --max-duration 350
+
+# For mix precision training:
+
+./zipformer/train.py \
+ --world-size 8 \
+ --num-epochs 12 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --training-subset L \
+ --lr-epochs 1.5 \
+ --max-duration 750
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import WenetSpeechAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import AsrModel
+from optim import Eden, ScaledAdam
+from scaling import ScheduledFloat
+from subsampling import Conv2dSubsampling
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer2
+
+from icefall import diagnostics
+from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,2,3,4,3,2",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--downsampling-factor",
+ type=str,
+ default="1,2,4,8,4,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dim",
+ type=str,
+ default="512,768,1024,1536,1024,768",
+ help="""Feedforward dimension of the zipformer encoder layers, per stack, comma separated.""",
+ )
+
+ parser.add_argument(
+ "--num-heads",
+ type=str,
+ default="4,4,4,8,4,4",
+ help="""Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.""",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=str,
+ default="192,256,384,512,384,256",
+ help="""Embedding dimension in encoder stacks: a single int or comma-separated list.""",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=str,
+ default="32",
+ help="""Query/key dimension per head in encoder stacks: a single int or comma-separated list.""",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=str,
+ default="12",
+ help="""Value dimension per head in encoder stacks: a single int or comma-separated list.""",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=str,
+ default="4",
+ help="""Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.""",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default="48",
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dim",
+ type=str,
+ default="192,192,256,256,256,192",
+ help="""Unmasked dimensions in the encoders, relates to augmentation during training. A single int or comma-separated list. Must be <= each corresponding encoder_dim.""",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernel",
+ type=str,
+ default="31,31,15,15,15,31",
+ help="""Sizes of convolutional kernels in convolution modules in each encoder stack: a single int or comma-separated list.""",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--causal",
+ type=str2bool,
+ default=False,
+ help="If True, use causal version of model.",
+ )
+
+ parser.add_argument(
+ "--chunk-size",
+ type=str,
+ default="16,32,64,-1",
+ help="""Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. Must be just -1 if --causal=False""",
+ )
+
+ parser.add_argument(
+ "--left-context-frames",
+ type=str,
+ default="64,128,256,-1",
+ help="""Maximum left-contexts for causal training, measured in frames which will
+ be converted to a number of chunks. If splitting into chunks,
+ chunk left-context frames will be chosen randomly from this list; else not relevant.""",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="""The lang dir
+ It contains language related input files such as
+ "lexicon.txt"
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.045, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=3.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="""Reference batch duration for purposes of adjusting batch counts for setting various schedules inside the model""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="""The context size in the decoder. 1 means bigram; 2 means tri-gram""",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="""The prune range for rnnt loss, it means how many symbols(context)
+ we are using to compute the loss""",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="""The scale to smooth the loss with lm
+ (output of prediction network) part.""",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="""The scale to smooth the loss with am (output of encoder network) part.""",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="""To get pruning ranges, we will calculate a simple version
+ loss(joiner is just addition), this simple loss also uses for
+ training (as a regularization item). We will scale the simple loss
+ with this parameter before adding to the final loss.""",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000,
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_encoder_embed(params: AttributeDict) -> nn.Module:
+ # encoder_embed converts the input of shape (N, T, num_features)
+ # to the shape (N, (T - 7) // 2, encoder_dims).
+ # That is, it does two things simultaneously:
+ # (1) subsampling: T -> (T - 7) // 2
+ # (2) embedding: num_features -> encoder_dims
+ # In the normal configuration, we will downsample once more at the end
+ # by a factor of 2, and most of the encoder stacks will run at a lower
+ # sampling rate.
+ encoder_embed = Conv2dSubsampling(
+ in_channels=params.feature_dim,
+ out_channels=_to_int_tuple(params.encoder_dim)[0],
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ )
+ return encoder_embed
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ encoder = Zipformer2(
+ output_downsampling_factor=2,
+ downsampling_factor=_to_int_tuple(params.downsampling_factor),
+ num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
+ encoder_dim=_to_int_tuple(params.encoder_dim),
+ encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
+ query_head_dim=_to_int_tuple(params.query_head_dim),
+ pos_head_dim=_to_int_tuple(params.pos_head_dim),
+ value_head_dim=_to_int_tuple(params.value_head_dim),
+ pos_dim=params.pos_dim,
+ num_heads=_to_int_tuple(params.num_heads),
+ feedforward_dim=_to_int_tuple(params.feedforward_dim),
+ cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ warmup_batches=4000.0,
+ causal=params.causal,
+ chunk_size=_to_int_tuple(params.chunk_size),
+ left_context_frames=_to_int_tuple(params.left_context_frames),
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ encoder_embed = get_encoder_embed(params)
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = AsrModel(
+ encoder_embed=encoder_embed,
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=int(max(params.encoder_dim.split(","))),
+ decoder_dim=params.decoder_dim,
+ vocab_size=params.vocab_size,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ if "cur_batch_idx" in saved_params:
+ params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute CTC loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = graph_compiler.texts_to_ids(texts)
+ y = k2.RaggedTensor(y).to(device)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss, _ = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+
+ loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ cur_batch_idx = params.get("cur_batch_idx", 0)
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+ if batch_idx < cur_batch_idx:
+ continue
+ cur_batch_idx = batch_idx
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ params.cur_batch_idx = batch_idx
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ del params.cur_batch_idx
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ graph_compiler = CharCtcTrainingGraphCompiler(
+ lexicon=lexicon,
+ device=device,
+ )
+
+ params.blank_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 2**22
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ wenetspeech = WenetSpeechAsrDataModule(args)
+
+ train_cuts = wenetspeech.train_cuts()
+ valid_cuts = wenetspeech.valid_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 15 seconds
+ #
+ # Caution: There is a reason to select 15.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 15.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = graph_compiler.texts_to_ids([c.supervisions[0].text])[0]
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = wenetspeech.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_dl = wenetspeech.valid_dataloaders(valid_cuts)
+
+ if False and not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ graph_compiler=graph_compiler,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ graph_compiler=graph_compiler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ graph_compiler:
+ The compiler to encode texts to ids.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ texts = supervisions["text"]
+ y = graph_compiler.texts_to_ids(texts)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ WenetSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.lang_dir = Path(args.lang_dir)
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/wenetspeech/ASR/zipformer/zipformer.py b/egs/wenetspeech/ASR/zipformer/zipformer.py
new file mode 120000
index 000000000..23011dda7
--- /dev/null
+++ b/egs/wenetspeech/ASR/zipformer/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/zipformer.py
\ No newline at end of file
diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py
index 0f0887859..3d206d139 100755
--- a/icefall/rnn_lm/train.py
+++ b/icefall/rnn_lm/train.py
@@ -99,6 +99,15 @@ def get_parser():
""",
)
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
parser.add_argument(
"--exp-dir",
type=str,
@@ -242,7 +251,9 @@ def load_checkpoint_if_available(
) -> None:
"""Load checkpoint from file.
- If params.start_epoch is positive, it will load the checkpoint from
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
@@ -261,10 +272,14 @@ def load_checkpoint_if_available(
Returns:
Return None.
"""
- if params.start_epoch <= 0:
- return
- filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
logging.info(f"Loading checkpoint: {filename}")
saved_params = load_checkpoint(
filename,
@@ -283,6 +298,13 @@ def load_checkpoint_if_available(
for k in keys:
params[k] = saved_params[k]
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ if "cur_batch_idx" in saved_params:
+ params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
return saved_params
@@ -438,7 +460,14 @@ def train_one_epoch(
tot_loss = MetricsTracker()
+ cur_batch_idx = params.get("cur_batch_idx", 0)
+
for batch_idx, batch in enumerate(train_dl):
+
+ if batch_idx < cur_batch_idx:
+ continue
+ cur_batch_idx = batch_idx
+
params.batch_idx_train += 1
x, y, sentence_lengths = batch
batch_size = x.size(0)
@@ -463,6 +492,7 @@ def train_one_epoch(
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
+ params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
@@ -471,6 +501,7 @@ def train_one_epoch(
optimizer=optimizer,
rank=rank,
)
+ del params.cur_batch_idx
if batch_idx % params.log_interval == 0:
# Note: "frames" here means "num_tokens"