diff --git a/.github/workflows/run-pretrained-transducer-stateless.yml b/.github/workflows/run-pretrained-transducer-stateless.yml
index 3bbd4c49b..5f4a425d9 100644
--- a/.github/workflows/run-pretrained-transducer-stateless.yml
+++ b/.github/workflows/run-pretrained-transducer-stateless.yml
@@ -74,11 +74,11 @@ jobs:
mkdir tmp
cd tmp
git lfs install
- git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27
+ git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10
cd ..
tree tmp
- soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/*.wav
- ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/*.wav
+ soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/*.wav
+ ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/*.wav
- name: Run greedy search decoding
shell: bash
@@ -87,11 +87,11 @@ jobs:
cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \
--method greedy_search \
- --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/exp/pretrained.pt \
- --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/data/lang_bpe_500/bpe.model \
- ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1089-134686-0001.wav \
- ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0001.wav \
- ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0002.wav
+ --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/exp/pretrained.pt \
+ --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/data/lang_bpe_500/bpe.model \
+ ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1089-134686-0001.wav \
+ ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0001.wav \
+ ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0002.wav
- name: Run beam search decoding
shell: bash
@@ -101,8 +101,8 @@ jobs:
./transducer_stateless/pretrained.py \
--method beam_search \
--beam-size 4 \
- --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/exp/pretrained.pt \
- --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/data/lang_bpe_500/bpe.model \
- ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1089-134686-0001.wav \
- ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0001.wav \
- ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0002.wav
+ --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/exp/pretrained.pt \
+ --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/data/lang_bpe_500/bpe.model \
+ ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1089-134686-0001.wav \
+ ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0001.wav \
+ ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0002.wav
diff --git a/README.md b/README.md
index f7aed9dc3..7dee1c1d6 100644
--- a/README.md
+++ b/README.md
@@ -84,12 +84,12 @@ The best WER using beam search with beam size 4 is:
| | test-clean | test-other |
|-----|------------|------------|
-| WER | 2.83 | 7.19 |
+| WER | 2.76 | 6.97 |
Note: No auxiliary losses are used in the training and no LMs are used
in the decoding.
-We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [](https://colab.research.google.com/drive/1Lm37sNajIpkV4HTzMDF7sn9l0JpfmekN?usp=sharing)
+We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [](https://colab.research.google.com/drive/1Rc4Is-3Yp9LbcEz_Iy8hfyenyHsyjvqE?usp=sharing)
### Aishell
diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md
index 1476c0528..18d9d4dec 100644
--- a/egs/librispeech/ASR/RESULTS.md
+++ b/egs/librispeech/ASR/RESULTS.md
@@ -4,7 +4,7 @@
#### Conformer encoder + embedding decoder
-Using commit `14c93add507982306f5a478cd144e0e32e0f970d`.
+Using commit `TODO`.
Conformer encoder + non-current decoder. The decoder
contains only an embedding layer and a Conv1d (with kernel size 2).
@@ -13,8 +13,8 @@ The WERs are
| | test-clean | test-other | comment |
|---------------------------|------------|------------|------------------------------------------|
-| greedy search | 2.85 | 7.30 | --epoch 29, --avg 13, --max-duration 100 |
-| beam search (beam size 4) | 2.83 | 7.19 | |
+| greedy search | 2.77 | 7.07 | --epoch 30, --avg 13, --max-duration 100 |
+| beam search (beam size 4) | 2.76 | 6.97 | |
The training command for reproducing is given below:
@@ -32,11 +32,11 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
```
The tensorboard training log can be found at
-
+
The decoding command is:
```
-epoch=29
+epoch=36
avg=13
## greedy search
diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py
index 9ed9b2ad1..989caa802 100644
--- a/egs/librispeech/ASR/transducer_stateless/beam_search.py
+++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py
@@ -66,6 +66,9 @@ def greedy_search(
# symbols per utterance decoded so far
sym_per_utt = 0
+ encoder_out_len = torch.tensor([1])
+ decoder_out_len = torch.tensor([1])
+
while t < T and sym_per_utt < max_sym_per_utt:
if sym_per_frame >= max_sym_per_frame:
sym_per_frame = 0
@@ -75,7 +78,9 @@ def greedy_search(
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
- logits = model.joiner(current_encoder_out, decoder_out)
+ logits = model.joiner(
+ current_encoder_out, decoder_out, encoder_out_len, decoder_out_len
+ )
# logits is (1, 1, 1, vocab_size)
y = logits.argmax().item()
@@ -262,6 +267,9 @@ def beam_search(
sym_per_utt = 0
+ encoder_out_len = torch.tensor([1])
+ decoder_out_len = torch.tensor([1])
+
decoder_cache: Dict[str, torch.Tensor] = {}
while t < T and sym_per_utt < max_sym_per_utt:
@@ -294,7 +302,12 @@ def beam_search(
cached_key += f"-t-{t}"
if cached_key not in joint_cache:
- logits = model.joiner(current_encoder_out, decoder_out)
+ logits = model.joiner(
+ current_encoder_out,
+ decoder_out,
+ encoder_out_len,
+ decoder_out_len,
+ )
# TODO(fangjun): Ccale the blank posterior
diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py
index 2ef3f1de6..9fd9da4f1 100644
--- a/egs/librispeech/ASR/transducer_stateless/joiner.py
+++ b/egs/librispeech/ASR/transducer_stateless/joiner.py
@@ -22,33 +22,51 @@ class Joiner(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
+ self.input_dim = input_dim
+ self.output_dim = output_dim
self.output_linear = nn.Linear(input_dim, output_dim)
def forward(
- self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
+ self,
+ encoder_out: torch.Tensor,
+ decoder_out: torch.Tensor,
+ encoder_out_len: torch.Tensor,
+ decoder_out_len: torch.Tensor,
) -> torch.Tensor:
"""
Args:
encoder_out:
- Output from the encoder. Its shape is (N, T, C).
+ Output from the encoder. Its shape is (N, T, self.input_dim).
decoder_out:
- Output from the decoder. Its shape is (N, U, C).
+ Output from the decoder. Its shape is (N, U, self.input_dim).
Returns:
- Return a tensor of shape (N, T, U, C).
+ Return a tensor of shape (sum_all_TU, self.output_dim).
"""
assert encoder_out.ndim == decoder_out.ndim == 3
assert encoder_out.size(0) == decoder_out.size(0)
- assert encoder_out.size(2) == decoder_out.size(2)
+ assert encoder_out.size(2) == self.input_dim
+ assert decoder_out.size(2) == self.input_dim
- encoder_out = encoder_out.unsqueeze(2)
- # Now encoder_out is (N, T, 1, C)
+ N = encoder_out.size(0)
- decoder_out = decoder_out.unsqueeze(1)
- # Now decoder_out is (N, 1, U, C)
+ encoder_out_list = [
+ encoder_out[i, : encoder_out_len[i], :] for i in range(N)
+ ]
- logit = encoder_out + decoder_out
- logit = torch.tanh(logit)
+ decoder_out_list = [
+ decoder_out[i, : decoder_out_len[i], :] for i in range(N)
+ ]
- output = self.output_linear(logit)
+ x = [
+ e.unsqueeze(1) + d.unsqueeze(0)
+ for e, d in zip(encoder_out_list, decoder_out_list)
+ ]
- return output
+ x = [p.reshape(-1, self.input_dim) for p in x]
+ x = torch.cat(x)
+
+ activations = torch.tanh(x)
+
+ logits = self.output_linear(activations)
+
+ return logits
diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py
index 2f0f9a183..98a6f0f37 100644
--- a/egs/librispeech/ASR/transducer_stateless/model.py
+++ b/egs/librispeech/ASR/transducer_stateless/model.py
@@ -14,15 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""
-Note we use `rnnt_loss` from torchaudio, which exists only in
-torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
-"""
import k2
import torch
import torch.nn as nn
-import torchaudio
-import torchaudio.functional
from encoder_interface import EncoderInterface
from icefall.utils import add_sos
@@ -102,18 +96,24 @@ class Transducer(nn.Module):
decoder_out = self.decoder(sos_y_padded)
- logits = self.joiner(encoder_out, decoder_out)
+ # +1 here since a blank is prepended to each utterance.
+ logits = self.joiner(
+ encoder_out=encoder_out,
+ decoder_out=decoder_out,
+ encoder_out_len=x_lens,
+ decoder_out_len=y_lens + 1,
+ )
# rnnt_loss requires 0 padded targets
# Note: y does not start with SOS
y_padded = y.pad(mode="constant", padding_value=0)
- assert hasattr(torchaudio.functional, "rnnt_loss"), (
- f"Current torchaudio version: {torchaudio.__version__}\n"
- "Please install a version >= 0.10.0"
- )
+ # We don't put this `import` at the beginning of the file
+ # as it is required only in the training, not during the
+ # reference stage
+ import optimized_transducer
- loss = torchaudio.functional.rnnt_loss(
+ loss = optimized_transducer.transducer_loss(
logits=logits,
targets=y_padded,
logit_lengths=x_lens,
diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py
index 832fd556e..0a5a42089 100644
--- a/egs/yesno/ASR/tdnn/asr_datamodule.py
+++ b/egs/yesno/ASR/tdnn/asr_datamodule.py
@@ -180,7 +180,7 @@ class YesNoAsrDataModule(DataModule):
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
- Fbank(FbankConfig(num_mel_bins=23))
+ FbankConfig(sampling_rate=8000, num_mel_bins=23)
),
return_cuts=self.args.return_cuts,
)
diff --git a/requirements.txt b/requirements.txt
index 4eaa86a67..09d9ef69f 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,3 +3,4 @@ kaldialign
sentencepiece>=0.1.96
tensorboard
typeguard
+optimized_transducer