mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Use optimized_transducer to compute transducer loss. (#162)
* WIP: Use optimized_transducer to compute transducer loss. * Minor fixes. * Fix decoding. * Fix decoding. * Add RESULTS. * Update RESULTS. * Update CI. * Fix sampling rate for yesno recipe.
This commit is contained in:
parent
319e120869
commit
4c1b3665ee
@ -74,11 +74,11 @@ jobs:
|
||||
mkdir tmp
|
||||
cd tmp
|
||||
git lfs install
|
||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27
|
||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10
|
||||
cd ..
|
||||
tree tmp
|
||||
soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/*.wav
|
||||
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/*.wav
|
||||
soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/*.wav
|
||||
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/*.wav
|
||||
|
||||
- name: Run greedy search decoding
|
||||
shell: bash
|
||||
@ -87,11 +87,11 @@ jobs:
|
||||
cd egs/librispeech/ASR
|
||||
./transducer_stateless/pretrained.py \
|
||||
--method greedy_search \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0002.wav
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0002.wav
|
||||
|
||||
- name: Run beam search decoding
|
||||
shell: bash
|
||||
@ -101,8 +101,8 @@ jobs:
|
||||
./transducer_stateless/pretrained.py \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0002.wav
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0002.wav
|
||||
|
@ -84,12 +84,12 @@ The best WER using beam search with beam size 4 is:
|
||||
|
||||
| | test-clean | test-other |
|
||||
|-----|------------|------------|
|
||||
| WER | 2.83 | 7.19 |
|
||||
| WER | 2.76 | 6.97 |
|
||||
|
||||
Note: No auxiliary losses are used in the training and no LMs are used
|
||||
in the decoding.
|
||||
|
||||
We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [](https://colab.research.google.com/drive/1Lm37sNajIpkV4HTzMDF7sn9l0JpfmekN?usp=sharing)
|
||||
We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [](https://colab.research.google.com/drive/1Rc4Is-3Yp9LbcEz_Iy8hfyenyHsyjvqE?usp=sharing)
|
||||
|
||||
### Aishell
|
||||
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
#### Conformer encoder + embedding decoder
|
||||
|
||||
Using commit `14c93add507982306f5a478cd144e0e32e0f970d`.
|
||||
Using commit `TODO`.
|
||||
|
||||
Conformer encoder + non-current decoder. The decoder
|
||||
contains only an embedding layer and a Conv1d (with kernel size 2).
|
||||
@ -13,8 +13,8 @@ The WERs are
|
||||
|
||||
| | test-clean | test-other | comment |
|
||||
|---------------------------|------------|------------|------------------------------------------|
|
||||
| greedy search | 2.85 | 7.30 | --epoch 29, --avg 13, --max-duration 100 |
|
||||
| beam search (beam size 4) | 2.83 | 7.19 | |
|
||||
| greedy search | 2.77 | 7.07 | --epoch 30, --avg 13, --max-duration 100 |
|
||||
| beam search (beam size 4) | 2.76 | 6.97 | |
|
||||
|
||||
The training command for reproducing is given below:
|
||||
|
||||
@ -32,11 +32,11 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
```
|
||||
|
||||
The tensorboard training log can be found at
|
||||
<https://tensorboard.dev/experiment/Mjx7MeTgR3Oyr1yBCwjozw/>
|
||||
<https://tensorboard.dev/experiment/6fnVojoUQTmEJVq1yG34Vw/>
|
||||
|
||||
The decoding command is:
|
||||
```
|
||||
epoch=29
|
||||
epoch=36
|
||||
avg=13
|
||||
|
||||
## greedy search
|
||||
|
@ -66,6 +66,9 @@ def greedy_search(
|
||||
# symbols per utterance decoded so far
|
||||
sym_per_utt = 0
|
||||
|
||||
encoder_out_len = torch.tensor([1])
|
||||
decoder_out_len = torch.tensor([1])
|
||||
|
||||
while t < T and sym_per_utt < max_sym_per_utt:
|
||||
if sym_per_frame >= max_sym_per_frame:
|
||||
sym_per_frame = 0
|
||||
@ -75,7 +78,9 @@ def greedy_search(
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||
# fmt: on
|
||||
logits = model.joiner(current_encoder_out, decoder_out)
|
||||
logits = model.joiner(
|
||||
current_encoder_out, decoder_out, encoder_out_len, decoder_out_len
|
||||
)
|
||||
# logits is (1, 1, 1, vocab_size)
|
||||
|
||||
y = logits.argmax().item()
|
||||
@ -262,6 +267,9 @@ def beam_search(
|
||||
|
||||
sym_per_utt = 0
|
||||
|
||||
encoder_out_len = torch.tensor([1])
|
||||
decoder_out_len = torch.tensor([1])
|
||||
|
||||
decoder_cache: Dict[str, torch.Tensor] = {}
|
||||
|
||||
while t < T and sym_per_utt < max_sym_per_utt:
|
||||
@ -294,7 +302,12 @@ def beam_search(
|
||||
|
||||
cached_key += f"-t-{t}"
|
||||
if cached_key not in joint_cache:
|
||||
logits = model.joiner(current_encoder_out, decoder_out)
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
encoder_out_len,
|
||||
decoder_out_len,
|
||||
)
|
||||
|
||||
# TODO(fangjun): Ccale the blank posterior
|
||||
|
||||
|
@ -22,33 +22,51 @@ class Joiner(nn.Module):
|
||||
def __init__(self, input_dim: int, output_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.output_linear = nn.Linear(input_dim, output_dim)
|
||||
|
||||
def forward(
|
||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
encoder_out_len: torch.Tensor,
|
||||
decoder_out_len: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, C).
|
||||
Output from the encoder. Its shape is (N, T, self.input_dim).
|
||||
decoder_out:
|
||||
Output from the decoder. Its shape is (N, U, C).
|
||||
Output from the decoder. Its shape is (N, U, self.input_dim).
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, U, C).
|
||||
Return a tensor of shape (sum_all_TU, self.output_dim).
|
||||
"""
|
||||
assert encoder_out.ndim == decoder_out.ndim == 3
|
||||
assert encoder_out.size(0) == decoder_out.size(0)
|
||||
assert encoder_out.size(2) == decoder_out.size(2)
|
||||
assert encoder_out.size(2) == self.input_dim
|
||||
assert decoder_out.size(2) == self.input_dim
|
||||
|
||||
encoder_out = encoder_out.unsqueeze(2)
|
||||
# Now encoder_out is (N, T, 1, C)
|
||||
N = encoder_out.size(0)
|
||||
|
||||
decoder_out = decoder_out.unsqueeze(1)
|
||||
# Now decoder_out is (N, 1, U, C)
|
||||
encoder_out_list = [
|
||||
encoder_out[i, : encoder_out_len[i], :] for i in range(N)
|
||||
]
|
||||
|
||||
logit = encoder_out + decoder_out
|
||||
logit = torch.tanh(logit)
|
||||
decoder_out_list = [
|
||||
decoder_out[i, : decoder_out_len[i], :] for i in range(N)
|
||||
]
|
||||
|
||||
output = self.output_linear(logit)
|
||||
x = [
|
||||
e.unsqueeze(1) + d.unsqueeze(0)
|
||||
for e, d in zip(encoder_out_list, decoder_out_list)
|
||||
]
|
||||
|
||||
return output
|
||||
x = [p.reshape(-1, self.input_dim) for p in x]
|
||||
x = torch.cat(x)
|
||||
|
||||
activations = torch.tanh(x)
|
||||
|
||||
logits = self.output_linear(activations)
|
||||
|
||||
return logits
|
||||
|
@ -14,15 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Note we use `rnnt_loss` from torchaudio, which exists only in
|
||||
torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
|
||||
"""
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
import torchaudio.functional
|
||||
from encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos
|
||||
@ -102,18 +96,24 @@ class Transducer(nn.Module):
|
||||
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
logits = self.joiner(encoder_out, decoder_out)
|
||||
# +1 here since a blank is prepended to each utterance.
|
||||
logits = self.joiner(
|
||||
encoder_out=encoder_out,
|
||||
decoder_out=decoder_out,
|
||||
encoder_out_len=x_lens,
|
||||
decoder_out_len=y_lens + 1,
|
||||
)
|
||||
|
||||
# rnnt_loss requires 0 padded targets
|
||||
# Note: y does not start with SOS
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
assert hasattr(torchaudio.functional, "rnnt_loss"), (
|
||||
f"Current torchaudio version: {torchaudio.__version__}\n"
|
||||
"Please install a version >= 0.10.0"
|
||||
)
|
||||
# We don't put this `import` at the beginning of the file
|
||||
# as it is required only in the training, not during the
|
||||
# reference stage
|
||||
import optimized_transducer
|
||||
|
||||
loss = torchaudio.functional.rnnt_loss(
|
||||
loss = optimized_transducer.transducer_loss(
|
||||
logits=logits,
|
||||
targets=y_padded,
|
||||
logit_lengths=x_lens,
|
||||
|
@ -180,7 +180,7 @@ class YesNoAsrDataModule(DataModule):
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=23))
|
||||
FbankConfig(sampling_rate=8000, num_mel_bins=23)
|
||||
),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
@ -3,3 +3,4 @@ kaldialign
|
||||
sentencepiece>=0.1.96
|
||||
tensorboard
|
||||
typeguard
|
||||
optimized_transducer
|
||||
|
Loading…
x
Reference in New Issue
Block a user