Use optimized_transducer to compute transducer loss. (#162)

* WIP: Use optimized_transducer to compute transducer loss.

* Minor fixes.

* Fix decoding.

* Fix decoding.

* Add RESULTS.

* Update RESULTS.

* Update CI.

* Fix sampling rate for yesno recipe.
This commit is contained in:
Fangjun Kuang 2022-01-10 11:54:58 +08:00 committed by GitHub
parent 319e120869
commit 4c1b3665ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 80 additions and 48 deletions

View File

@ -74,11 +74,11 @@ jobs:
mkdir tmp mkdir tmp
cd tmp cd tmp
git lfs install git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27 git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10
cd .. cd ..
tree tmp tree tmp
soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/*.wav soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/*.wav
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/*.wav ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/*.wav
- name: Run greedy search decoding - name: Run greedy search decoding
shell: bash shell: bash
@ -87,11 +87,11 @@ jobs:
cd egs/librispeech/ASR cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \ ./transducer_stateless/pretrained.py \
--method greedy_search \ --method greedy_search \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/exp/pretrained.pt \ --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/data/lang_bpe_500/bpe.model \ --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1089-134686-0001.wav \ ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0001.wav \ ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0002.wav ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0002.wav
- name: Run beam search decoding - name: Run beam search decoding
shell: bash shell: bash
@ -101,8 +101,8 @@ jobs:
./transducer_stateless/pretrained.py \ ./transducer_stateless/pretrained.py \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/exp/pretrained.pt \ --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/data/lang_bpe_500/bpe.model \ --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1089-134686-0001.wav \ ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0001.wav \ ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0002.wav ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0002.wav

View File

@ -84,12 +84,12 @@ The best WER using beam search with beam size 4 is:
| | test-clean | test-other | | | test-clean | test-other |
|-----|------------|------------| |-----|------------|------------|
| WER | 2.83 | 7.19 | | WER | 2.76 | 6.97 |
Note: No auxiliary losses are used in the training and no LMs are used Note: No auxiliary losses are used in the training and no LMs are used
in the decoding. in the decoding.
We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Lm37sNajIpkV4HTzMDF7sn9l0JpfmekN?usp=sharing) We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Rc4Is-3Yp9LbcEz_Iy8hfyenyHsyjvqE?usp=sharing)
### Aishell ### Aishell

View File

@ -4,7 +4,7 @@
#### Conformer encoder + embedding decoder #### Conformer encoder + embedding decoder
Using commit `14c93add507982306f5a478cd144e0e32e0f970d`. Using commit `TODO`.
Conformer encoder + non-current decoder. The decoder Conformer encoder + non-current decoder. The decoder
contains only an embedding layer and a Conv1d (with kernel size 2). contains only an embedding layer and a Conv1d (with kernel size 2).
@ -13,8 +13,8 @@ The WERs are
| | test-clean | test-other | comment | | | test-clean | test-other | comment |
|---------------------------|------------|------------|------------------------------------------| |---------------------------|------------|------------|------------------------------------------|
| greedy search | 2.85 | 7.30 | --epoch 29, --avg 13, --max-duration 100 | | greedy search | 2.77 | 7.07 | --epoch 30, --avg 13, --max-duration 100 |
| beam search (beam size 4) | 2.83 | 7.19 | | | beam search (beam size 4) | 2.76 | 6.97 | |
The training command for reproducing is given below: The training command for reproducing is given below:
@ -32,11 +32,11 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
``` ```
The tensorboard training log can be found at The tensorboard training log can be found at
<https://tensorboard.dev/experiment/Mjx7MeTgR3Oyr1yBCwjozw/> <https://tensorboard.dev/experiment/6fnVojoUQTmEJVq1yG34Vw/>
The decoding command is: The decoding command is:
``` ```
epoch=29 epoch=36
avg=13 avg=13
## greedy search ## greedy search

View File

@ -66,6 +66,9 @@ def greedy_search(
# symbols per utterance decoded so far # symbols per utterance decoded so far
sym_per_utt = 0 sym_per_utt = 0
encoder_out_len = torch.tensor([1])
decoder_out_len = torch.tensor([1])
while t < T and sym_per_utt < max_sym_per_utt: while t < T and sym_per_utt < max_sym_per_utt:
if sym_per_frame >= max_sym_per_frame: if sym_per_frame >= max_sym_per_frame:
sym_per_frame = 0 sym_per_frame = 0
@ -75,7 +78,9 @@ def greedy_search(
# fmt: off # fmt: off
current_encoder_out = encoder_out[:, t:t+1, :] current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on # fmt: on
logits = model.joiner(current_encoder_out, decoder_out) logits = model.joiner(
current_encoder_out, decoder_out, encoder_out_len, decoder_out_len
)
# logits is (1, 1, 1, vocab_size) # logits is (1, 1, 1, vocab_size)
y = logits.argmax().item() y = logits.argmax().item()
@ -262,6 +267,9 @@ def beam_search(
sym_per_utt = 0 sym_per_utt = 0
encoder_out_len = torch.tensor([1])
decoder_out_len = torch.tensor([1])
decoder_cache: Dict[str, torch.Tensor] = {} decoder_cache: Dict[str, torch.Tensor] = {}
while t < T and sym_per_utt < max_sym_per_utt: while t < T and sym_per_utt < max_sym_per_utt:
@ -294,7 +302,12 @@ def beam_search(
cached_key += f"-t-{t}" cached_key += f"-t-{t}"
if cached_key not in joint_cache: if cached_key not in joint_cache:
logits = model.joiner(current_encoder_out, decoder_out) logits = model.joiner(
current_encoder_out,
decoder_out,
encoder_out_len,
decoder_out_len,
)
# TODO(fangjun): Ccale the blank posterior # TODO(fangjun): Ccale the blank posterior

View File

@ -22,33 +22,51 @@ class Joiner(nn.Module):
def __init__(self, input_dim: int, output_dim: int): def __init__(self, input_dim: int, output_dim: int):
super().__init__() super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.output_linear = nn.Linear(input_dim, output_dim) self.output_linear = nn.Linear(input_dim, output_dim)
def forward( def forward(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
encoder_out_len: torch.Tensor,
decoder_out_len: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
encoder_out: encoder_out:
Output from the encoder. Its shape is (N, T, C). Output from the encoder. Its shape is (N, T, self.input_dim).
decoder_out: decoder_out:
Output from the decoder. Its shape is (N, U, C). Output from the decoder. Its shape is (N, U, self.input_dim).
Returns: Returns:
Return a tensor of shape (N, T, U, C). Return a tensor of shape (sum_all_TU, self.output_dim).
""" """
assert encoder_out.ndim == decoder_out.ndim == 3 assert encoder_out.ndim == decoder_out.ndim == 3
assert encoder_out.size(0) == decoder_out.size(0) assert encoder_out.size(0) == decoder_out.size(0)
assert encoder_out.size(2) == decoder_out.size(2) assert encoder_out.size(2) == self.input_dim
assert decoder_out.size(2) == self.input_dim
encoder_out = encoder_out.unsqueeze(2) N = encoder_out.size(0)
# Now encoder_out is (N, T, 1, C)
decoder_out = decoder_out.unsqueeze(1) encoder_out_list = [
# Now decoder_out is (N, 1, U, C) encoder_out[i, : encoder_out_len[i], :] for i in range(N)
]
logit = encoder_out + decoder_out decoder_out_list = [
logit = torch.tanh(logit) decoder_out[i, : decoder_out_len[i], :] for i in range(N)
]
output = self.output_linear(logit) x = [
e.unsqueeze(1) + d.unsqueeze(0)
for e, d in zip(encoder_out_list, decoder_out_list)
]
return output x = [p.reshape(-1, self.input_dim) for p in x]
x = torch.cat(x)
activations = torch.tanh(x)
logits = self.output_linear(activations)
return logits

View File

@ -14,15 +14,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Note we use `rnnt_loss` from torchaudio, which exists only in
torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
"""
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchaudio
import torchaudio.functional
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from icefall.utils import add_sos from icefall.utils import add_sos
@ -102,18 +96,24 @@ class Transducer(nn.Module):
decoder_out = self.decoder(sos_y_padded) decoder_out = self.decoder(sos_y_padded)
logits = self.joiner(encoder_out, decoder_out) # +1 here since a blank is prepended to each utterance.
logits = self.joiner(
encoder_out=encoder_out,
decoder_out=decoder_out,
encoder_out_len=x_lens,
decoder_out_len=y_lens + 1,
)
# rnnt_loss requires 0 padded targets # rnnt_loss requires 0 padded targets
# Note: y does not start with SOS # Note: y does not start with SOS
y_padded = y.pad(mode="constant", padding_value=0) y_padded = y.pad(mode="constant", padding_value=0)
assert hasattr(torchaudio.functional, "rnnt_loss"), ( # We don't put this `import` at the beginning of the file
f"Current torchaudio version: {torchaudio.__version__}\n" # as it is required only in the training, not during the
"Please install a version >= 0.10.0" # reference stage
) import optimized_transducer
loss = torchaudio.functional.rnnt_loss( loss = optimized_transducer.transducer_loss(
logits=logits, logits=logits,
targets=y_padded, targets=y_padded,
logit_lengths=x_lens, logit_lengths=x_lens,

View File

@ -180,7 +180,7 @@ class YesNoAsrDataModule(DataModule):
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=23)) FbankConfig(sampling_rate=8000, num_mel_bins=23)
), ),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )

View File

@ -3,3 +3,4 @@ kaldialign
sentencepiece>=0.1.96 sentencepiece>=0.1.96
tensorboard tensorboard
typeguard typeguard
optimized_transducer