HENT-SRT: Hierarchical Efficient Neural Transducer with Self-Distillation for Joint Speech Recognition and Translation

Paper: https://arxiv.org/abs/2506.02157
This commit is contained in:
Amir Hussein 2025-09-19 12:17:53 -04:00 committed by GitHub
parent 63563d16d3
commit 855536d355
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 24379 additions and 0 deletions

View File

@ -0,0 +1,53 @@
# HENT-SRT
This repository contains a **speech-to-text translation (ST)** recipe accompanying our IWSLT 2025 paper:
**HENT-SRT: Hierarchical Efficient Neural Transducer with Self-Distillation for Joint Speech Recognition and Translation**
Paper: <https://arxiv.org/abs/2506.02157>
## Datasets
The recipe combines three conversational, 3-way parallel ST corpora:
- **TunisianEnglish (IWSLT22 TA)**
Lhotse recipe: <https://github.com/lhotse-speech/lhotse/blob/master/lhotse/recipes/iwslt22_ta.py>
- **Fisher Spanish**
Reference: <https://aclanthology.org/2013.iwslt-papers.14>
- **HKUST (Mandarin Telephone Speech)**
Reference: <https://arxiv.org/abs/2404.11619>
> **Data access:** Fisher and HKUST require an institutional LDC subscription.
> **Recipe status:** Lhotse recipes for Fisher Spanish and HKUST are in progress and will be finalized soon.
## Zipformer Multi-joiner ST
This model is similar to https://www.isca-archive.org/interspeech_2023/wang23oa_interspeech.pdf, but
our system uses zipformer encoder with a pruned transducer and stateless decoder
| Dataset | Decoding method | test WER | test BLEU | comment |
| ---------- | -------------------- | -------- | --------- | ----------------------------------------------- |
| iwslt\_ta | modified beam search | 41.6 | 16.3 | --epoch 20, --avg 13, beam(20), |
| hkust | modified beam search | 23.8 | 10.4 | --epoch 20, --avg 13, beam(20), |
| fisher\_sp | modified beam search | 18.0 | 31.0 | --epoch 20, --avg 13, beam(20), |
## Hent-SRT offline
| Dataset | Decoding method | test WER | test BLEU | comment |
| ---------- | -------------------- | -------- | --------- | ----------------------------------------------- |
| iwslt\_ta | modified beam search | 41.4 | 20.6 | --epoch 20, --avg 13, beam(20), BP 1 |
| hkust | modified beam search | 22.8 | 14.7 | --epoch 20, --avg 13, beam(20), BP 1 |
| fisher\_sp | modified beam search | 17.8 | 33.7 | --epoch 20, --avg 13, beam(20), BP 1 |
## Hent-SRT streaming
| Dataset | Decoding method | test WER | test BLEU | comment |
| ---------- | -------------------- | -------- | --------- | ----------------------------------------------- |
| iwslt\_ta | greedy search | 46.2 | 17.3 | --epoch 20, --avg 13, BP 2, chunk-size 64, left-context-frames 128, max-sym-per-frame 20 |
| hkust | greedy search | 27.3 | 11.2 | --epoch 20, --avg 13, BP 2, chunk-size 64, left-context-frames 128, max-sym-per-frame 20|
| fisher\_sp | greedy search | 22.7 | 30.8 | --epoch 20, --avg 13, BP 2, chunk-size 64, left-context-frames 128, max-sym-per-frame 20 |
See [RESULTS](/egs/multi_conv_zh_es_ta/ST/RESULTS.md) for details.

View File

@ -0,0 +1,404 @@
## Zipformer Multi-joiner ST
### For offline model training:
You can find a pretrained model, training logs, decoding logs, and decoding results at: https://huggingface.co/AmirHussein/HENT-SRT/tree/main/zipformer_multijoiner_st
| Dataset | Decoding method | test WER | test BLEU | comment |
| ---------- | -------------------- | -------- | --------- | ----------------------------------------------- |
| iwslt\_ta | modified beam search | 41.6 | 16.3 | --epoch 25, --avg 13, beam(20), |
| hkust | modified beam search | 23.8 | 10.4 | --epoch 25, --avg 13, beam(20), |
| fisher\_sp | modified beam search | 18.0 | 31.0 | --epoch 25, --avg 13, beam(20), |
The training command:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer_multijoiner_st/train.py \
--base-lr 0.045 \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer_multijoiner_st/exp-multi-joiner-pbe4k\
--causal 0 \
--num-encoder-layers 2,2,2,2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024,1024,1024,768 \
--encoder-dim 192,256,384,512,384,384,384,256 \
--encoder-unmasked-dim 192,192,256,256,256,256,256,192 \
--downsampling-factor 1,2,4,8,8,4,4,2\
--cnn-module-kernel 31,31,15,15,15,15,31,31 \
--num-heads 4,4,4,8,8,8,4,4 \
--bpe-st-model data/lang_st_bpe_4000/bpe.model \
--max-duration 400 \
--prune-range 10 \
--warm-step 10000 \
--lr-epochs 6 \
--use-hat False
```
Decodeing command:
```
./zipformer_multijoiner_st/decode.py \
--exp-dir ./zipformer_multijoiner_st/exp-multi-joiner-pbe4k \
--epoch 25 \
--avg 13 \
--beam-size 20 \
--max-duration 600 \
--decoding-method modified_beam_search \
--bpe-st-model data/lang_st_bpe_4000/bpe.model \
--bpe-model data/lang_bpe_5000/bpe.model \
--num-encoder-layers 2,2,2,2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024,1024,1024,768 \
--encoder-dim 192,256,384,512,384,384,384,256 \
--encoder-unmasked-dim 192,192,256,256,256,256,256,192 \
--downsampling-factor 1,2,4,8,8,4,4,2 \
--cnn-module-kernel 31,31,15,15,15,15,31,31 \
--num-heads 4,4,4,8,8,8,4,4 \
--use-averaged-model True
```
### For streaming model training:
| Dataset | Decoding method | test WER | test BLEU | comment |
| ---------- | -------------------- | -------- | --------- | ----------------------------------------------- |
| iwslt\_ta | greedy search | 44.1 | 6.0 | --epoch 25, --avg 13 |
| hkust | greedy search | 27.4 | 3.7 | --epoch 25, --avg 13 |
| fisher\_sp | greedy search | 19.9 | 16.3 | --epoch 25, --avg 13 |
The training command:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer_multijoiner_st/train.py \
--base-lr 0.045 \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer_multijoiner_st/exp-multi-joiner-pbe4k_causal\
--causal 1 \
--num-encoder-layers 2,2,2,2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024,1024,1024,768 \
--encoder-dim 192,256,384,512,384,384,384,256 \
--encoder-unmasked-dim 192,192,256,256,256,256,256,192 \
--downsampling-factor 1,2,4,8,8,4,4,2\
--cnn-module-kernel 31,31,15,15,15,15,31,31 \
--num-heads 4,4,4,8,8,8,4,4 \
--bpe-st-model data/lang_st_bpe_4000/bpe.model \
--max-duration 400 \
--prune-range 10 \
--warm-step 10000 \
--lr-epochs 6 \
--use-hat False
```
Decodeing command:
```
./zipformer_multijoiner_st/decode.py \
--exp-dir ./zipformer_multijoiner_st/exp-multi-joiner-pbe4k \
--causal 1 \
--epoch 25 \
--avg 13 \
--beam-size 20 \
--max-duration 600 \
--decoding-method modified_beam_search \
--bpe-st-model data/lang_st_bpe_4000/bpe.model \
--bpe-model data/lang_bpe_5000/bpe.model \
--num-encoder-layers 2,2,2,2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024,1024,1024,768 \
--encoder-dim 192,256,384,512,384,384,384,256 \
--encoder-unmasked-dim 192,192,256,256,256,256,256,192 \
--downsampling-factor 1,2,4,8,8,4,4,2 \
--cnn-module-kernel 31,31,15,15,15,15,31,31 \
--num-heads 4,4,4,8,8,8,4,4 \
--use-averaged-model True \
--decoding-method greedy_search \
--chunk-size 64 \
--left-context-frames 128 \
--use-hat False \
--max-sym-per-frame 20
```
## Hent-SRT offline
You can find a pretrained model, training logs, decoding logs, and decoding results at: https://huggingface.co/AmirHussein/HENT-SRT/tree/main/hent_srt
| Dataset | Decoding method | test WER | test BLEU | comment |
| ---------- | -------------------- | -------- | --------- | ----------------------------------------------- |
| iwslt\_ta | modified beam search | 41.4 | 20.6 | --epoch 20, --avg 13, beam(20), BP 1 |
| hkust | modified beam search | 22.8 | 14.7 | --epoch 20, --avg 13, beam(20), BP 1 |
| fisher\_sp | modified beam search | 17.8 | 33.7 | --epoch 20, --avg 13, beam(20), BP 1 |
### First pretrain the offline CR-CTC ASR
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./hent_srt/train.py \
--base-lr 0.055 \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir hent_srt/exp-asr\
--causal 0 \
--num-encoder-layers 2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024 \
--encoder-dim 192,256,384,512,384 \
--encoder-unmasked-dim 192,192,256,256,256 \
--downsampling-factor 1,2,4,8,4 \
--cnn-module-kernel 31,31,15,15,15 \
--num-heads 4,4,4,8,8 \
--st-num-encoder-layers 2,2,2,2,2 \
--st-feedforward-dim 512,512,256,256,256 \
--st-encoder-dim 512,384,256,256,256 \
--st-encoder-unmasked-dim 256,256,256,256,192 \
--st-downsampling-factor 4,4,4,4,4 \
--st-cnn-module-kernel 15,31,31,15,15 \
--st-num-heads 4,4,8,8,8 \
--bpe-st-model data/lang_st_bpe_4000/bpe.model \
--bpe-model data/lang_bpe_5000/bpe.model \
--manifest-dir data/fbank \
--max-duration 350 \
--prune-range 10 \
--warm-step 8000 \
--ctc-loss-scale 0.2 \
--enable-spec-aug 0 \
--cr-loss-scale 0.2 \
--time-mask-ratio 2.5 \
--use-asr-cr-ctc 1 \
--use-ctc 1 \
--lr-epochs 6 \
--use-hat False \
--use-st-joiner False
```
### Train ST with a Pretrained ASR Initialization
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./hent_srt/train.py \
--base-lr 0.045 \
--world-size 4 \
--num-epochs 25 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir hent_srt/exp-st \
--model-init-ckpt hent_srt/exp-asr/best-valid-loss.pt \
--causal 0 \
--num-encoder-layers 2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024 \
--encoder-dim 192,256,384,512,384 \
--encoder-unmasked-dim 192,192,256,256,256 \
--downsampling-factor 1,2,4,8,4 \
--cnn-module-kernel 31,31,15,15,15 \
--num-heads 4,4,4,8,8 \
--st-num-encoder-layers 2,2,2,2,2 \
--st-feedforward-dim 512,512,256,256,256 \
--st-encoder-dim 384,512,256,256,256 \
--st-encoder-unmasked-dim 256,256,256,256,192 \
--st-downsampling-factor 1,2,4,4,4 \
--st-cnn-module-kernel 15,31,31,15,15 \
--st-num-heads 8,8,8,8,8 \
--output-downsampling-factor 2 \
--st-output-downsampling-factor 1 \
--bpe-st-model data/lang_st_bpe_4000/bpe.model \
--bpe-model data/lang_bpe_5000/bpe.model \
--manifest-dir data/fbank \
--max-duration 200 \
--prune-range 5 \
--st-prune-range 10 \
--warm-step 10000 \
--ctc-loss-scale 0.1 \
--st-ctc-loss-scale 0.1 \
--enable-spec-aug 0 \
--cr-loss-scale 0.05 \
--st-cr-loss-scale 0.05 \
--time-mask-ratio 2.5 \
--use-asr-cr-ctc 1 \
--use-ctc 1 \
--use-st-cr-ctc 1 \
--use-st-ctc 1 \
--lr-epochs 6 \
--use-hat False \
--use-st-joiner True
```
### Decode offline Hent-SRT
```
./hent_srt/decode.py \
--epoch 20 --avg 13 --use-averaged-model True \
--beam-size 20 \
--causal 0 \
--exp-dir hent_srt/exp-st \
--bpe-model data/lang_bpe_5000/bpe.model \
--bpe-st-model data/lang_st_bpe_4000/bpe.model \
--output-downsampling-factor 2 \
--st-output-downsampling-factor 1 \
--max-duration 800 \
--num-encoder-layers 2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024 \
--encoder-dim 192,256,384,512,384 \
--encoder-unmasked-dim 192,192,256,256,256 \
--downsampling-factor 1,2,4,8,4 \
--cnn-module-kernel 31,31,15,15,15 \
--num-heads 4,4,4,8,8 \
--st-num-encoder-layers 2,2,2,2,2 \
--st-feedforward-dim 512,512,256,256,256 \
--st-encoder-dim 384,512,256,256,256 \
--st-encoder-unmasked-dim 256,256,256,256,192 \
--st-downsampling-factor 1,2,4,4,4 \
--st-cnn-module-kernel 15,31,31,15,15 \
--st-num-heads 8,8,8,8,8 \
--decoding-method modified_beam_search \
--use-st-joiner True \
--use-hat-decode False \
--use-ctc 1 \
--use-st-ctc 1 \
--st-blank-penalty 1
```
## Hent-SRT streaming
| Dataset | Decoding method | test WER | test BLEU | comment |
| ---------- | -------------------- | -------- | --------- | ----------------------------------------------- |
| iwslt\_ta | greedy search | 46.2 | 17.3 | --epoch 20, --avg 13, BP 2, chunk-size 64, left-context-frames 128, max-sym-per-frame 20 |
| hkust | greedy search | 27.3 | 11.2 | --epoch 20, --avg 13, BP 2, chunk-size 64, left-context-frames 128, max-sym-per-frame 20|
| fisher\_sp | greedy search | 22.7 | 30.8 | --epoch 20, --avg 13, BP 2, chunk-size 64, left-context-frames 128, max-sym-per-frame 20 |
### First pretrain the streaming CR-CTC ASR
# CR-CTC ASR streaming
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./hent_srt/train.py \
--base-lr 0.055 \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir hent_srt/exp-asr_causal\
--causal 1 \
--num-encoder-layers 2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024 \
--encoder-dim 192,256,384,512,384 \
--encoder-unmasked-dim 192,192,256,256,256 \
--downsampling-factor 1,2,4,8,4 \
--cnn-module-kernel 31,31,15,15,15 \
--num-heads 4,4,4,8,8 \
--st-num-encoder-layers 2,2,2,2,2 \
--st-feedforward-dim 512,512,256,256,256 \
--st-encoder-dim 512,384,256,256,256 \
--st-encoder-unmasked-dim 256,256,256,256,192 \
--st-downsampling-factor 4,4,4,4,4 \
--st-cnn-module-kernel 15,31,31,15,15 \
--st-num-heads 4,4,8,8,8 \
--bpe-st-model data/lang_st_bpe_4000/bpe.model \
--bpe-model data/lang_bpe_5000/bpe.model \
--manifest-dir data/fbank \
--max-duration 250 \
--prune-range 10 \
--warm-step 8000 \
--ctc-loss-scale 0.2 \
--enable-spec-aug 0 \
--cr-loss-scale 0.2 \
--time-mask-ratio 2.5 \
--use-asr-cr-ctc 1 \
--use-ctc 1 \
--lr-epochs 6 \
--use-hat False \
--use-st-joiner False
```
### Train streaming ST with a Pretrained ASR Initialization
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./hent_srt/train.py \
--base-lr 0.045 \
--world-size 4 \
--num-epochs 25 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir hent_srt/exp-st_causal \
--model-init-ckpt hent_srt/exp-asr_causal/best-valid-loss.pt \
--causal 1 \
--num-encoder-layers 2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024 \
--encoder-dim 192,256,384,512,384 \
--encoder-unmasked-dim 192,192,256,256,256 \
--downsampling-factor 1,2,4,8,4 \
--cnn-module-kernel 31,31,15,15,15 \
--num-heads 4,4,4,8,8 \
--st-num-encoder-layers 2,2,2,2,2 \
--st-feedforward-dim 512,512,256,256,256 \
--st-encoder-dim 384,512,256,256,256 \
--st-encoder-unmasked-dim 256,256,256,256,192 \
--st-downsampling-factor 1,2,4,4,4 \
--st-cnn-module-kernel 15,31,31,15,15 \
--st-num-heads 8,8,8,8,8 \
--output-downsampling-factor 2 \
--st-output-downsampling-factor 1 \
--bpe-st-model data/lang_st_bpe_4000/bpe.model \
--bpe-model data/lang_bpe_5000/bpe.model \
--manifest-dir data/fbank \
--max-duration 200 \
--prune-range 5 \
--st-prune-range 10 \
--warm-step 10000 \
--ctc-loss-scale 0.1 \
--st-ctc-loss-scale 0.1 \
--enable-spec-aug 0 \
--cr-loss-scale 0.05 \
--st-cr-loss-scale 0.05 \
--time-mask-ratio 2.5 \
--use-asr-cr-ctc 1 \
--use-ctc 1 \
--use-st-cr-ctc 1 \
--use-st-ctc 1 \
--lr-epochs 6 \
--use-hat False \
--use-st-joiner True
```
### Decode streaming Hent-SRT
```
./hent_srt/decode.py \
--epoch 20 --avg 13 --use-averaged-model True \
--causal 1 \
--exp-dir hent_srt/exp-st_causal \
--bpe-model data/lang_bpe_5000/bpe.model \
--bpe-st-model data/lang_st_bpe_4000/bpe.model \
--output-downsampling-factor 2 \
--st-output-downsampling-factor 1 \
--max-duration 800 \
--num-encoder-layers 2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024 \
--encoder-dim 192,256,384,512,384 \
--encoder-unmasked-dim 192,192,256,256,256 \
--downsampling-factor 1,2,4,8,4 \
--cnn-module-kernel 31,31,15,15,15 \
--num-heads 4,4,4,8,8 \
--st-num-encoder-layers 2,2,2,2,2 \
--st-feedforward-dim 512,512,256,256,256 \
--st-encoder-dim 384,512,256,256,256 \
--st-encoder-unmasked-dim 256,256,256,256,192 \
--st-downsampling-factor 1,2,4,4,4 \
--st-cnn-module-kernel 15,31,31,15,15 \
--st-num-heads 8,8,8,8,8 \
--decoding-method greedy_search \
--use-st-joiner True \
--use-hat-decode False \
--use-ctc 1 \
--use-st-ctc 1 \
--st-blank-penalty 2 \
--chunk-size 64 \
--left-context-frames 128 \
--use-hat False --max-sym-per-frame 20
```

View File

@ -0,0 +1 @@
../zipformer_multijoiner_st/asr_datamodule.py

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,261 @@
# Copyright 2025 Johns Hopkins University (author: Amir Hussein)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from scaling import Balancer
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengrui Jin,
# Yifan Yang,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from scaling import Balancer
class LSTMDecoder(nn.Module):
"""LSTM decoder."""
def __init__(
self,
vocab_size: int,
blank_id: int,
decoder_dim: int,
num_layers: int,
hidden_dim: int,
embedding_dropout: float = 0.0,
rnn_dropout: float = 0.0,
):
"""
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
blank_id:
The ID of the blank symbol.
decoder_dim:
Dimension of the input embedding.
num_layers:
Number of LSTM layers.
hidden_dim:
Hidden dimension of LSTM layers.
embedding_dropout:
Dropout rate for the embedding layer.
rnn_dropout:
Dropout for LSTM layers.
"""
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=decoder_dim,
)
# the balancers are to avoid any drift in the magnitude of the
# embeddings, which would interact badly with parameter averaging.
self.balancer = Balancer(
decoder_dim,
channel_dim=-1,
min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
self.blank_id = blank_id
self.vocab_size = vocab_size
self.embedding_dropout = nn.Dropout(embedding_dropout)
self.rnn = nn.LSTM(
input_size=decoder_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=rnn_dropout,
)
self.balancer2 = Balancer(
decoder_dim,
channel_dim=-1,
min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
def forward(
self,
y: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
need_pad: bool = False
) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
Returns:
Return a tensor of shape (N, U, decoder_dim).
"""
y = y.to(torch.int64)
# this stuff about clamp() is a temporary fix for a mismatch
# at utterance start, we use negative ids in beam_search.py
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
embedding_out = self.embedding_dropout(embedding_out)
embedding_out = self.balancer(embedding_out)
if need_pad is True:
embedding_out = pad_sequence(embedding_out, batch_first=True, padding_value=0)
rnn_out, (h, c) = self.rnn(embedding_out, states)
rnn_out = F.relu(rnn_out)
rnn_out = self.balancer2(rnn_out)
return rnn_out, (h, c)
class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""
def __init__(
self,
vocab_size: int,
decoder_dim: int,
blank_id: int,
context_size: int,
):
"""
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
decoder_dim:
Dimension of the input embedding, and of the decoder output.
blank_id:
The ID of the blank symbol.
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
"""
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=decoder_dim,
)
# the balancers are to avoid any drift in the magnitude of the
# embeddings, which would interact badly with parameter averaging.
self.balancer = Balancer(
decoder_dim,
channel_dim=-1,
min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
self.blank_id = blank_id
assert context_size >= 1, context_size
self.context_size = context_size
self.vocab_size = vocab_size
if context_size > 1:
self.conv = nn.Conv1d(
in_channels=decoder_dim,
out_channels=decoder_dim,
kernel_size=context_size,
padding=0,
groups=decoder_dim // 4, # group size == 4
bias=False,
)
self.balancer2 = Balancer(
decoder_dim,
channel_dim=-1,
min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
need_pad:
True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference.
Returns:
Return a tensor of shape (N, U, decoder_dim).
"""
y = y.to(torch.int64)
# this stuff about clamp() is a temporary fix for a mismatch
# at utterance start, we use negative ids in beam_search.py
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
embedding_out = self.balancer(embedding_out)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = F.relu(embedding_out)
embedding_out = self.balancer2(embedding_out)
return embedding_out

View File

@ -0,0 +1,640 @@
#!/usr/bin/env python3
#
# Copyright 2025 Johns Hopkins University (author: Amir Hussein)
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
Note: This is a example for librispeech dataset, if you are using different
dataset, you should change the argument values according to your dataset.
(1) Export to torchscript model using torch.jit.script()
- For non-streaming model:
./hent_srt/export.py \
--exp-dir ./hent_srt/exp-st \
--causal 0 \
--use-averaged-model 1 \
--tokens data/lang_bpe_5000/tokens.txt \
--st-tokens data/lang_st_bpe_4000/tokens.txt \
--num-encoder-layers 2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024 \
--encoder-dim 192,256,384,512,384 \
--encoder-unmasked-dim 192,192,256,256,256 \
--downsampling-factor 1,2,4,8,4 \
--cnn-module-kernel 31,31,15,15,15 \
--num-heads 4,4,4,8,8 \
--st-num-encoder-layers 2,2,2,2,2 \
--st-feedforward-dim 512,512,256,256,256 \
--st-encoder-dim 384,512,256,256,256 \
--st-encoder-unmasked-dim 256,256,256,256,192 \
--st-downsampling-factor 1,2,4,4,4 \
--st-cnn-module-kernel 15,31,31,15,15 \
--st-num-heads 8,8,8,8,8 \
--epoch 25 \
--avg 13 \
--jit 1 \
--output-downsampling-factor 2 \
--st-output-downsampling-factor 1 \
--use-st-joiner True \
--use-hat-decode False \
--use-ctc 1 \
--use-st-ctc 1
It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("jit_script.pt")`.
Check ./jit_pretrained.py for its usage.
Check https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
- For streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--causal 1 \
--chunk-size 32 \
--left-context-frames 128 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`.
You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`.
Check ./jit_pretrained_streaming.py for its usage.
Check https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
(2) Export `model.state_dict()`
- For non-streaming model:
./hent_srt/export.py \
--exp-dir ./hent_srt/exp-st \
--causal 0 \
--use-averaged-model 1 \
--tokens data/lang_bpe_5000/tokens.txt \
--st-tokens data/lang_st_bpe_4000/tokens.txt \
--num-encoder-layers 2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024 \
--encoder-dim 192,256,384,512,384 \
--encoder-unmasked-dim 192,192,256,256,256 \
--downsampling-factor 1,2,4,8,4 \
--cnn-module-kernel 31,31,15,15,15 \
--num-heads 4,4,4,8,8 \
--st-num-encoder-layers 2,2,2,2,2 \
--st-feedforward-dim 512,512,256,256,256 \
--st-encoder-dim 384,512,256,256,256 \
--st-encoder-unmasked-dim 256,256,256,256,192 \
--st-downsampling-factor 1,2,4,4,4 \
--st-cnn-module-kernel 15,31,31,15,15 \
--st-num-heads 8,8,8,8,8 \
--epoch 20 \
--avg 13 \
--jit 0 \
--output-downsampling-factor 2 \
--st-output-downsampling-factor 1 \
--use-st-joiner True \
--use-hat False \
--use-ctc 1 \
--use-st-ctc 1
- For streaming model:
./hent_srt/export.py \
--exp-dir ./hent_srt/exp-st_causal \
--causal 1 \
--use-averaged-model 1 \
--tokens data/lang_bpe_5000/tokens.txt \
--st-tokens data/lang_st_bpe_4000/tokens.txt \
--num-encoder-layers 2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024 \
--encoder-dim 192,256,384,512,384 \
--encoder-unmasked-dim 192,192,256,256,256 \
--downsampling-factor 1,2,4,8,4 \
--cnn-module-kernel 31,31,15,15,15 \
--num-heads 4,4,4,8,8 \
--st-num-encoder-layers 2,2,2,2,2 \
--st-feedforward-dim 512,512,256,256,256 \
--st-encoder-dim 384,512,256,256,256 \
--st-encoder-unmasked-dim 256,256,256,256,192 \
--st-downsampling-factor 1,2,4,4,4 \
--st-cnn-module-kernel 15,31,31,15,15 \
--st-num-heads 8,8,8,8,8 \
--epoch 20 \
--avg 13 \
--jit 0 \
--output-downsampling-factor 2 \
--st-output-downsampling-factor 1 \
--use-st-joiner True \
--use-hat False \
--use-ctc 1 \
--use-st-ctc 1
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
- For non-streaming model:
To use the generated file with `zipformer/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/multi_conv_zh_es_ta/ST
./hent_srt/decode.py \
--epoch 9999 --avg 1 --use-averaged-model 0 \
--beam-size 20 \
--causal 0 \
--exp-dir hent_srt/exp-st \
--bpe-model data/lang_bpe_5000/bpe.model \
--bpe-st-model data/lang_st_bpe_4000/bpe.model \
--output-downsampling-factor 2 \
--st-output-downsampling-factor 1 \
--max-duration 800 \
--num-encoder-layers 2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024 \
--encoder-dim 192,256,384,512,384 \
--encoder-unmasked-dim 192,192,256,256,256 \
--downsampling-factor 1,2,4,8,4 \
--cnn-module-kernel 31,31,15,15,15 \
--num-heads 4,4,4,8,8 \
--st-num-encoder-layers 2,2,2,2,2 \
--st-feedforward-dim 512,512,256,256,256 \
--st-encoder-dim 384,512,256,256,256 \
--st-encoder-unmasked-dim 256,256,256,256,192 \
--st-downsampling-factor 1,2,4,4,4 \
--st-cnn-module-kernel 15,31,31,15,15 \
--st-num-heads 8,8,8,8,8 \
--decoding-method modified_beam_search \
--use-st-joiner True \
--use-hat-decode False \
--use-ctc 1 \
--use-st-ctc 1 \
--st-blank-penalty 1
- For streaming model:
To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/multi_conv_zh_es_ta/ST
./hent_srt/decode.py \
--epoch 9999 --avg 1 --use-averaged-model 0 \
--causal 1 \
--exp-dir hent_srt/exp-st_causal \
--bpe-model data/lang_bpe_5000/bpe.model \
--bpe-st-model data/lang_st_bpe_4000/bpe.model \
--output-downsampling-factor 2 \
--st-output-downsampling-factor 1 \
--max-duration 800 \
--num-encoder-layers 2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024 \
--encoder-dim 192,256,384,512,384 \
--encoder-unmasked-dim 192,192,256,256,256 \
--downsampling-factor 1,2,4,8,4 \
--cnn-module-kernel 31,31,15,15,15 \
--num-heads 4,4,4,8,8 \
--st-num-encoder-layers 2,2,2,2,2 \
--st-feedforward-dim 512,512,256,256,256 \
--st-encoder-dim 384,512,256,256,256 \
--st-encoder-unmasked-dim 256,256,256,256,192 \
--st-downsampling-factor 1,2,4,4,4 \
--st-cnn-module-kernel 15,31,31,15,15 \
--st-num-heads 8,8,8,8,8 \
--decoding-method greedy_search \
--use-st-joiner True \
--use-hat-decode False \
--use-ctc 1 \
--use-st-ctc 1 \
--st-blank-penalty 2 \
--chunk-size 64 \
--left-context-frames 128 \
--use-hat False --max-sym-per-frame 20
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
- non-streaming model:
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
- streaming model:
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
# You will find the pre-trained models in exp dir
"""
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from torch import Tensor, nn
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, num_tokens, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=str,
default="data/lang_bpe_5000/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
"--st-tokens",
type=str,
default="data/lang_st_bpe_4000/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate a file named jit_script.pt.
Check ./jit_pretrained.py for how to use it.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--st-context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
add_model_arguments(parser)
return parser
class EncoderModel(nn.Module):
"""A wrapper for encoder and encoder_embed"""
def __init__(self, model: nn.Module) -> None:
super().__init__()
self.model = model
def forward(
self, features: Tensor, feature_lengths: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Args:
features: (N, T, C)
feature_lengths: (N,)
"""
encoder_out, encoder_out_lens, st_encoder_out, st_encoder_out_lens = model.forward_encoder(feature, feature_lengths)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
st_encoder_out = st_encoder_out.permute(1, 0, 2)
return encoder_out, encoder_out_lens, st_encoder_out, st_encoder_out_lens
class StreamingEncoderModel(nn.Module):
"""A wrapper for encoder and encoder_embed"""
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
super().__init__()
assert len(encoder.chunk_size) == 1, encoder.chunk_size
assert len(encoder.left_context_frames) == 1, encoder.left_context_frames
self.chunk_size = encoder.chunk_size[0]
self.left_context_len = encoder.left_context_frames[0]
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
self.pad_length = 7 + 2 * 3
self.encoder = encoder
self.encoder_embed = encoder_embed
def forward(
self, features: Tensor, feature_lengths: Tensor, states: List[Tensor]
) -> Tuple[Tensor, Tensor, List[Tensor]]:
"""Streaming forward for encoder_embed and encoder.
Args:
features: (N, T, C)
feature_lengths: (N,)
states: a list of Tensors
Returns encoder outputs, output lengths, and updated states.
"""
chunk_size = self.chunk_size
left_context_len = self.left_context_len
cached_embed_left_pad = states[-2]
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
x=features,
x_lens=feature_lengths,
cached_left_pad=cached_embed_left_pad,
)
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
src_key_padding_mask = make_pad_mask(x_lens)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2]
(
encoder_out,
encoder_out_lens,
new_encoder_states,
) = self.encoder.streaming_forward(
x=x,
x_lens=x_lens,
states=encoder_states,
src_key_padding_mask=src_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
new_states = new_encoder_states + [
new_cached_embed_left_pad,
new_processed_lens,
]
return encoder_out, encoder_out_lens, new_states
@torch.jit.export
def get_init_states(
self,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
"""
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
states[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
"""
states = self.encoder.get_init_states(batch_size, device)
embed_states = self.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
states.append(processed_lens)
return states
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
# if torch.cuda.is_available():
# device = torch.device("cuda", 0)
logging.info(f"device: {device}")
token_table = k2.SymbolTable.from_file(params.tokens)
st_token_table = k2.SymbolTable.from_file(params.st_tokens)
params.blank_id = token_table["<blk>"]
# params.unk_id = sp.piece_to_id("<unk>")
# params.st_unk_id = sp_st.piece_to_id("<unk>")
params.blank_st_id = st_token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1
params.vocab_st_size = num_tokens(st_token_table) + 1
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
), strict=False
)
model.eval()
if params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
# Wrap encoder and encoder_embed as a module
if params.causal:
model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed)
chunk_size = model.encoder.chunk_size
left_context_len = model.encoder.left_context_len
filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt"
else:
model.encoder = EncoderModel(model)
filename = "jit_script.pt"
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
model.save(str(params.exp_dir / filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torchscript. Export model.state_dict()")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../zipformer_multijoiner_st/joiner.py

View File

@ -0,0 +1,117 @@
# Copyright 2025 Johns Hopkins University (author: Amir Hussein)
import logging
from typing import Any, Dict, Union
import torch
import torch.nn
import torch.optim
def filter_state_dict(
dst_state: Dict[str, Union[float, torch.Tensor]],
src_state: Dict[str, Union[float, torch.Tensor]],
):
"""Filter name, size mismatch instances between dicts.
Args:
dst_state: reference state dict for filtering
src_state: target state dict for filtering
"""
match_state = {}
for key, value in src_state.items():
if key in dst_state and (dst_state[key].size() == src_state[key].size()):
match_state[key] = value
else:
if key not in dst_state:
logging.warning(
f"Filter out {key} from pretrained dict"
+ " because of name not found in target dict"
)
else:
logging.warning(
f"Filter out {key} from pretrained dict"
+ " because of size mismatch"
+ f"({dst_state[key].size()}-{src_state[key].size()})"
)
return match_state
def load_pretrained_model(
init_param: str,
model: torch.nn.Module,
ignore_init_mismatch: bool,
map_location: str = "cpu",
):
"""Load a model state and set it to the model.
Args:
init_param: <file_path>:<src_key>:<dst_key>:<exclude_Keys>
Examples:
>>> load_pretrained_model("somewhere/model.pth", model)
>>> load_pretrained_model("somewhere/model.pth:decoder:decoder", model)
>>> load_pretrained_model("somewhere/model.pth:decoder:decoder:", model)
>>> load_pretrained_model(
... "somewhere/model.pth:decoder:decoder:decoder.embed", model
... )
>>> load_pretrained_model("somewhere/decoder.pth::decoder", model)
"""
sps = init_param.split(":", 4)
if len(sps) == 4:
path, src_key, dst_key, excludes = sps
elif len(sps) == 3:
path, src_key, dst_key = sps
excludes = None
elif len(sps) == 2:
path, src_key = sps
dst_key, excludes = None, None
else:
(path,) = sps
src_key, dst_key, excludes = None, None, None
if src_key == "":
src_key = None
if dst_key == "":
dst_key = None
if dst_key is None:
obj = model
else:
def get_attr(obj: Any, key: str):
"""Get an nested attribute.
>>> class A(torch.nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = torch.nn.Linear(10, 10)
>>> a = A()
>>> assert A.linear.weight is get_attr(A, 'linear.weight')
"""
if key.strip() == "":
return obj
for k in key.split("."):
obj = getattr(obj, k)
return obj
obj = get_attr(model, dst_key)
src_state = torch.load(path, map_location=map_location)
if excludes is not None:
for e in excludes.split(","):
src_state = {k: v for k, v in src_state.items() if not k.startswith(e)}
if src_key is not None:
src_state['model'] = {
k[len(src_key) + 1 :]: v
for k, v in src_state['model'].items()
if k.startswith(src_key)
}
dst_state = obj.state_dict()
if ignore_init_mismatch:
src_state = filter_state_dict(dst_state, src_state['model'])
dst_state.update(src_state)
obj.load_state_dict(dst_state)

View File

@ -0,0 +1,811 @@
# Copyright 2025 Johns Hopkins University (author: Amir Hussein)
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import k2
import torch
from torch import Tensor
from lhotse.dataset import SpecAugment
import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos, make_pad_mask, time_warp
class HENT_SRT(nn.Module):
def __init__(
self,
encoder_embed: nn.Module,
encoder: EncoderInterface,
decoder: Optional[nn.Module] = None,
joiner: Optional[nn.Module] = None,
st_joiner: Optional[nn.Module] = None,
st_decoder: Optional[nn.Module] = None,
st_encoder: Optional[nn.Module] = None,
encoder_dim: int = 384,
st_encoder_dim: int = 384,
decoder_dim: int = 512,
vocab_size: int = 500,
st_vocab_size: int = 500,
use_transducer: bool = True,
use_ctc: bool = False,
use_st_ctc: bool = False,
use_hat: bool = False,
use_lstm_pred:bool=False,
):
"""A multitask Transducer ASR-ST model with seperate joiners and predictors but shared acoustic encoder.
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
- Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
- Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
Args:
encoder_embed:
It is a Convolutional 2D subsampling module. It converts
an input of shape (N, T, idim) to an output of of shape
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
It is used when use_transducer is True.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
It is used when use_transducer is True.
use_transducer:
Whether use transducer head. Default: True.
use_ctc:
Whether use CTC head. Default: False.
"""
super().__init__()
assert (
use_transducer or use_ctc
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
assert isinstance(encoder, EncoderInterface), type(encoder)
self.encoder_embed = encoder_embed
self.encoder = encoder
self.use_hat = use_hat
self.use_lstm_pred = use_lstm_pred
self.use_transducer = use_transducer
if use_transducer:
# Modules for Transducer head
assert decoder is not None
assert hasattr(decoder, "blank_id")
assert joiner is not None
self.decoder = decoder
self.joiner = joiner
self.st_joiner = st_joiner
self.st_decoder = st_decoder
self.st_encoder = st_encoder
self.simple_am_proj = ScaledLinear(
encoder_dim, vocab_size, initial_scale=0.25
)
self.simple_lm_proj = ScaledLinear(
decoder_dim, vocab_size, initial_scale=0.25
)
self.simple_st_am_proj = ScaledLinear(
st_encoder_dim, st_vocab_size, initial_scale=0.25
)
self.simple_st_lm_proj = ScaledLinear(
decoder_dim, st_vocab_size, initial_scale=0.25
)
else:
assert decoder is None
assert joiner is None
self.use_ctc = use_ctc
self.use_st_ctc = use_st_ctc
if self.use_ctc:
# Modules for CTC head
self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)
if self.use_st_ctc:
# Modules for CTC head
self.st_ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(st_encoder_dim, st_vocab_size),
nn.LogSoftmax(dim=-1),
)
def forward_encoder(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute encoder outputs.
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
Returns:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
"""
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x, x_lens = self.encoder_embed(x, x_lens)
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens, st_input = self.encoder(x, x_lens, src_key_padding_mask)
if self.st_encoder is not None:
st_src_key_padding_mask = make_pad_mask(encoder_out_lens)
st_encoder_out, st_encoder_out_lens = self.st_encoder(
st_input, x_lens, src_key_padding_mask
)
st_encoder_out = st_encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
else:
st_encoder_out_lens = None
st_encoder_out = None
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
return encoder_out, encoder_out_lens, st_encoder_out, st_encoder_out_lens
def forward_st_ctc(
self,
st_encoder_out: torch.Tensor,
st_encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> torch.Tensor:
"""Compute CTC loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
targets:
Target Tensor of shape (sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC log-prob
ctc_output = self.st_ctc_output(st_encoder_out) # (N, T, C)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets,
input_lengths=st_encoder_out_lens,
target_lengths=target_lengths,
reduction="sum",
)
return ctc_loss
def forward_ctc(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> torch.Tensor:
"""Compute CTC loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
targets:
Target Tensor of shape (sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets,
input_lengths=encoder_out_lens,
target_lengths=target_lengths,
reduction="sum",
)
return ctc_loss
def forward_cr_ctc(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute CTC loss with consistency regularization loss.
Args:
encoder_out:
Encoder output, of shape (2 * N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (2 * N,).
targets:
Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC loss
ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C)
targets=targets.cpu(),
input_lengths=encoder_out_lens.cpu(),
target_lengths=target_lengths.cpu(),
reduction="none",
)
ctc_loss_is_finite = torch.isfinite(ctc_loss)
ctc_loss = ctc_loss[ctc_loss_is_finite]
ctc_loss = ctc_loss.sum()
# Compute consistency regularization loss
exchanged_targets = ctc_output.detach().chunk(2, dim=0)
exchanged_targets = torch.cat(
[exchanged_targets[1], exchanged_targets[0]], dim=0
) # exchange: [x1, x2] -> [x2, x1]
cr_loss = nn.functional.kl_div(
input=ctc_output,
target=exchanged_targets,
reduction="none",
log_target=True,
) # (2 * N, T, C)
length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1)
cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum()
return ctc_loss, cr_loss
def forward_st_cr_ctc(
self,
st_encoder_out: torch.Tensor,
st_encoder_out_lens: torch.Tensor,
st_targets: torch.Tensor,
st_target_lengths: torch.Tensor,
# encoder_out: torch.Tensor,
# encoder_out_lens: torch.Tensor,
# targets: torch.Tensor,
# target_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute CTC loss with consistency regularization loss.
Args:
encoder_out:
Encoder output, of shape (2 * N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (2 * N,).
targets:
Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC loss
st_ctc_output = self.st_ctc_output(st_encoder_out) # (2 * N, T, C)
st_ctc_loss = torch.nn.functional.ctc_loss(
log_probs=st_ctc_output.permute(1, 0, 2), # (T, 2 * N, C)
targets=st_targets.cpu(),
input_lengths=st_encoder_out_lens.cpu(),
target_lengths=st_target_lengths.cpu(),
reduction="none",
)
st_ctc_loss_is_finite = torch.isfinite(st_ctc_loss)
st_ctc_loss = st_ctc_loss[st_ctc_loss_is_finite]
st_ctc_loss = st_ctc_loss.sum()
# ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C)
# ctc_loss = torch.nn.functional.ctc_loss(
# log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C)
# targets=targets.cpu(),
# input_lengths=encoder_out_lens.cpu(),
# target_lengths=target_lengths.cpu(),
# reduction="sum",
# )
# if not torch.isfinite(st_ctc_loss):
# breakpoint()
# Compute consistency regularization loss
exchanged_targets = st_ctc_output.detach().chunk(2, dim=0)
exchanged_targets = torch.cat(
[exchanged_targets[1], exchanged_targets[0]], dim=0
) # exchange: [x1, x2] -> [x2, x1]
cr_loss = nn.functional.kl_div(
input=st_ctc_output,
target=exchanged_targets,
reduction="none",
log_target=True,
) # (2 * N, T, C)
length_mask = make_pad_mask(st_encoder_out_lens).unsqueeze(-1)
cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum()
return st_ctc_loss, cr_loss
def forward_st_transducer(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
st_encoder_out: torch.Tensor,
st_encoder_out_lens: torch.Tensor,
y: k2.RaggedTensor,
y_lens: torch.Tensor,
st_y: k2.RaggedTensor,
st_y_lens: torch.Tensor,
prune_range: int = 5,
st_prune_range: int = 10,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
"""Compute Transducer loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
"""
# Now for the decoder, i.e., the prediction network
blank_id = self.decoder.blank_id
st_blank_id = self.st_decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
st_sos_y = add_sos(st_y, sos_id=st_blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
st_sos_y_padded = st_sos_y.pad(mode="constant", padding_value=st_blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
if self.use_lstm_pred:
st_decoder_out, _ = self.st_decoder(st_sos_y_padded)
else:
st_decoder_out = self.st_decoder(st_sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(encoder_out.size(0), 4),
dtype=torch.int64,
device=encoder_out.device,
)
boundary[:, 2] = y_lens
boundary[:, 3] = encoder_out_lens
st_y_padded = st_y.pad(mode="constant", padding_value=0)
st_y_padded = st_y_padded.to(torch.int64)
st_boundary = torch.zeros(
(encoder_out.size(0), 4),
dtype=torch.int64,
device=encoder_out.device,
)
st_boundary[:, 2] = st_y_lens
st_boundary[:, 3] = st_encoder_out_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
st_lm = self.simple_st_lm_proj(st_decoder_out)
st_am = self.simple_st_am_proj(st_encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
st_simple_loss, (st_px_grad, st_py_grad) = k2.rnnt_loss_smoothed(
lm=st_lm.float(),
am=st_am.float(),
symbols=st_y_padded,
termination_symbol=st_blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=st_boundary,
reduction="sum",
return_grad=True,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
use_hat_loss=self.use_hat,
)
# logits : [B, T, prune_range, vocab_size]
st_ranges = k2.get_rnnt_prune_ranges(
px_grad=st_px_grad,
py_grad=st_py_grad,
boundary=st_boundary,
s_range=st_prune_range,
)
st_am_pruned, st_lm_pruned = k2.do_rnnt_pruning(
am=self.st_joiner.encoder_proj(st_encoder_out),
lm=self.st_joiner.decoder_proj(st_decoder_out),
ranges=st_ranges,
)
st_logits = self.st_joiner(st_am_pruned, st_lm_pruned, project_input=False)
# Compute HAT loss for st
with torch.cuda.amp.autocast(enabled=False):
pruned_st_loss = k2.rnnt_loss_pruned(
logits=st_logits.float(),
symbols=st_y.pad(mode="constant", padding_value=blank_id).to(torch.int64),
ranges=st_ranges,
termination_symbol=st_blank_id,
boundary=st_boundary,
reduction="sum",
use_hat_loss=self.use_hat,
)
return simple_loss, st_simple_loss, pruned_loss, pruned_st_loss
def forward_transducer(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
y: k2.RaggedTensor,
y_lens: torch.Tensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
"""Compute Transducer loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
"""
# Now for the decoder, i.e., the prediction network
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(encoder_out.size(0), 4),
dtype=torch.int64,
device=encoder_out.device,
)
boundary[:, 2] = y_lens
boundary[:, 3] = encoder_out_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
use_hat_loss=self.use_hat,
)
# logits : [B, T, prune_range, vocab_size]
return simple_loss, pruned_loss
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
st_y: k2.RaggedTensor,
prune_range: int = 5,
st_prune_range: int =10,
am_scale: float = 0.0,
lm_scale: float = 0.0,
use_st_cr_ctc: bool = False,
use_asr_cr_ctc: bool = False,
use_spec_aug: bool = False,
spec_augment: Optional[SpecAugment] = None,
supervision_segments: Optional[torch.Tensor] = None,
time_warp_factor: Optional[int] = 80,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
use_cr_ctc:
Whether use consistency-regularized CTC.
use_spec_aug:
Whether apply spec-augment manually, used only if use_cr_ctc is True.
spec_augment:
The SpecAugment instance that returns time masks,
used only if use_cr_ctc is True.
supervision_segments:
An int tensor of shape ``(S, 3)``. ``S`` is the number of
supervision segments that exist in ``features``.
Used only if use_cr_ctc is True.
time_warp_factor:
Parameter for the time warping; larger values mean more warping.
Set to ``None``, or less than ``1``, to disable.
Used only if use_cr_ctc is True.
Returns:
Return the transducer losses and CTC loss,
in form of (simple_loss, pruned_loss, ctc_loss)
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert st_y.num_axes == 2, st_y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
device = x.device
if use_st_cr_ctc or use_asr_cr_ctc:
assert self.use_ctc or self.use_st_ctc
if use_spec_aug:
assert spec_augment is not None and spec_augment.time_warp_factor < 1
# Apply time warping before input duplicating
assert supervision_segments is not None
x = time_warp(
x,
time_warp_factor=time_warp_factor,
supervision_segments=supervision_segments,
)
# Independently apply frequency masking and time masking to the two copies
x = spec_augment(x.repeat(2, 1, 1))
else:
x = x.repeat(2, 1, 1)
x_lens = x_lens.repeat(2)
y = k2.ragged.cat([y, y], axis=0)
if self.st_joiner != None and self.use_st_ctc:
st_y = k2.ragged.cat([st_y, st_y], axis=0)
# Compute encoder outputs
encoder_out, encoder_out_lens, st_encoder_out, st_encoder_out_lens = self.forward_encoder(x, x_lens)
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
st_row_splits = st_y.shape.row_splits(1)
st_y_lens = st_row_splits[1:] - st_row_splits[:-1]
if self.use_transducer:
# Compute transducer loss
if self.st_joiner != None:
simple_loss, st_simple_loss, pruned_loss, st_pruned_loss = self.forward_st_transducer(
st_encoder_out=st_encoder_out,
st_encoder_out_lens=st_encoder_out_lens,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
y=y.to(x.device),
y_lens=y_lens,
st_y=st_y.to(x.device),
st_y_lens=st_y_lens,
prune_range=st_prune_range,
am_scale=am_scale,
lm_scale=lm_scale,
)
if use_asr_cr_ctc:
simple_loss = simple_loss * 0.5
pruned_loss = pruned_loss * 0.5
if use_st_cr_ctc:
st_simple_loss = st_simple_loss * 0.5
st_pruned_loss = st_pruned_loss * 0.5
else:
simple_loss, pruned_loss = self.forward_transducer(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
y=y.to(x.device),
y_lens=y_lens,
prune_range=prune_range,
am_scale=am_scale,
lm_scale=lm_scale,
)
if use_asr_cr_ctc:
simple_loss = simple_loss * 0.5
pruned_loss = pruned_loss * 0.5
st_simple_loss, st_pruned_loss = torch.empty(0), torch.empty(0)
else:
simple_loss = torch.empty(0)
pruned_loss = torch.empty(0)
if self.use_ctc:
# Compute CTC loss
targets = y.values
if not use_asr_cr_ctc:
ctc_loss = self.forward_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
)
cr_loss = torch.empty(0)
else:
ctc_loss, cr_loss = self.forward_cr_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
)
ctc_loss = ctc_loss * 0.5
cr_loss = cr_loss * 0.5
else:
ctc_loss = torch.empty(0)
cr_loss = torch.empty(0)
if self.use_st_ctc:
# Compute CTC loss
st_targets = st_y.values
if not use_st_cr_ctc:
st_ctc_loss = self.forward_st_ctc(
st_encoder_out=st_encoder_out,
st_encoder_out_lens=st_encoder_out_lens,
targets=st_targets,
target_lengths=st_y_lens,
)
st_cr_loss = torch.empty(0)
else:
st_ctc_loss, st_cr_loss = self.forward_st_cr_ctc(
st_encoder_out=st_encoder_out,
st_encoder_out_lens=st_encoder_out_lens,
st_targets=st_targets,
st_target_lengths=st_y_lens,
# encoder_out=encoder_out,
# encoder_out_lens=encoder_out_lens,
# targets=targets,
# target_lengths=y_lens,
)
st_ctc_loss = st_ctc_loss * 0.5
st_cr_loss = st_cr_loss * 0.5
else:
st_ctc_loss = torch.empty(0)
st_cr_loss = torch.empty(0)
return simple_loss, st_simple_loss, pruned_loss, st_pruned_loss, ctc_loss, st_ctc_loss, cr_loss, st_cr_loss

View File

@ -0,0 +1 @@
../zipformer_multijoiner_st/optim.py

View File

@ -0,0 +1 @@
../zipformer_multijoiner_st/profile.py

View File

@ -0,0 +1 @@
../zipformer_multijoiner_st/scaling.py

View File

@ -0,0 +1 @@
../zipformer_multijoiner_st/scaling_converter.py

View File

@ -0,0 +1,244 @@
# Copyright 2025 Johns Hopkins University (author: Amir Hussein)
# Copyright 2022 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import List
import k2
import torch
import torch.nn as nn
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from decode_stream import DecodeStream
from icefall.decode import one_best_decoding
from icefall.utils import get_texts
def greedy_search_st(
model: nn.Module,
encoder_out: torch.Tensor,
encoder_out_st: torch.Tensor,
max_sym_per_frame: int,
streams: List[DecodeStream],
st_blank_penalty: float = 0.0,
) -> None:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
streams:
A list of Stream objects.
"""
assert len(streams) == encoder_out_st.size(0)
assert encoder_out_st.ndim == 3
# ST
blank_id_st = model.st_decoder.blank_id
context_size_st = model.st_decoder.context_size
unk_id_st = getattr(model, "unk_id", blank_id_st)
device = model.device
T = encoder_out_st.size(1)
# ASR
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
unk_id = getattr(model, "unk_id", blank_id_st)
#ST
decoder_input_st = torch.tensor(
[stream.hyp_st[-context_size_st:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, 1, decoder_out_dim)
decoder_out_st = model.st_decoder(decoder_input_st, need_pad=False)
decoder_out_st = model.st_joiner.decoder_proj(decoder_out_st)
# ASR
decoder_input = torch.tensor(
[stream.hyp_asr[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# Maximum symbols per utterance.
max_sym_per_utt = 10000
# symbols per frame
sym_per_frame = 0
# symbols per utterance decoded so far
sym_per_utt = 0
t = 0
# for t in range(T):
while t < T and sym_per_utt < max_sym_per_utt:
if sym_per_frame >= max_sym_per_frame:
sym_per_frame = 0
t += 1
continue
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
# current_encoder_out_st = encoder_out_st[:, t : t + 1, :] # noqa
current_encoder_out_st = encoder_out_st[:, t : t + 1, :].unsqueeze(2)
st_logits = model.st_joiner(
current_encoder_out_st,
decoder_out_st.unsqueeze(1),
project_input=False,
)
# logits'shape (batch_size, vocab_size)
st_logits = st_logits.squeeze(1).squeeze(1)
if st_blank_penalty != 0.0:
st_logits[:, 0] -= st_blank_penalty
assert st_logits.ndim == 2, st_logits.shape
y_st = st_logits.argmax(dim=1).tolist()
for i, v in enumerate(y_st):
if v not in (blank_id_st, unk_id_st):
streams[i].hyp_st.append(v)
# update decoder output
# decoder_input_st = torch.tensor(
# [stream.hyp_st[-context_size_st:].reshape(
# 1, context_size_st) for stream in streams],
# device=device,
# dtype=torch.int64,
# )
decoder_input_st = torch.stack([
torch.tensor(stream.hyp_st[-context_size_st:], device=device, dtype=torch.int64)
for stream in streams]).reshape(len(streams), context_size_st)
decoder_out_st = model.st_decoder(
decoder_input_st,
need_pad=False,
)
decoder_out_st = model.st_joiner.decoder_proj(decoder_out_st)
sym_per_utt += 1
sym_per_frame += 1
else:
sym_per_frame = 0
t += 1
def modified_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
num_active_paths: int = 4,
blank_penalty: float = 0.0,
) -> None:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The RNN-T model.
encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
streams:
A list of stream objects.
num_active_paths:
Number of active paths during the beam search.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert len(streams) == encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size = len(streams)
T = encoder_out.size(1)
B = [stream.hyps for stream in streams]
for t in range(T):
current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.stack(
[hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, encoder_out_dim)
logits = model.joiner(current_encoder_out, decoder_out, project_input=False)
# logits is of shape (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1)
if blank_penalty != 0.0:
logits[:, 0] -= blank_penalty
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
for i in range(batch_size):
streams[i].hyps = B[i]

View File

@ -0,0 +1 @@
../zipformer_multijoiner_st/subsampling.py

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,73 @@
import argparse
import jiwer
import os
import re
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dec-file",
type=str,
help="file with decoded text"
)
return parser
def contains_chinese(text):
"""
Check if the given text contains any Chinese characters.
Args:
text (str): The input string.
Returns:
bool: True if the string contains at least one Chinese character, False otherwise.
"""
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]')
return bool(chinese_char_pattern.search(text))
def cer_(file):
hyp = []
ref = []
cer_results = 0
ref_lens = 0
with open(file, 'r', encoding='utf-8') as dec:
for line in dec:
id, target = line.split('\t')
id = id[0:-1]
target, txt = target.split("=")
if target == 'ref':
words = txt.strip().strip('[]').split(', ')
word_list = [word.strip("'") for word in words]
# if contains_chinese(" ".join(word_list)):
# word_list = [" ".join(re.findall(r".",word.strip("'"))) for word in words]
ref.append("".join(word_list))
elif target == 'hyp':
words = txt.strip().strip('[]').split(', ')
word_list = [word.strip("'") for word in words]
# if contains_chinese(" ".join(word_list)):
# word_list = ["".join(re.findall(r".",word.strip("'"))) for word in words]
hyp.append("".join(word_list))
for h, r in zip(hyp, ref):
if r:
cer_results += (jiwer.cer(r, h)*len(r))
ref_lens += len(r)
#print(os.path.basename(file))
print(cer_results / ref_lens)
def main():
parse = get_args()
args = parse.parse_args()
cer_(args.dec_file)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,175 @@
#!/usr/bin/env python3
# Johns Hopkins University (authors: Amir Hussein)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
import os
from pathlib import Path
import argparse
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatFrameOptions,
KaldifeatMelOptions,
)
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-splits",
type=int,
default=5,
help="Number of splits for the train set.",
)
parser.add_argument(
"--start",
type=int,
default=0,
help="Start index of the train set split.",
)
parser.add_argument(
"--stop",
type=int,
default=-1,
help="Stop index of the train set split.",
)
parser.add_argument(
"--test",
action="store_true",
help="If set, only compute features for the dev and val set.",
)
parser.add_argument(
"--datadir",
type=str,
help="Manifests datadir",
)
return parser.parse_args()
def compute_fbank_gpu(args):
src_dir = Path(args.datadir+"/manifests")
output_dir = Path(args.datadir+"/fbank")
num_jobs = min(os.cpu_count(),10)
num_mel_bins = 80
sampling_rate = 16000
sr = 16000
logging.info(f"Cpus {num_jobs}")
if args.test:
dataset_parts = (
"hkust_test",
"iwslt_ta_test",
"fisher-sp_test",
"dev")
else:
dataset_parts = (
"train")
prefix = "cts"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
prefix=prefix, dataset_parts=dataset_parts, output_dir=src_dir,suffix=suffix,
)
assert manifests is not None
extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)
for partition, m in manifests.items():
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
if (output_dir / f"{cuts_filename}").is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
logging.info("About to split cuts into smaller chunks.")
if sr != None:
logging.info(f"Resampling to {sr}")
cut_set = cut_set.resample(sr)
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False,
keep_all_channels=False)
cut_set = cut_set.filter(lambda c: c.duration >= .2 and c.duration <= 30)
if "train" in partition:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
manifest_path=f"{src_dir}/{cuts_filename}",
batch_duration=2000,
num_workers=num_jobs,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz")
else:
logging.info(f"Processing {partition}")
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
manifest_path=f"{src_dir}/{cuts_filename}",
batch_duration=2000,
num_workers=num_jobs,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_gpu(args)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compute_fbank_musan.py

View File

@ -0,0 +1,97 @@
#!/usr/bin/python
from lhotse import RecordingSet, SupervisionSet, CutSet
import argparse
import logging
from lhotse.qa import fix_manifests, validate_recordings_and_supervisions
import pdb
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--sup",
type=str,
default="",
help="Supervisions file",
)
parser.add_argument(
"--rec",
type=str,
default="",
help="Recordings file",
)
parser.add_argument(
"--cut",
type=str,
default="",
help="Cutset file",
)
parser.add_argument(
"--savecut",
type=str,
default="",
help="name of the cutset to be saved",
)
return parser
def valid_asr(cut):
tol = 2e-3
i=0
total_dur = 0
for c in cut:
if c.supervisions != []:
if c.supervisions[0].end > c.duration + tol:
logging.info(f"Supervision beyond the cut. Cut number: {i}")
total_dur += c.duration
logging.info(f"id: {c.id}, sup_end: {c.supervisions[0].end}, dur: {c.duration}, source {c.recording.sources[0].source}")
elif c.supervisions[0].start < -tol:
logging.info(f"Supervision starts before the cut. Cut number: {i}")
logging.info(f"id: {c.id}, sup_start: {c.supervisions[0].start}, dur: {c.duration}, source {c.recording.sources[0].source}")
else:
continue
else:
logging.info("Empty supervision")
logging.info(f"id: {c.id}")
i += 1
logging.info(f"filtered duration: {total_dur}")
def main():
parser = get_parser()
args = parser.parse_args()
if args.cut != "":
cuts = CutSet.from_file(args.cut)
else:
recordings = RecordingSet.from_file(args.rec)
supervisions = SupervisionSet.from_file(args.sup)
# breakpoint()
logging.info("Example from supervisions:")
logging.info(supervisions[0])
logging.info("Example from recordings")
logging.info("Fixing manifests")
recordings, supervisions = fix_manifests(recordings, supervisions)
logging.info("Validating manifests")
validate_recordings_and_supervisions(recordings, supervisions)
cuts = CutSet.from_manifests(recordings= recordings, supervisions=supervisions,)
cuts = cuts.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
cuts.describe()
logging.info("Example from cut:")
logging.info(cuts[100])
logging.info("Validating manifests for ASR")
valid_asr(cuts)
if args.savecut != "":
cuts.to_file(args.savecut)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,53 @@
# Copyright 2023 Johns Hopkins University (Amir Hussein)
#!/usr/bin/python
"""
This script prepares transcript_words.txt from cutset
"""
from lhotse import CutSet
import argparse
import logging
import pdb
from pathlib import Path
import os
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--cut",
type=str,
default="",
help="Cutset file",
)
parser.add_argument(
"--langdir",
type=str,
default="",
help="name of the lang-dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
logging.info("Reading the cuts")
cuts = CutSet.from_file(args.cut)
langdir = Path(args.langdir)
if not os.path.exists(langdir):
os.makedirs(langdir)
with open(langdir / "st_words.txt", 'w') as txt:
for c in cuts:
text = c.supervisions[0].custom['translated_text']['en']
txt.write(text + '\n')
if __name__ == "__main__":
main()

View File

@ -0,0 +1,54 @@
# Copyright 2023 Johns Hopkins University (Amir Hussein)
#!/usr/bin/python
"""
This script prepares transcript_words.txt from cutset
"""
from lhotse import CutSet
import argparse
import logging
import pdb
from pathlib import Path
import os
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--cut",
type=str,
default="",
help="Cutset file",
)
parser.add_argument(
"--langdir",
type=str,
default="",
help="name of the lang-dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
logging.info("Reading the cuts")
cuts = CutSet.from_file(args.cut)
langdir = Path(args.langdir)
if not os.path.exists(langdir):
os.makedirs(langdir)
with open(langdir / "transcript_words.txt", 'w') as txt:
for c in cuts:
#breakpoint()
text = c.supervisions[0].text
txt.write(text + '\n')
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/train_bpe_model.py

View File

@ -0,0 +1,175 @@
#!/usr/bin/env bash
# Copyright 2023 Johns Hopkins University (Amir Hussein)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
set -eou pipefail
nj=20
stage=0
stop_stage=7
# We assume dl_dir (download dir) contains the following
# directories and files.
#
# - $dl_dir/cts
#
# You can download the data from
#
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
#
dl_dir=cts
. shared/parse_options.sh || exit 1
# vocab size for sentence piece models.
# It will generate data/lang_bpe_xxx,
# data/lang_bpe_yyy if the array contains xxx, yyy
vocab_sizes=(
5000
)
st_vocab_sizes=(
4000
)
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# Download callhome_spanish, fisher_spanish iwslt22_ta and HKUST from LDC
#
# you can create a symlink
#
# ln -sfv /path/to/data $dl_dir/data
# If you have pre-downloaded it to /path/to/musan,
# you can create a symlink
#
# ln -sfv /path/to/musan $dl_dir/
#
if [ ! -d $dl_dir/musan ]; then
lhotse download musan $dl_dir
fi
fi
fbank=data/fbank
manifests=data/manifests
mkdir -p $manifests
sets="hkust iwslt-ta callhome-sp fisher-sp"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Prepare telephone manifest"
# We assume that you have downloaded callhome_spanish, fisher_spanish iwslt22_ta and hkust to $dl_dir/
for set in $sets; do
log "Prepare $set manifests"
if [[ "$set" == "iwslt-ta" ]]; then
if [ ! -d "iwslt22-dialect" ]; then
echo "Splits directory (iwslt22-dialect) does not exist"
echo "Run: git clone https://github.com/kevinduh/iwslt22-dialect"
exit 1
else
lhotse prepare "$set" "$dl_dir/$set" iwslt22-dialect "$manifests"
fi
else
lhotse prepare "$set" "$dl_dir/$set" "$manifests"
# validate recordings and supervisions
fi
# python local/cuts_validate.py \
# --sup "${manifests}/supervisions.jsonl.gz" \
# --rec "${manifests}/recordings.jsonl.gz" \
# --savecut "${manifests}/cuts_${set}.jsonl.gz"
done
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
if [ ! -f ${manifests}/cut_train.jsonl.gz ]; then
log "Combining conversational data to create train, dev sets"
# combine train
lhotse combine $manifests/iwslt-ta_supervisions_train.jsonl.gz $manifests/hkust_supervisions_train.jsonl.gz $manifests/fisher-sp_supervisions_train.jsonl.gz ${manifests}/cts_supervisions_train.jsonl.gz
lhotse combine $manifests/iwslt-ta_recordings_train.jsonl.gz $manifests/hkust_recordings_train.jsonl.gz $manifests/fisher-sp_recordings_train.jsonl.gz ${manifests}/cts_recordings_train.jsonl.gz
# python local/cuts_validate.py --sup $manifests/cts_supervisions_train.jsonl.gz --rec ${manifests}/cts_recordings_train.jsonl.gz
# combine dev
lhotse combine $manifests/iwslt-ta_supervisions_dev1.jsonl.gz $manifests/hkust_supervisions_dev1.jsonl.gz $manifests/fisher-sp_supervisions_dev.jsonl.gz ${manifests}/cts_supervisions_dev.jsonl.gz
lhotse combine $manifests/iwslt-ta_recordings_dev1.jsonl.gz $manifests/fisher-spanish_recordings_dev.jsonl.gz $manifests/hkust_recordings_dev1.jsonl.gz $manifests/fisher-sp_recordings_dev.jsonl.gz ${manifests}/cts_recordings_dev.jsonl.gz
# python local/cuts_validate.py --sup ${manifests}/cts_supervisions_dev.jsonl.gz --rec ${manifests}/cts_recordings_dev.jsonl.gz
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to data/musan
if [ ! -f ${manifests}/musan_recordings_speech.jsonl.gz ]; then
mkdir -p $manifests
lhotse prepare musan $dl_dir/musan $manifests
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank features"
mkdir -p ${fbank}
./local/compute_fbank_gpu.py
./local/compute_fbank_gpu.py --test
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan"
./local/compute_fbank_musan.py
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare BPE based lang"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
mkdir -p ${lang_dir}
cp data/lang_phone/words.txt $lang_dir
if [ ! -f $lang_dir/transcript_words.txt ]; then
log "Generate text for BPE training from data/fbank/cuts_train.jsonl.gz"
python local/prepare_transcripts.py --cut ${fbank}/cuts_train.jsonl.gz --langdir ${lang_dir}
fi
./local/train_bpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript $lang_dir/transcript_words.txt
done
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Prepare BPE ST based lang"
for vocab_size in ${st_vocab_sizes[@]}; do
lang_dir=data/lang_st_bpe_${vocab_size}
mkdir -p ${lang_dir}
if [ ! -f $lang_dir/st_words.txt ]; then
log "Generate text for BPE training from data/fbank/cuts_train.jsonl.gz"
python local/prepare_st_transcripts.py --cut ${fbank}/cuts_train.jsonl.gz --langdir ${lang_dir}
fi
./local/train_bpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript $lang_dir/st_words.txt
done
fi

View File

@ -0,0 +1 @@
../../../icefall/shared

View File

@ -0,0 +1,421 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
# Copyright 2025 Johns Hopkins University (Author: Amir Hussein)
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
K2Speech2TextTranslationDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class MultiLingAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=10,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2Speech2TextTranslationDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2Speech2TextTranslationDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2Speech2TextTranslationDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2Speech2TextTranslationDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2Speech2TextTranslationDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("Train data: About to get training cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_train.jsonl.gz"
)
@lru_cache()
def dev_all_cuts(self) -> CutSet:
logging.info("Dev data: About to get develop cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_dev.jsonl.gz"
)
@lru_cache()
def test_hkust(self) -> CutSet:
logging.info("About to get test-hkust cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_hkust_test.jsonl.gz"
)
def test_iwslt22(self) -> CutSet:
logging.info("About to get test-iwslt22 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_iwslt_ta_test.jsonl.gz"
)
def test_fisher(self) -> CutSet:
logging.info("About to get test-fisher cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_fisher-sp_test.jsonl.gz"
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,131 @@
# Copyright 2025 Johns Hopkins University (author: Amir Hussein)
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from scaling import Balancer
class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""
def __init__(
self,
vocab_size: int,
decoder_dim: int,
blank_id: int,
context_size: int,
):
"""
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
decoder_dim:
Dimension of the input embedding, and of the decoder output.
blank_id:
The ID of the blank symbol.
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
"""
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=decoder_dim,
)
# the balancers are to avoid any drift in the magnitude of the
# embeddings, which would interact badly with parameter averaging.
self.balancer = Balancer(
decoder_dim,
channel_dim=-1,
min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
self.blank_id = blank_id
assert context_size >= 1, context_size
self.context_size = context_size
self.vocab_size = vocab_size
if context_size > 1:
self.conv = nn.Conv1d(
in_channels=decoder_dim,
out_channels=decoder_dim,
kernel_size=context_size,
padding=0,
groups=decoder_dim // 4, # group size == 4
bias=False,
)
self.balancer2 = Balancer(
decoder_dim,
channel_dim=-1,
min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
need_pad:
True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference.
Returns:
Return a tensor of shape (N, U, decoder_dim).
"""
y = y.to(torch.int64)
# this stuff about clamp() is a temporary fix for a mismatch
# at utterance start, we use negative ids in beam_search.py
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
embedding_out = self.balancer(embedding_out)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = F.relu(embedding_out)
embedding_out = self.balancer2(embedding_out)
return embedding_out

View File

@ -0,0 +1,590 @@
#!/usr/bin/env python3
# Copyright 2025 Johns Hopkins University (author: Amir Hussein)
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
Note: This is a example for librispeech dataset, if you are using different
dataset, you should change the argument values according to your dataset.
(1) Export to torchscript model using torch.jit.script()
- For non-streaming model:
./zipformer/export.py \
--exp-dir ./zipformer_hat_st/exp-multi-joiner-nohat-pbe4k \
--bpe-st-model data/lang_st_bpe_4000/bpe.model \
--bpe-model data/lang_bpe_5000/bpe.model \
--epoch 25 \
--avg 13 \
--jit 1 \
--num-encoder-layers 2,2,2,2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024,1024,1024,768 \
--encoder-dim 192,256,384,512,384,384,384,256 \
--encoder-unmasked-dim 192,192,256,256,256,256,256,192 \
--downsampling-factor 1,2,4,8,8,4,4,2 \
--cnn-module-kernel 31,31,15,15,15,15,31,31 \
--num-heads 4,4,4,8,8,8,4,4 \
--use-hat False
It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("jit_script.pt")`.
Check ./jit_pretrained.py for its usage.
Check https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
- For streaming model:
./zipformer_multijoiner_st/export.py \
--exp-dir ./zipformer_multijoiner_st/exp-multi-joiner-nohat-pbe4k_causal \
--bpe-st-model data/lang_st_bpe_4000/bpe.model \
--bpe-model data/lang_bpe_5000/bpe.model \
--causal 1 \
--chunk-size 64 \
--left-context-frames 128 \
--epoch 25 \
--avg 13 \
--jit 1 \
--num-encoder-layers 2,2,2,2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024,1024,1024,768 \
--encoder-dim 192,256,384,512,384,384,384,256 \
--encoder-unmasked-dim 192,192,256,256,256,256,256,192 \
--downsampling-factor 1,2,4,8,8,4,4,2 \
--cnn-module-kernel 31,31,15,15,15,15,31,31 \
--num-heads 4,4,4,8,8,8,4,4 \
--use-hat False
It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`.
You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`.
Check ./jit_pretrained_streaming.py for its usage.
Check https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
(2) Export `model.state_dict()`
- For non-streaming model:
./zipformer_multijoiner_st/export.py \
--exp-dir ./zipformer_multijoiner_st/exp-multi-joiner-nohat-pbe4k \
--bpe-st-model data/lang_st_bpe_4000/bpe.model \
--bpe-model data/lang_bpe_5000/bpe.model \
--epoch 25 \
--avg 13 \
--beam-size 20 \
--num-encoder-layers 2,2,2,2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024,1024,1024,768 \
--encoder-dim 192,256,384,512,384,384,384,256 \
--encoder-unmasked-dim 192,192,256,256,256,256,256,192 \
--downsampling-factor 1,2,4,8,8,4,4,2 \
--cnn-module-kernel 31,31,15,15,15,15,31,31 \
--num-heads 4,4,4,8,8,8,4,4
- For streaming model:
./zipformer_multijoiner_st/export.py \
--exp-dir ./zipformer_multijoiner_st/exp-multi-joiner-nohat-pbe4k_causal \
--bpe-st-model data/lang_st_bpe_4000/bpe.model \
--bpe-model data/lang_bpe_5000/bpe.model \
--epoch 25 \
--avg 13 \
--causal 1 \
--chunk-size 64 \
--left-context-frames 128 \
--num-encoder-layers 2,2,2,2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024,1024,1024,768 \
--encoder-dim 192,256,384,512,384,384,384,256 \
--encoder-unmasked-dim 192,192,256,256,256,256,256,192 \
--downsampling-factor 1,2,4,8,8,4,4,2 \
--cnn-module-kernel 31,31,15,15,15,15,31,31 \
--num-heads 4,4,4,8,8,8,4,4
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
- For non-streaming model:
To use the generated file with `zipformer/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./zipformer_multijoiner_st/decode.py \
--exp-dir ./zipformer_multijoiner_st/exp-multi-joiner-nohat-pbe4k \
--epoch 9999 \
--avg 1 \
--beam-size 20 \
--max-duration 600 \
--decoding-method modified_beam_search \
--bpe-st-model data/lang_st_bpe_4000/bpe.model \
--bpe-model data/lang_bpe_5000/bpe.model \
--num-encoder-layers 2,2,2,2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024,1024,1024,768 \
--encoder-dim 192,256,384,512,384,384,384,256 \
--encoder-unmasked-dim 192,192,256,256,256,256,256,192 \
--downsampling-factor 1,2,4,8,8,4,4,2 \
--cnn-module-kernel 31,31,15,15,15,15,31,31 \
--num-heads 4,4,4,8,8,8,4,4 \
--use-averaged-model false
- For streaming model:
To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
# simulated streaming decoding
./zipformer_multijoiner_st/decode.py \
--exp-dir ./zipformer_multijoiner_st/exp-multi-joiner-nohat-pbe4k \
--epoch 9999 \
--avg 1 \
--causal 1 \
--beam-size 20 \
--max-duration 600 \
--bpe-st-model data/lang_st_bpe_4000/bpe.model \
--bpe-model data/lang_bpe_5000/bpe.model \
--num-encoder-layers 2,2,2,2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024,1024,1024,768 \
--encoder-dim 192,256,384,512,384,384,384,256 \
--encoder-unmasked-dim 192,192,256,256,256,256,256,192 \
--downsampling-factor 1,2,4,8,8,4,4,2 \
--cnn-module-kernel 31,31,15,15,15,15,31,31 \
--num-heads 4,4,4,8,8,8,4,4 \
--use-averaged-model false \
--chunk-size 64 \
--left-context-frames 128 \
--max-sym-per-frame 20 \
--use-hat False \
--decoding-method greedy_search
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
- non-streaming model:
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
- streaming model:
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
# You will find the pre-trained models in exp dir
"""
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import k2
import sentencepiece as spm
import torch
from scaling_converter import convert_scaled_to_non_scaled
from torch import Tensor, nn
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, num_tokens, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-st-model",
type=str,
default="data/lang_st_bpe_4000/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_5000/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate a file named jit_script.pt.
Check ./jit_pretrained.py for how to use it.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
class EncoderModel(nn.Module):
"""A wrapper for encoder and encoder_embed"""
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
def forward(
self, features: Tensor, feature_lengths: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Args:
features: (N, T, C)
feature_lengths: (N,)
"""
x, x_lens = self.encoder_embed(features, feature_lengths)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return encoder_out, encoder_out_lens
class StreamingEncoderModel(nn.Module):
"""A wrapper for encoder and encoder_embed"""
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
super().__init__()
assert len(encoder.chunk_size) == 1, encoder.chunk_size
assert len(encoder.left_context_frames) == 1, encoder.left_context_frames
self.chunk_size = encoder.chunk_size[0]
self.left_context_len = encoder.left_context_frames[0]
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
self.pad_length = 7 + 2 * 3
self.encoder = encoder
self.encoder_embed = encoder_embed
def forward(
self, features: Tensor, feature_lengths: Tensor, states: List[Tensor]
) -> Tuple[Tensor, Tensor, List[Tensor]]:
"""Streaming forward for encoder_embed and encoder.
Args:
features: (N, T, C)
feature_lengths: (N,)
states: a list of Tensors
Returns encoder outputs, output lengths, and updated states.
"""
chunk_size = self.chunk_size
left_context_len = self.left_context_len
cached_embed_left_pad = states[-2]
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
x=features,
x_lens=feature_lengths,
cached_left_pad=cached_embed_left_pad,
)
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
src_key_padding_mask = make_pad_mask(x_lens)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2]
(
encoder_out,
encoder_out_lens,
new_encoder_states,
) = self.encoder.streaming_forward(
x=x,
x_lens=x_lens,
states=encoder_states,
src_key_padding_mask=src_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
new_states = new_encoder_states + [
new_cached_embed_left_pad,
new_processed_lens,
]
return encoder_out, encoder_out_lens, new_states
@torch.jit.export
def get_init_states(
self,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
"""
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
states[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
"""
states = self.encoder.get_init_states(batch_size, device)
embed_states = self.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
states.append(processed_lens)
return states
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
# if torch.cuda.is_available():
# device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
sp_st = spm.SentencePieceProcessor()
sp_st.load(params.bpe_st_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.blank_st_id = sp_st.piece_to_id("<blk>")
params.st_unk_id = sp_st.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.st_vocab_size = sp_st.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.eval()
if params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
# Wrap encoder and encoder_embed as a module
if params.causal:
model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed)
chunk_size = model.encoder.chunk_size
left_context_len = model.encoder.left_context_len
filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt"
else:
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
filename = "jit_script.pt"
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
model.save(str(params.exp_dir / filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torchscript. Export model.state_dict()")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,72 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from scaling import ScaledLinear
from typing import Optional
class Joiner(nn.Module):
def __init__(
self,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
encoder_lid: Optional[int] = None,
):
super().__init__()
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
if encoder_lid:
self.lid_proj = ScaledLinear(encoder_lid, joiner_dim, initial_scale=0.25)
self.output_linear = nn.Linear(joiner_dim, vocab_size)
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
project_input: bool = True,
lid_out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, s_range, C).
decoder_out:
Output from the decoder. Its shape is (N, T, s_range, C).
project_input:
If true, apply input projections encoder_proj and decoder_proj.
If this is false, it is the user's responsibility to do this
manually.
Returns:
Return a tensor of shape (N, T, s_range, C).
"""
assert encoder_out.ndim == decoder_out.ndim, (
encoder_out.shape,
decoder_out.shape,
)
if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
elif lid_out is not None:
logit = encoder_out + decoder_out + lid_out
else:
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit

View File

@ -0,0 +1,681 @@
# Copyright 2025 Johns Hopkins University (author: Amir Hussein)
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import k2
import torch
from torch import Tensor
from lhotse.dataset import SpecAugment
import torch.nn as nn
from encoder_interface import EncoderInterface
from icefall.utils import add_sos, make_pad_mask, time_warp
from scaling import ScaledLinear
class StModel(nn.Module):
def __init__(
self,
encoder_embed: nn.Module,
encoder: EncoderInterface,
decoder: Optional[nn.Module] = None,
joiner: Optional[nn.Module] = None,
st_joiner: Optional[nn.Module] = None,
st_decoder: Optional[nn.Module] = None,
encoder_dim: int = 384,
decoder_dim: int = 512,
vocab_size: int = 500,
st_vocab_size: int = 500,
use_transducer: bool = True,
use_ctc: bool = False,
use_st_ctc: bool = False,
use_hat: bool = False,
):
"""A multitask Transducer ASR-ST model with seperate joiners and predictors but shared acoustic encoder.
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
- Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
- Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
Args:
encoder_embed:
It is a Convolutional 2D subsampling module. It converts
an input of shape (N, T, idim) to an output of of shape
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
It is used when use_transducer is True.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
It is used when use_transducer is True.
use_transducer:
Whether use transducer head. Default: True.
use_ctc:
Whether use CTC head. Default: False.
"""
super().__init__()
assert (
use_transducer or use_ctc
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
assert isinstance(encoder, EncoderInterface), type(encoder)
self.encoder_embed = encoder_embed
self.encoder = encoder
self.use_hat = use_hat
self.use_transducer = use_transducer
if use_transducer:
# Modules for Transducer head
assert decoder is not None
assert hasattr(decoder, "blank_id")
assert joiner is not None
self.decoder = decoder
self.joiner = joiner
self.st_joiner = st_joiner
self.st_decoder = st_decoder
self.simple_am_proj = ScaledLinear(
encoder_dim, vocab_size, initial_scale=0.25
)
self.simple_lm_proj = ScaledLinear(
decoder_dim, vocab_size, initial_scale=0.25
)
self.simple_st_am_proj = ScaledLinear(
encoder_dim, st_vocab_size, initial_scale=0.25
)
self.simple_st_lm_proj = ScaledLinear(
decoder_dim, st_vocab_size, initial_scale=0.25
)
else:
assert decoder is None
assert joiner is None
self.use_ctc = use_ctc
self.use_st_ctc = use_st_ctc
if self.use_ctc:
# Modules for CTC head
self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)
if self.use_st_ctc:
# Modules for CTC head
self.st_ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(encoder_dim, st_vocab_size),
nn.LogSoftmax(dim=-1),
)
def forward_encoder(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute encoder outputs.
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
Returns:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
"""
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x, x_lens = self.encoder_embed(x, x_lens)
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
return encoder_out, encoder_out_lens
def forward_st_ctc(
self,
st_encoder_out: torch.Tensor,
st_encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> torch.Tensor:
"""Compute CTC loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
targets:
Target Tensor of shape (sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC log-prob
ctc_output = self.st_ctc_output(st_encoder_out) # (N, T, C)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets,
input_lengths=st_encoder_out_lens,
target_lengths=target_lengths,
reduction="sum",
)
return ctc_loss
def forward_ctc(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> torch.Tensor:
"""Compute CTC loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
targets:
Target Tensor of shape (sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets,
input_lengths=encoder_out_lens,
target_lengths=target_lengths,
reduction="sum",
)
return ctc_loss
def forward_cr_ctc(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute CTC loss with consistency regularization loss.
Args:
encoder_out:
Encoder output, of shape (2 * N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (2 * N,).
targets:
Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC loss
ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C)
targets=targets.cpu(),
input_lengths=encoder_out_lens.cpu(),
target_lengths=target_lengths.cpu(),
reduction="none",
)
ctc_loss_is_finite = torch.isfinite(ctc_loss)
ctc_loss = ctc_loss[ctc_loss_is_finite]
ctc_loss = ctc_loss.sum()
# Compute consistency regularization loss
exchanged_targets = ctc_output.detach().chunk(2, dim=0)
exchanged_targets = torch.cat(
[exchanged_targets[1], exchanged_targets[0]], dim=0
) # exchange: [x1, x2] -> [x2, x1]
cr_loss = nn.functional.kl_div(
input=ctc_output,
target=exchanged_targets,
reduction="none",
log_target=True,
) # (2 * N, T, C)
length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1)
cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum()
return ctc_loss, cr_loss
def forward_st_cr_ctc(
self,
st_encoder_out: torch.Tensor,
st_encoder_out_lens: torch.Tensor,
st_targets: torch.Tensor,
st_target_lengths: torch.Tensor,
# encoder_out: torch.Tensor,
# encoder_out_lens: torch.Tensor,
# targets: torch.Tensor,
# target_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute CTC loss with consistency regularization loss.
Args:
encoder_out:
Encoder output, of shape (2 * N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (2 * N,).
targets:
Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC loss
st_ctc_output = self.st_ctc_output(st_encoder_out) # (2 * N, T, C)
st_ctc_loss = torch.nn.functional.ctc_loss(
log_probs=st_ctc_output.permute(1, 0, 2), # (T, 2 * N, C)
targets=st_targets.cpu(),
input_lengths=st_encoder_out_lens.cpu(),
target_lengths=st_target_lengths.cpu(),
reduction="none",
)
st_ctc_loss_is_finite = torch.isfinite(st_ctc_loss)
st_ctc_loss = st_ctc_loss[st_ctc_loss_is_finite]
st_ctc_loss = st_ctc_loss.sum()
# ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C)
# ctc_loss = torch.nn.functional.ctc_loss(
# log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C)
# targets=targets.cpu(),
# input_lengths=encoder_out_lens.cpu(),
# target_lengths=target_lengths.cpu(),
# reduction="sum",
# )
# if not torch.isfinite(st_ctc_loss):
# breakpoint()
# Compute consistency regularization loss
exchanged_targets = st_ctc_output.detach().chunk(2, dim=0)
exchanged_targets = torch.cat(
[exchanged_targets[1], exchanged_targets[0]], dim=0
) # exchange: [x1, x2] -> [x2, x1]
cr_loss = nn.functional.kl_div(
input=st_ctc_output,
target=exchanged_targets,
reduction="none",
log_target=True,
) # (2 * N, T, C)
length_mask = make_pad_mask(st_encoder_out_lens).unsqueeze(-1)
cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum()
return st_ctc_loss, cr_loss
def forward_transducer(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
y: k2.RaggedTensor,
y_lens: torch.Tensor,
st_y: k2.RaggedTensor,
st_y_lens: torch.Tensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
"""Compute Transducer loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
"""
# Now for the decoder, i.e., the prediction network
blank_id = self.decoder.blank_id
st_blank_id = self.st_decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
st_sos_y = add_sos(st_y, sos_id=st_blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
st_sos_y_padded = st_sos_y.pad(mode="constant", padding_value=st_blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
st_decoder_out = self.st_decoder(st_sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(encoder_out.size(0), 4),
dtype=torch.int64,
device=encoder_out.device,
)
boundary[:, 2] = y_lens
boundary[:, 3] = encoder_out_lens
st_y_padded = st_y.pad(mode="constant", padding_value=0)
st_y_padded = st_y_padded.to(torch.int64)
st_boundary = torch.zeros(
(encoder_out.size(0), 4),
dtype=torch.int64,
device=encoder_out.device,
)
st_boundary[:, 2] = st_y_lens
st_boundary[:, 3] = encoder_out_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
st_lm = self.simple_st_lm_proj(st_decoder_out)
st_am = self.simple_st_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
st_simple_loss, (st_px_grad, st_py_grad) = k2.rnnt_loss_smoothed(
lm=st_lm.float(),
am=st_am.float(),
symbols=st_y_padded,
termination_symbol=st_blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=st_boundary,
reduction="sum",
return_grad=True,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
use_hat_loss=self.use_hat,
)
# logits : [B, T, prune_range, vocab_size]
st_ranges = k2.get_rnnt_prune_ranges(
px_grad=st_px_grad,
py_grad=st_py_grad,
boundary=st_boundary,
s_range=prune_range,
)
st_am_pruned, st_lm_pruned = k2.do_rnnt_pruning(
am=self.st_joiner.encoder_proj(encoder_out),
lm=self.st_joiner.decoder_proj(st_decoder_out),
ranges=st_ranges,
)
st_logits = self.st_joiner(st_am_pruned, st_lm_pruned, project_input=False)
# Compute HAT loss for st
with torch.cuda.amp.autocast(enabled=False):
pruned_st_loss = k2.rnnt_loss_pruned(
logits=st_logits.float(),
symbols=st_y.pad(mode="constant", padding_value=blank_id).to(torch.int64),
ranges=st_ranges,
termination_symbol=st_blank_id,
boundary=st_boundary,
reduction="sum",
use_hat_loss=self.use_hat,
)
return simple_loss, st_simple_loss, pruned_loss, pruned_st_loss
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
st_y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
use_st_cr_ctc: bool = False,
use_asr_cr_ctc: bool = False,
use_spec_aug: bool = False,
spec_augment: Optional[SpecAugment] = None,
supervision_segments: Optional[torch.Tensor] = None,
time_warp_factor: Optional[int] = 80,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
use_cr_ctc:
Whether use consistency-regularized CTC.
use_spec_aug:
Whether apply spec-augment manually, used only if use_cr_ctc is True.
spec_augment:
The SpecAugment instance that returns time masks,
used only if use_cr_ctc is True.
supervision_segments:
An int tensor of shape ``(S, 3)``. ``S`` is the number of
supervision segments that exist in ``features``.
Used only if use_cr_ctc is True.
time_warp_factor:
Parameter for the time warping; larger values mean more warping.
Set to ``None``, or less than ``1``, to disable.
Used only if use_cr_ctc is True.
Returns:
Return the transducer losses and CTC loss,
in form of (simple_loss, pruned_loss, ctc_loss)
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert st_y.num_axes == 2, st_y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
if use_st_cr_ctc or use_asr_cr_ctc:
assert self.use_ctc or self.use_st_ctc
if use_spec_aug:
assert spec_augment is not None and spec_augment.time_warp_factor < 1
# Apply time warping before input duplicating
assert supervision_segments is not None
x = time_warp(
x,
time_warp_factor=time_warp_factor,
supervision_segments=supervision_segments,
)
# Independently apply frequency masking and time masking to the two copies
x = spec_augment(x.repeat(2, 1, 1))
else:
x = x.repeat(2, 1, 1)
x_lens = x_lens.repeat(2)
y = k2.ragged.cat([y, y], axis=0)
if self.st_joiner != None and self.use_st_ctc:
st_y = k2.ragged.cat([st_y, st_y], axis=0)
# Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
st_row_splits = st_y.shape.row_splits(1)
st_y_lens = st_row_splits[1:] - st_row_splits[:-1]
if self.use_transducer:
# Compute transducer loss
if self.st_joiner != None:
simple_loss, st_simple_loss, pruned_loss, st_pruned_loss = self.forward_transducer(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
y=y.to(x.device),
y_lens=y_lens,
st_y=st_y.to(x.device),
st_y_lens=st_y_lens,
prune_range=prune_range,
am_scale=am_scale,
lm_scale=lm_scale,
)
if use_asr_cr_ctc:
simple_loss = simple_loss * 0.5
pruned_loss = pruned_loss * 0.5
if use_st_cr_ctc:
st_simple_loss = st_simple_loss * 0.5
st_pruned_loss = st_pruned_loss * 0.5
else:
simple_loss, pruned_loss = self.forward_transducer(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
y=y.to(x.device),
y_lens=y_lens,
prune_range=prune_range,
am_scale=am_scale,
lm_scale=lm_scale,
)
if use_asr_cr_ctc:
simple_loss = simple_loss * 0.5
pruned_loss = pruned_loss * 0.5
st_simple_loss, st_pruned_loss = torch.empty(0), torch.empty(0)
else:
simple_loss = torch.empty(0)
pruned_loss = torch.empty(0)
if self.use_ctc:
# Compute CTC loss
targets = y.values
if not use_asr_cr_ctc:
ctc_loss = self.forward_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
)
cr_loss = torch.empty(0)
else:
ctc_loss, cr_loss = self.forward_cr_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
)
ctc_loss = ctc_loss * 0.5
cr_loss = cr_loss * 0.5
else:
cr_loss = torch.empty(0)
ctc_loss = torch.empty(0)
if self.use_st_ctc:
st_targets = st_y.values
if not use_st_cr_ctc:
st_ctc_loss = self.forward_st_ctc(
st_encoder_out=encoder_out,
st_encoder_out_lens=encoder_out_lens,
targets=st_targets,
target_lengths=st_y_lens,
)
st_cr_loss = torch.empty(0)
else:
st_ctc_loss, st_cr_loss = self.forward_st_cr_ctc(
st_encoder_out=encoder_out,
st_encoder_out_lens=encoder_out_lens,
st_targets=st_targets,
st_target_lengths=st_y_lens,
# encoder_out=encoder_out,
# encoder_out_lens=encoder_out_lens,
# targets=targets,
# target_lengths=y_lens,
)
st_ctc_loss = st_ctc_loss * 0.5
st_cr_loss = st_cr_loss * 0.5
else:
st_ctc_loss = torch.empty(0)
st_cr_loss = torch.empty(0)
return simple_loss, st_simple_loss, pruned_loss, st_pruned_loss, ctc_loss, st_ctc_loss, cr_loss, st_cr_loss

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,381 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads a checkpoint and uses it to decode waves.
You can generate the checkpoint with the following command:
Note: This is a example for librispeech dataset, if you are using different
dataset, you should change the argument values according to your dataset.
- For non-streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
- For streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--causal 1 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
Usage of this script:
- For non-streaming model:
(1) greedy search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) modified beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
(3) fast beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
- For streaming model:
(1) greedy search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) modified beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
(3) fast beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./zipformer/exp/epoch-xx.pt`.
Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
fast_beam_search_one_best,
greedy_search_batch,
modified_beam_search,
)
from export import num_tokens
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params
from icefall.utils import make_pad_mask
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--tokens",
type=str,
help="""Path to tokens.txt.""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
if params.causal:
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
logging.info("Creating model")
model = get_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
# model forward
encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
hyps = []
msg = f"Using {params.method}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
raise ValueError(f"Unsupported method: {params.method}")
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,170 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage: ./zipformer/profile.py
"""
import argparse
import logging
import sentencepiece as spm
import torch
from typing import Tuple
from torch import Tensor, nn
from icefall.utils import make_pad_mask
from icefall.profiler import get_model_profile
from scaling import BiasNorm
from train import (
get_encoder_embed,
get_encoder_model,
get_joiner_model,
add_model_arguments,
get_params,
)
from zipformer import BypassModule
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
add_model_arguments(parser)
return parser
def _bias_norm_flops_compute(module, input, output):
assert len(input) == 1, len(input)
# estimate as layer_norm, see icefall/profiler.py
flops = input[0].numel() * 5
module.__flops__ += int(flops)
def _swoosh_module_flops_compute(module, input, output):
# For SwooshL and SwooshR modules
assert len(input) == 1, len(input)
# estimate as swish/silu, see icefall/profiler.py
flops = input[0].numel()
module.__flops__ += int(flops)
def _bypass_module_flops_compute(module, input, output):
# For Bypass module
assert len(input) == 2, len(input)
flops = input[0].numel() * 2
module.__flops__ += int(flops)
MODULE_HOOK_MAPPING = {
BiasNorm: _bias_norm_flops_compute,
BypassModule: _bypass_module_flops_compute,
}
class Model(nn.Module):
"""A Wrapper for encoder, encoder_embed, and encoder_proj"""
def __init__(
self,
encoder: nn.Module,
encoder_embed: nn.Module,
encoder_proj: nn.Module,
) -> None:
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
self.encoder_proj = encoder_proj
def forward(self, feature: Tensor, feature_lens: Tensor) -> Tuple[Tensor, Tensor]:
x, x_lens = self.encoder_embed(feature, feature_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
logits = self.encoder_proj(encoder_out)
return logits, encoder_out_lens
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
# We only profile the encoder part
model = Model(
encoder=get_encoder_model(params),
encoder_embed=get_encoder_embed(params),
encoder_proj=get_joiner_model(params).encoder_proj,
)
model.eval()
model.to(device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# for 30-second input
B, T, D = 1, 3000, 80
feature = torch.ones(B, T, D, dtype=torch.float32).to(device)
feature_lens = torch.full((B,), T, dtype=torch.int64).to(device)
flops, params = get_model_profile(
model=model,
args=(feature, feature_lens),
module_hoop_mapping=MODULE_HOOK_MAPPING,
)
logging.info(f"For the encoder part, params: {params}, flops: {flops}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,104 @@
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file replaces various modules in a model.
Specifically, ActivationBalancer is replaced with an identity operator;
Whiten is also replaced with an identity operator;
BasicNorm is replaced by a module with `exp` removed.
"""
import copy
from typing import List, Tuple
import torch
import torch.nn as nn
from scaling import (
Balancer,
Dropout3,
ScaleGrad,
SwooshL,
SwooshLOnnx,
SwooshR,
SwooshROnnx,
Whiten,
)
from zipformer import CompactRelPositionalEncoding
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
# get_submodule was added to nn.Module at v1.9.0
def get_submodule(model, target):
if target == "":
return model
atoms: List[str] = target.split(".")
mod: torch.nn.Module = model
for item in atoms:
if not hasattr(mod, item):
raise AttributeError(
mod._get_name() + " has no " "attribute `" + item + "`"
)
mod = getattr(mod, item)
if not isinstance(mod, torch.nn.Module):
raise AttributeError("`" + item + "` is not " "an nn.Module")
return mod
def convert_scaled_to_non_scaled(
model: nn.Module,
inplace: bool = False,
is_pnnx: bool = False,
is_onnx: bool = False,
):
"""
Args:
model:
The model to be converted.
inplace:
If True, the input model is modified inplace.
If False, the input model is copied and we modify the copied version.
is_pnnx:
True if we are going to export the model for PNNX.
is_onnx:
True if we are going to export the model for ONNX.
Return:
Return a model without scaled layers.
"""
if not inplace:
model = copy.deepcopy(model)
d = {}
for name, m in model.named_modules():
if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)):
d[name] = nn.Identity()
elif is_onnx and isinstance(m, SwooshR):
d[name] = SwooshROnnx()
elif is_onnx and isinstance(m, SwooshL):
d[name] = SwooshLOnnx()
elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
# We want to recreate the positional encoding vector when
# the input changes, so we have to use torch.jit.script()
# to replace torch.jit.trace()
d[name] = torch.jit.script(m)
for k, v in d.items():
if "." in k:
parent, child = k.rsplit(".", maxsplit=1)
setattr(get_submodule(model, parent), child, v)
else:
setattr(model, k, v)
return model

View File

@ -0,0 +1,406 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import warnings
import torch
from torch import Tensor, nn
from scaling import (
Balancer,
BiasNorm,
Dropout3,
FloatLike,
Optional,
ScaledConv2d,
ScaleGrad,
ScheduledFloat,
SwooshL,
SwooshR,
Whiten,
)
class ConvNeXt(nn.Module):
"""
Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf
"""
def __init__(
self,
channels: int,
hidden_ratio: int = 3,
kernel_size: Tuple[int, int] = (7, 7),
layerdrop_rate: FloatLike = None,
):
super().__init__()
self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
hidden_channels = channels * hidden_ratio
if layerdrop_rate is None:
layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015))
self.layerdrop_rate = layerdrop_rate
self.depthwise_conv = nn.Conv2d(
in_channels=channels,
out_channels=channels,
groups=channels,
kernel_size=kernel_size,
padding=self.padding,
)
self.pointwise_conv1 = nn.Conv2d(
in_channels=channels, out_channels=hidden_channels, kernel_size=1
)
self.hidden_balancer = Balancer(
hidden_channels,
channel_dim=1,
min_positive=0.3,
max_positive=1.0,
min_abs=0.75,
max_abs=5.0,
)
self.activation = SwooshL()
self.pointwise_conv2 = ScaledConv2d(
in_channels=hidden_channels,
out_channels=channels,
kernel_size=1,
initial_scale=0.01,
)
self.out_balancer = Balancer(
channels,
channel_dim=1,
min_positive=0.4,
max_positive=0.6,
min_abs=1.0,
max_abs=6.0,
)
self.out_whiten = Whiten(
num_groups=1,
whitening_limit=5.0,
prob=(0.025, 0.25),
grad_scale=0.01,
)
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
return self.forward_internal(x)
layerdrop_rate = float(self.layerdrop_rate)
if layerdrop_rate != 0.0:
batch_size = x.shape[0]
mask = (
torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
> layerdrop_rate
)
else:
mask = None
# turns out this caching idea does not work with --world-size > 1
# return caching_eval(self.forward_internal, x, mask)
return self.forward_internal(x, mask)
def forward_internal(
self, x: Tensor, layer_skip_mask: Optional[Tensor] = None
) -> Tensor:
"""
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
The returned value has the same shape as x.
"""
bypass = x
x = self.depthwise_conv(x)
x = self.pointwise_conv1(x)
x = self.hidden_balancer(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
if layer_skip_mask is not None:
x = x * layer_skip_mask
x = bypass + x
x = self.out_balancer(x)
if x.requires_grad:
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
x = self.out_whiten(x)
x = x.transpose(1, 3) # (N, C, H, W)
return x
def streaming_forward(
self,
x: Tensor,
cached_left_pad: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
Args:
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
cached_left_pad: (batch_size, num_channels, left_pad, num_freqs)
Returns:
- The returned value has the same shape as x.
- Updated cached_left_pad.
"""
padding = self.padding
# The length without right padding for depth-wise conv
T = x.size(2) - padding[0]
bypass = x[:, :, :T, :]
# Pad left side
assert cached_left_pad.size(2) == padding[0], (
cached_left_pad.size(2),
padding[0],
)
x = torch.cat([cached_left_pad, x], dim=2)
# Update cached left padding
cached_left_pad = x[:, :, T : padding[0] + T, :]
# depthwise_conv
x = torch.nn.functional.conv2d(
x,
weight=self.depthwise_conv.weight,
bias=self.depthwise_conv.bias,
padding=(0, padding[1]),
groups=self.depthwise_conv.groups,
)
x = self.pointwise_conv1(x)
x = self.hidden_balancer(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
x = bypass + x
return x, cached_left_pad
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/2 length).
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = (T-3)//2 - 2 == (T-7)//2
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(
self,
in_channels: int,
out_channels: int,
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
dropout: FloatLike = 0.1,
) -> None:
"""
Args:
in_channels:
Number of channels in. The input shape is (N, T, in_channels).
Caution: It requires: T >=7, in_channels >=7
out_channels
Output dim. The output shape is (N, (T-3)//2, out_channels)
layer1_channels:
Number of channels in layer1
layer1_channels:
Number of channels in layer2
bottleneck:
bottleneck dimension for 1d squeeze-excite
"""
assert in_channels >= 7
super().__init__()
# The ScaleGrad module is there to prevent the gradients
# w.r.t. the weight or bias of the first Conv2d module in self.conv from
# exceeding the range of fp16 when using automatic mixed precision (amp)
# training. (The second one is necessary to stop its bias from getting
# a too-large gradient).
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=layer1_channels,
kernel_size=3,
padding=(0, 1), # (time, freq)
),
ScaleGrad(0.2),
Balancer(layer1_channels, channel_dim=1, max_abs=1.0),
SwooshR(),
nn.Conv2d(
in_channels=layer1_channels,
out_channels=layer2_channels,
kernel_size=3,
stride=2,
padding=0,
),
Balancer(layer2_channels, channel_dim=1, max_abs=4.0),
SwooshR(),
nn.Conv2d(
in_channels=layer2_channels,
out_channels=layer3_channels,
kernel_size=3,
stride=(1, 2), # (time, freq)
),
Balancer(layer3_channels, channel_dim=1, max_abs=4.0),
SwooshR(),
)
# just one convnext layer
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
# (in_channels-3)//4
self.out_width = (((in_channels - 1) // 2) - 1) // 2
self.layer3_channels = layer3_channels
self.out = nn.Linear(self.out_width * layer3_channels, out_channels)
# use a larger than normal grad_scale on this whitening module; there is
# only one such module, so there is not a concern about adding together
# many copies of this extra gradient term.
self.out_whiten = Whiten(
num_groups=1,
whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
prob=(0.025, 0.25),
grad_scale=0.02,
)
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
# getting large, there is an unnecessary degree of freedom.
self.out_norm = BiasNorm(out_channels)
self.dropout = Dropout3(dropout, shared_dim=1)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
Returns:
- a tensor of shape (N, (T-7)//2, odim)
- output lengths, of shape (batch_size,)
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
# gradients.
x = self.conv(x)
x = self.convnext(x)
# Now x is of shape (N, odim, (T-7)//2, (idim-3)//4)
b, c, t, f = x.size()
x = x.transpose(1, 2).reshape(b, t, c * f)
# now x: (N, (T-7)//2, out_width * layer3_channels))
x = self.out(x)
# Now x is of shape (N, (T-7)//2, odim)
x = self.out_whiten(x)
x = self.out_norm(x)
x = self.dropout(x)
if torch.jit.is_scripting() or torch.jit.is_tracing():
x_lens = (x_lens - 7) // 2
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
x_lens = (x_lens - 7) // 2
assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max())
return x, x_lens
def streaming_forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
cached_left_pad: Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
Returns:
- a tensor of shape (N, (T-7)//2, odim)
- output lengths, of shape (batch_size,)
- updated cache
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
# T' = (T-7)//2
x = self.conv(x)
# T' = (T-7)//2-3
x, cached_left_pad = self.convnext.streaming_forward(
x, cached_left_pad=cached_left_pad
)
# Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = x.transpose(1, 2).reshape(b, t, c * f)
# now x: (N, T', out_width * layer3_channels))
x = self.out(x)
# Now x is of shape (N, T', odim)
x = self.out_norm(x)
if torch.jit.is_scripting() or torch.jit.is_tracing():
assert self.convnext.padding[0] == 3
# The ConvNeXt module needs 3 frames of right padding after subsampling
x_lens = (x_lens - 7) // 2 - 3
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# The ConvNeXt module needs 3 frames of right padding after subsampling
assert self.convnext.padding[0] == 3
x_lens = (x_lens - 7) // 2 - 3
assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max())
return x, x_lens, cached_left_pad
@torch.jit.export
def get_init_states(
self,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> Tensor:
"""Get initial states for Conv2dSubsampling module.
It is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
"""
left_pad = self.convnext.padding[0]
freq = self.out_width
channels = self.layer3_channels
cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to(
device
)
return cached_embed_left_pad

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff