mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Use optimized_transducer to compute transducer loss. (#162)
* WIP: Use optimized_transducer to compute transducer loss. * Minor fixes. * Fix decoding. * Fix decoding. * Add RESULTS. * Update RESULTS. * Update CI. * Fix sampling rate for yesno recipe.
This commit is contained in:
parent
319e120869
commit
4c1b3665ee
@ -74,11 +74,11 @@ jobs:
|
|||||||
mkdir tmp
|
mkdir tmp
|
||||||
cd tmp
|
cd tmp
|
||||||
git lfs install
|
git lfs install
|
||||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27
|
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10
|
||||||
cd ..
|
cd ..
|
||||||
tree tmp
|
tree tmp
|
||||||
soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/*.wav
|
soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/*.wav
|
||||||
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/*.wav
|
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/*.wav
|
||||||
|
|
||||||
- name: Run greedy search decoding
|
- name: Run greedy search decoding
|
||||||
shell: bash
|
shell: bash
|
||||||
@ -87,11 +87,11 @@ jobs:
|
|||||||
cd egs/librispeech/ASR
|
cd egs/librispeech/ASR
|
||||||
./transducer_stateless/pretrained.py \
|
./transducer_stateless/pretrained.py \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/exp/pretrained.pt \
|
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/exp/pretrained.pt \
|
||||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/data/lang_bpe_500/bpe.model \
|
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/data/lang_bpe_500/bpe.model \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1089-134686-0001.wav \
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1089-134686-0001.wav \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0001.wav \
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0001.wav \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0002.wav
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0002.wav
|
||||||
|
|
||||||
- name: Run beam search decoding
|
- name: Run beam search decoding
|
||||||
shell: bash
|
shell: bash
|
||||||
@ -101,8 +101,8 @@ jobs:
|
|||||||
./transducer_stateless/pretrained.py \
|
./transducer_stateless/pretrained.py \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/exp/pretrained.pt \
|
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/exp/pretrained.pt \
|
||||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/data/lang_bpe_500/bpe.model \
|
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/data/lang_bpe_500/bpe.model \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1089-134686-0001.wav \
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1089-134686-0001.wav \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0001.wav \
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0001.wav \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0002.wav
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -84,12 +84,12 @@ The best WER using beam search with beam size 4 is:
|
|||||||
|
|
||||||
| | test-clean | test-other |
|
| | test-clean | test-other |
|
||||||
|-----|------------|------------|
|
|-----|------------|------------|
|
||||||
| WER | 2.83 | 7.19 |
|
| WER | 2.76 | 6.97 |
|
||||||
|
|
||||||
Note: No auxiliary losses are used in the training and no LMs are used
|
Note: No auxiliary losses are used in the training and no LMs are used
|
||||||
in the decoding.
|
in the decoding.
|
||||||
|
|
||||||
We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [](https://colab.research.google.com/drive/1Lm37sNajIpkV4HTzMDF7sn9l0JpfmekN?usp=sharing)
|
We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [](https://colab.research.google.com/drive/1Rc4Is-3Yp9LbcEz_Iy8hfyenyHsyjvqE?usp=sharing)
|
||||||
|
|
||||||
### Aishell
|
### Aishell
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
#### Conformer encoder + embedding decoder
|
#### Conformer encoder + embedding decoder
|
||||||
|
|
||||||
Using commit `14c93add507982306f5a478cd144e0e32e0f970d`.
|
Using commit `TODO`.
|
||||||
|
|
||||||
Conformer encoder + non-current decoder. The decoder
|
Conformer encoder + non-current decoder. The decoder
|
||||||
contains only an embedding layer and a Conv1d (with kernel size 2).
|
contains only an embedding layer and a Conv1d (with kernel size 2).
|
||||||
@ -13,8 +13,8 @@ The WERs are
|
|||||||
|
|
||||||
| | test-clean | test-other | comment |
|
| | test-clean | test-other | comment |
|
||||||
|---------------------------|------------|------------|------------------------------------------|
|
|---------------------------|------------|------------|------------------------------------------|
|
||||||
| greedy search | 2.85 | 7.30 | --epoch 29, --avg 13, --max-duration 100 |
|
| greedy search | 2.77 | 7.07 | --epoch 30, --avg 13, --max-duration 100 |
|
||||||
| beam search (beam size 4) | 2.83 | 7.19 | |
|
| beam search (beam size 4) | 2.76 | 6.97 | |
|
||||||
|
|
||||||
The training command for reproducing is given below:
|
The training command for reproducing is given below:
|
||||||
|
|
||||||
@ -32,11 +32,11 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
```
|
```
|
||||||
|
|
||||||
The tensorboard training log can be found at
|
The tensorboard training log can be found at
|
||||||
<https://tensorboard.dev/experiment/Mjx7MeTgR3Oyr1yBCwjozw/>
|
<https://tensorboard.dev/experiment/6fnVojoUQTmEJVq1yG34Vw/>
|
||||||
|
|
||||||
The decoding command is:
|
The decoding command is:
|
||||||
```
|
```
|
||||||
epoch=29
|
epoch=36
|
||||||
avg=13
|
avg=13
|
||||||
|
|
||||||
## greedy search
|
## greedy search
|
||||||
|
@ -66,6 +66,9 @@ def greedy_search(
|
|||||||
# symbols per utterance decoded so far
|
# symbols per utterance decoded so far
|
||||||
sym_per_utt = 0
|
sym_per_utt = 0
|
||||||
|
|
||||||
|
encoder_out_len = torch.tensor([1])
|
||||||
|
decoder_out_len = torch.tensor([1])
|
||||||
|
|
||||||
while t < T and sym_per_utt < max_sym_per_utt:
|
while t < T and sym_per_utt < max_sym_per_utt:
|
||||||
if sym_per_frame >= max_sym_per_frame:
|
if sym_per_frame >= max_sym_per_frame:
|
||||||
sym_per_frame = 0
|
sym_per_frame = 0
|
||||||
@ -75,7 +78,9 @@ def greedy_search(
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
logits = model.joiner(current_encoder_out, decoder_out)
|
logits = model.joiner(
|
||||||
|
current_encoder_out, decoder_out, encoder_out_len, decoder_out_len
|
||||||
|
)
|
||||||
# logits is (1, 1, 1, vocab_size)
|
# logits is (1, 1, 1, vocab_size)
|
||||||
|
|
||||||
y = logits.argmax().item()
|
y = logits.argmax().item()
|
||||||
@ -262,6 +267,9 @@ def beam_search(
|
|||||||
|
|
||||||
sym_per_utt = 0
|
sym_per_utt = 0
|
||||||
|
|
||||||
|
encoder_out_len = torch.tensor([1])
|
||||||
|
decoder_out_len = torch.tensor([1])
|
||||||
|
|
||||||
decoder_cache: Dict[str, torch.Tensor] = {}
|
decoder_cache: Dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
while t < T and sym_per_utt < max_sym_per_utt:
|
while t < T and sym_per_utt < max_sym_per_utt:
|
||||||
@ -294,7 +302,12 @@ def beam_search(
|
|||||||
|
|
||||||
cached_key += f"-t-{t}"
|
cached_key += f"-t-{t}"
|
||||||
if cached_key not in joint_cache:
|
if cached_key not in joint_cache:
|
||||||
logits = model.joiner(current_encoder_out, decoder_out)
|
logits = model.joiner(
|
||||||
|
current_encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
encoder_out_len,
|
||||||
|
decoder_out_len,
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(fangjun): Ccale the blank posterior
|
# TODO(fangjun): Ccale the blank posterior
|
||||||
|
|
||||||
|
@ -22,33 +22,51 @@ class Joiner(nn.Module):
|
|||||||
def __init__(self, input_dim: int, output_dim: int):
|
def __init__(self, input_dim: int, output_dim: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
self.output_linear = nn.Linear(input_dim, output_dim)
|
self.output_linear = nn.Linear(input_dim, output_dim)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
decoder_out: torch.Tensor,
|
||||||
|
encoder_out_len: torch.Tensor,
|
||||||
|
decoder_out_len: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
encoder_out:
|
encoder_out:
|
||||||
Output from the encoder. Its shape is (N, T, C).
|
Output from the encoder. Its shape is (N, T, self.input_dim).
|
||||||
decoder_out:
|
decoder_out:
|
||||||
Output from the decoder. Its shape is (N, U, C).
|
Output from the decoder. Its shape is (N, U, self.input_dim).
|
||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, T, U, C).
|
Return a tensor of shape (sum_all_TU, self.output_dim).
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == decoder_out.ndim == 3
|
assert encoder_out.ndim == decoder_out.ndim == 3
|
||||||
assert encoder_out.size(0) == decoder_out.size(0)
|
assert encoder_out.size(0) == decoder_out.size(0)
|
||||||
assert encoder_out.size(2) == decoder_out.size(2)
|
assert encoder_out.size(2) == self.input_dim
|
||||||
|
assert decoder_out.size(2) == self.input_dim
|
||||||
|
|
||||||
encoder_out = encoder_out.unsqueeze(2)
|
N = encoder_out.size(0)
|
||||||
# Now encoder_out is (N, T, 1, C)
|
|
||||||
|
|
||||||
decoder_out = decoder_out.unsqueeze(1)
|
encoder_out_list = [
|
||||||
# Now decoder_out is (N, 1, U, C)
|
encoder_out[i, : encoder_out_len[i], :] for i in range(N)
|
||||||
|
]
|
||||||
|
|
||||||
logit = encoder_out + decoder_out
|
decoder_out_list = [
|
||||||
logit = torch.tanh(logit)
|
decoder_out[i, : decoder_out_len[i], :] for i in range(N)
|
||||||
|
]
|
||||||
|
|
||||||
output = self.output_linear(logit)
|
x = [
|
||||||
|
e.unsqueeze(1) + d.unsqueeze(0)
|
||||||
|
for e, d in zip(encoder_out_list, decoder_out_list)
|
||||||
|
]
|
||||||
|
|
||||||
return output
|
x = [p.reshape(-1, self.input_dim) for p in x]
|
||||||
|
x = torch.cat(x)
|
||||||
|
|
||||||
|
activations = torch.tanh(x)
|
||||||
|
|
||||||
|
logits = self.output_linear(activations)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
@ -14,15 +14,9 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
|
||||||
Note we use `rnnt_loss` from torchaudio, which exists only in
|
|
||||||
torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
|
|
||||||
"""
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchaudio
|
|
||||||
import torchaudio.functional
|
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos
|
||||||
@ -102,18 +96,24 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
decoder_out = self.decoder(sos_y_padded)
|
decoder_out = self.decoder(sos_y_padded)
|
||||||
|
|
||||||
logits = self.joiner(encoder_out, decoder_out)
|
# +1 here since a blank is prepended to each utterance.
|
||||||
|
logits = self.joiner(
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
decoder_out=decoder_out,
|
||||||
|
encoder_out_len=x_lens,
|
||||||
|
decoder_out_len=y_lens + 1,
|
||||||
|
)
|
||||||
|
|
||||||
# rnnt_loss requires 0 padded targets
|
# rnnt_loss requires 0 padded targets
|
||||||
# Note: y does not start with SOS
|
# Note: y does not start with SOS
|
||||||
y_padded = y.pad(mode="constant", padding_value=0)
|
y_padded = y.pad(mode="constant", padding_value=0)
|
||||||
|
|
||||||
assert hasattr(torchaudio.functional, "rnnt_loss"), (
|
# We don't put this `import` at the beginning of the file
|
||||||
f"Current torchaudio version: {torchaudio.__version__}\n"
|
# as it is required only in the training, not during the
|
||||||
"Please install a version >= 0.10.0"
|
# reference stage
|
||||||
)
|
import optimized_transducer
|
||||||
|
|
||||||
loss = torchaudio.functional.rnnt_loss(
|
loss = optimized_transducer.transducer_loss(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
targets=y_padded,
|
targets=y_padded,
|
||||||
logit_lengths=x_lens,
|
logit_lengths=x_lens,
|
||||||
|
@ -180,7 +180,7 @@ class YesNoAsrDataModule(DataModule):
|
|||||||
train = K2SpeechRecognitionDataset(
|
train = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(
|
input_strategy=OnTheFlyFeatures(
|
||||||
Fbank(FbankConfig(num_mel_bins=23))
|
FbankConfig(sampling_rate=8000, num_mel_bins=23)
|
||||||
),
|
),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
|
@ -3,3 +3,4 @@ kaldialign
|
|||||||
sentencepiece>=0.1.96
|
sentencepiece>=0.1.96
|
||||||
tensorboard
|
tensorboard
|
||||||
typeguard
|
typeguard
|
||||||
|
optimized_transducer
|
||||||
|
Loading…
x
Reference in New Issue
Block a user