diff --git a/.github/workflows/run-pretrained-transducer-stateless.yml b/.github/workflows/run-pretrained-transducer-stateless.yml index 5f4a425d9..de66b90c5 100644 --- a/.github/workflows/run-pretrained-transducer-stateless.yml +++ b/.github/workflows/run-pretrained-transducer-stateless.yml @@ -74,24 +74,53 @@ jobs: mkdir tmp cd tmp git lfs install - git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10 + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07 cd .. tree tmp - soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/*.wav - ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/*.wav + soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/*.wav + ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/*.wav - - name: Run greedy search decoding + - name: Run greedy search decoding (max-sym-per-frame 1) shell: bash run: | export PYTHONPATH=$PWD:PYTHONPATH cd egs/librispeech/ASR ./transducer_stateless/pretrained.py \ --method greedy_search \ - --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/data/lang_bpe_500/bpe.model \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1089-134686-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0002.wav + --max-sym-per-frame 1 \ + --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \ + --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav + + - name: Run greedy search decoding (max-sym-per-frame 2) + shell: bash + run: | + export PYTHONPATH=$PWD:PYTHONPATH + cd egs/librispeech/ASR + ./transducer_stateless/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame 2 \ + --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \ + --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav + + - name: Run greedy search decoding (max-sym-per-frame 3) + shell: bash + run: | + export PYTHONPATH=$PWD:PYTHONPATH + cd egs/librispeech/ASR + ./transducer_stateless/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame 3 \ + --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \ + --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav - name: Run beam search decoding shell: bash @@ -101,8 +130,22 @@ jobs: ./transducer_stateless/pretrained.py \ --method beam_search \ --beam-size 4 \ - --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/data/lang_bpe_500/bpe.model \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1089-134686-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0002.wav + --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \ + --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav + + - name: Run modified beam search decoding + shell: bash + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + cd egs/librispeech/ASR + ./transducer_stateless/pretrained.py \ + --method modified_beam_search \ + --beam-size 4 \ + --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \ + --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav diff --git a/README.md b/README.md index 38c25900f..28c9b6ce4 100644 --- a/README.md +++ b/README.md @@ -80,16 +80,16 @@ We provide a Colab notebook to run a pre-trained RNN-T conformer model: [![Open Using Conformer as encoder. The decoder consists of 1 embedding layer and 1 convolutional layer. -The best WER using beam search with beam size 4 is: +The best WER using modified beam search with beam size 4 is: | | test-clean | test-other | |-----|------------|------------| -| WER | 2.68 | 6.72 | +| WER | 2.67 | 6.64 | Note: No auxiliary losses are used in the training and no LMs are used in the decoding. -We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Rc4Is-3Yp9LbcEz_Iy8hfyenyHsyjvqE?usp=sharing) +We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1CO1bXJ-2khDckZIW8zjOPHGSKLHpTDlp?usp=sharing) ### Aishell diff --git a/egs/aishell/ASR/transducer_stateless/decoder.py b/egs/aishell/ASR/transducer_stateless/decoder.py index dca084477..c2c6552a9 100644 --- a/egs/aishell/ASR/transducer_stateless/decoder.py +++ b/egs/aishell/ASR/transducer_stateless/decoder.py @@ -82,17 +82,17 @@ class Decoder(nn.Module): Returns: Return a tensor of shape (N, U, embedding_dim). """ - embeding_out = self.embedding(y) + embedding_out = self.embedding(y) if self.context_size > 1: - embeding_out = embeding_out.permute(0, 2, 1) + embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embeding_out = F.pad( - embeding_out, pad=(self.context_size - 1, 0) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) ) else: # During inference time, there is no need to do extra padding # as we only need one output - assert embeding_out.size(-1) == self.context_size - embeding_out = self.conv(embeding_out) - embeding_out = embeding_out.permute(0, 2, 1) - return embeding_out + assert embedding_out.size(-1) == self.context_size + embedding_out = self.conv(embedding_out) + embedding_out = embedding_out.permute(0, 2, 1) + return embedding_out diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py index 641555bdb..5687260df 100755 --- a/egs/aishell/ASR/transducer_stateless/export.py +++ b/egs/aishell/ASR/transducer_stateless/export.py @@ -48,6 +48,7 @@ from pathlib import Path import sentencepiece as spm import torch +import torch.nn as nn from conformer import Conformer from decoder import Decoder from joiner import Joiner @@ -133,7 +134,7 @@ def get_params() -> AttributeDict: return params -def get_encoder_model(params: AttributeDict): +def get_encoder_model(params: AttributeDict) -> nn.Module: encoder = Conformer( num_features=params.feature_dim, output_dim=params.encoder_out_dim, @@ -147,7 +148,7 @@ def get_encoder_model(params: AttributeDict): return encoder -def get_decoder_model(params: AttributeDict): +def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, embedding_dim=params.encoder_out_dim, @@ -157,7 +158,7 @@ def get_decoder_model(params: AttributeDict): return decoder -def get_joiner_model(params: AttributeDict): +def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( input_dim=params.encoder_out_dim, output_dim=params.vocab_size, @@ -165,7 +166,7 @@ def get_joiner_model(params: AttributeDict): return joiner -def get_transducer_model(params: AttributeDict): +def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py index 65ac5f3ff..db89c4d67 100755 --- a/egs/aishell/ASR/transducer_stateless/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless/pretrained.py @@ -44,11 +44,12 @@ Note: ./transducer_stateless/exp/pretrained.pt is generated by import argparse import logging import math -from typing import List from pathlib import Path +from typing import List import kaldifeat import torch +import torch.nn as nn import torchaudio from beam_search import beam_search, greedy_search from conformer import Conformer @@ -57,10 +58,10 @@ from joiner import Joiner from model import Transducer from torch.nn.utils.rnn import pad_sequence -from icefall.env import get_env_info -from icefall.utils import AttributeDict -from icefall.lexicon import Lexicon from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict def get_parser(): @@ -150,7 +151,7 @@ def get_params() -> AttributeDict: return params -def get_encoder_model(params: AttributeDict): +def get_encoder_model(params: AttributeDict) -> nn.Module: encoder = Conformer( num_features=params.feature_dim, output_dim=params.encoder_out_dim, @@ -164,7 +165,7 @@ def get_encoder_model(params: AttributeDict): return encoder -def get_decoder_model(params: AttributeDict): +def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, embedding_dim=params.encoder_out_dim, @@ -174,7 +175,7 @@ def get_decoder_model(params: AttributeDict): return decoder -def get_joiner_model(params: AttributeDict): +def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( input_dim=params.encoder_out_dim, output_dim=params.vocab_size, @@ -182,7 +183,7 @@ def get_joiner_model(params: AttributeDict): return joiner -def get_transducer_model(params: AttributeDict): +def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py index 7da8e28a1..0c180b260 100755 --- a/egs/aishell/ASR/transducer_stateless/train.py +++ b/egs/aishell/ASR/transducer_stateless/train.py @@ -204,7 +204,7 @@ def get_params() -> AttributeDict: return params -def get_encoder_model(params: AttributeDict): +def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Conformer and Transformer encoder = Conformer( num_features=params.feature_dim, @@ -219,7 +219,7 @@ def get_encoder_model(params: AttributeDict): return encoder -def get_decoder_model(params: AttributeDict): +def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, embedding_dim=params.encoder_out_dim, @@ -229,7 +229,7 @@ def get_decoder_model(params: AttributeDict): return decoder -def get_joiner_model(params: AttributeDict): +def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( input_dim=params.encoder_out_dim, output_dim=params.vocab_size, @@ -237,7 +237,7 @@ def get_joiner_model(params: AttributeDict): return joiner -def get_transducer_model(params: AttributeDict): +def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index ffeaaae68..d78447593 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -4,62 +4,73 @@ #### Conformer encoder + embedding decoder -Using commit `4c1b3665ee6efb935f4dd93a80ff0e154b13efb6`. +Using commit `a8150021e01d34ecbd6198fe03a57eacf47a16f2`. -Conformer encoder + non-current decoder. The decoder +Conformer encoder + non-recurrent decoder. The decoder contains only an embedding layer and a Conv1d (with kernel size 2). The WERs are -| | test-clean | test-other | comment | -|---------------------------|------------|------------|------------------------------------------| -| greedy search | 2.69 | 6.81 | --epoch 71, --avg 15, --max-duration 100 | -| beam search (beam size 4) | 2.68 | 6.72 | --epoch 71, --avg 15, --max-duration 100 | +| | test-clean | test-other | comment | +|-------------------------------------|------------|------------|------------------------------------------| +| greedy search (max sym per frame 1) | 2.68 | 6.71 | --epoch 61, --avg 18, --max-duration 100 | +| greedy search (max sym per frame 2) | 2.69 | 6.71 | --epoch 61, --avg 18, --max-duration 100 | +| greedy search (max sym per frame 3) | 2.69 | 6.71 | --epoch 61, --avg 18, --max-duration 100 | +| modified beam search (beam size 4) | 2.67 | 6.64 | --epoch 61, --avg 18, --max-duration 100 | + The training command for reproducing is given below: ``` +cd egs/librispeech/ASR/ +./prepare.sh export CUDA_VISIBLE_DEVICES="0,1,2,3" - ./transducer_stateless/train.py \ --world-size 4 \ --num-epochs 76 \ --start-epoch 0 \ --exp-dir transducer_stateless/exp-full \ --full-libri 1 \ - --max-duration 250 \ - --lr-factor 3 + --max-duration 300 \ + --lr-factor 5 \ + --bpe-model data/lang_bpe_500/bpe.model \ + --modified-transducer-prob 0.25 ``` The tensorboard training log can be found at - + The decoding command is: ``` -epoch=71 -avg=15 +epoch=61 +avg=18 ## greedy search -./transducer_stateless/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir transducer_stateless/exp-full \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --max-duration 100 +for sym in 1 2 3; do + ./transducer_stateless/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir transducer_stateless/exp-full \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --max-duration 100 \ + --max-sym-per-frame $sym +done + +## modified beam search -## beam search ./transducer_stateless/decode.py \ --epoch $epoch \ --avg $avg \ --exp-dir transducer_stateless/exp-full \ --bpe-model ./data/lang_bpe_500/bpe.model \ --max-duration 100 \ - --decoding-method beam_search \ + --context-size 2 \ + --decoding-method modified_beam_search \ --beam-size 4 ``` You can find a pretrained model by visiting - + #### Conformer encoder + LSTM decoder diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index c1fa814c0..cb0bd5c2d 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -41,6 +41,7 @@ from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -123,6 +124,15 @@ def get_parser(): """, ) + parser.add_argument( + "--num-decoder-layers", + type=int, + default=6, + help="""Number of decoder layer of transformer decoder. + Setting this to 0 will not create the decoder at all (pure CTC model) + """, + ) + parser.add_argument( "--lr-factor", type=float, @@ -210,7 +220,6 @@ def get_params() -> AttributeDict: "use_feat_batchnorm": True, "attention_dim": 512, "nhead": 8, - "num_decoder_layers": 6, # parameters for loss "beam_size": 10, "reduction": "sum", @@ -357,9 +366,17 @@ def compute_loss( supervisions, subsampling_factor=params.subsampling_factor ) - token_ids = graph_compiler.texts_to_ids(texts) - - decoding_graph = graph_compiler.compile(token_ids) + if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler): + # Works with a BPE model + token_ids = graph_compiler.texts_to_ids(texts) + decoding_graph = graph_compiler.compile(token_ids) + elif isinstance(graph_compiler, CtcTrainingGraphCompiler): + # Works with a phone lexicon + decoding_graph = graph_compiler.compile(texts) + else: + raise ValueError( + f"Unsupported type of graph compiler: {type(graph_compiler)}" + ) dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -584,12 +601,38 @@ def run(rank, world_size, args): if torch.cuda.is_available(): device = torch.device("cuda", rank) - graph_compiler = BpeCtcTrainingGraphCompiler( - params.lang_dir, - device=device, - sos_token="", - eos_token="", - ) + if "lang_bpe" in params.lang_dir: + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + elif "lang_phone" in params.lang_dir: + assert params.att_rate == 0, ( + "Attention decoder training does not support phone lang dirs " + "at this time due to a missing symbol. Set --att-rate=0 " + "for pure CTC training when using a phone-based lang dir." + ) + assert params.num_decoder_layers == 0, ( + "Attention decoder training does not support phone lang dirs " + "at this time due to a missing symbol. " + "Set --num-decoder-layers=0 for pure CTC training when using " + "a phone-based lang dir." + ) + graph_compiler = CtcTrainingGraphCompiler( + lexicon, + device=device, + ) + # Manually add the sos/eos ID with their default values + # from the BPE recipe which we're adapting here. + graph_compiler.sos_id = 1 + graph_compiler.eos_id = 1 + else: + raise ValueError( + f"Unsupported type of lang dir (we expected it to have " + f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" + ) logging.info("About to create model") model = Conformer( @@ -607,7 +650,9 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: - model = DDP(model, device_ids=[rank]) + # Note: find_unused_parameters=True is needed in case we + # want to set params.att_rate = 0 (i.e. att decoder is not trained) + model = DDP(model, device_ids=[rank], find_unused_parameters=True) optimizer = Noam( model.parameters(), diff --git a/egs/librispeech/ASR/transducer/beam_search.py b/egs/librispeech/ASR/transducer/beam_search.py index f45d06ce9..11032f31a 100644 --- a/egs/librispeech/ASR/transducer/beam_search.py +++ b/egs/librispeech/ASR/transducer/beam_search.py @@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: blank_id = model.decoder.blank_id device = model.device - sos = torch.tensor([blank_id], device=device).reshape(1, 1) + sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape( + 1, 1 + ) decoder_out, (h, c) = model.decoder(sos) T = encoder_out.size(1) t = 0 diff --git a/egs/librispeech/ASR/transducer/model.py b/egs/librispeech/ASR/transducer/model.py index fa0b2dd68..8305248c9 100644 --- a/egs/librispeech/ASR/transducer/model.py +++ b/egs/librispeech/ASR/transducer/model.py @@ -99,6 +99,7 @@ class Transducer(nn.Module): sos_y = add_sos(y, sos_id=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + sos_y_padded = sos_y_padded.to(torch.int64) decoder_out, _ = self.decoder(sos_y_padded) diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py index dfc22fcf8..3531a9633 100644 --- a/egs/librispeech/ASR/transducer_lstm/beam_search.py +++ b/egs/librispeech/ASR/transducer_lstm/beam_search.py @@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: blank_id = model.decoder.blank_id device = model.device - sos = torch.tensor([blank_id], device=device).reshape(1, 1) + sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape( + 1, 1 + ) decoder_out, (h, c) = model.decoder(sos) T = encoder_out.size(1) t = 0 diff --git a/egs/librispeech/ASR/transducer_lstm/model.py b/egs/librispeech/ASR/transducer_lstm/model.py index cb9afd8a2..31843b60e 100644 --- a/egs/librispeech/ASR/transducer_lstm/model.py +++ b/egs/librispeech/ASR/transducer_lstm/model.py @@ -101,6 +101,7 @@ class Transducer(nn.Module): sos_y = add_sos(y, sos_id=sos_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + sos_y_padded = sos_y_padded.to(torch.int64) decoder_out, _ = self.decoder(sos_y_padded) diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index 341c74fab..c5efb733d 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -17,7 +17,6 @@ from dataclasses import dataclass from typing import Dict, List, Optional -import numpy as np import torch from model import Transducer @@ -48,7 +47,7 @@ def greedy_search( device = model.device decoder_input = torch.tensor( - [blank_id] * context_size, device=device + [blank_id] * context_size, device=device, dtype=torch.int64 ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -108,8 +107,9 @@ class Hypothesis: # Newly predicted tokens are appended to `ys`. ys: List[int] - # The log prob of ys - log_prob: float + # The log prob of ys. + # It contains only one entry. + log_prob: torch.Tensor @property def key(self) -> str: @@ -145,8 +145,10 @@ class HypothesisList(object): """ key = hyp.key if key in self: - old_hyp = self._data[key] - old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob) + old_hyp = self._data[key] # shallow copy + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -184,7 +186,7 @@ class HypothesisList(object): assert key in self, f"{key} does not exist" del self._data[key] - def filter(self, threshold: float) -> "HypothesisList": + def filter(self, threshold: torch.Tensor) -> "HypothesisList": """Remove all Hypotheses whose log_prob is less than threshold. Caution: @@ -312,6 +314,113 @@ def run_joiner( return log_prob +def modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, +) -> List[int]: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + Returns: + Return the decoded result. + """ + + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + device = model.device + + decoder_input = torch.tensor( + [blank_id] * context_size, device=device + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + + T = encoder_out.size(1) + + B = HypothesisList() + B.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + encoder_out_len = torch.tensor([1]) + decoder_out_len = torch.tensor([1]) + + for t in range(T): + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :] + # current_encoder_out is of shape (1, 1, encoder_out_dim) + # fmt: on + A = list(B) + B = HypothesisList() + + ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) + # ys_log_probs is of shape (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyp in A], + device=device, + ) + # decoder_input is of shape (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + # decoder_output is of shape (num_hyps, 1, decoder_output_dim) + + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, -1 + ) + + logits = model.joiner( + current_encoder_out, + decoder_out, + encoder_out_len.expand(decoder_out.size(0)), + decoder_out_len.expand(decoder_out.size(0)), + ) + # logits is of shape (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) + + log_probs.add_(ys_log_probs) + + log_probs = log_probs.reshape(-1) + topk_log_probs, topk_indexes = log_probs.topk(beam) + + # topk_hyp_indexes are indexes into `A` + topk_hyp_indexes = topk_indexes // logits.size(-1) + topk_token_indexes = topk_indexes % logits.size(-1) + + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() + + for i in range(len(topk_hyp_indexes)): + hyp = A[topk_hyp_indexes[i]] + new_ys = hyp.ys[:] + new_token = topk_token_indexes[i] + if new_token != blank_id: + new_ys.append(new_token) + new_log_prob = topk_log_probs[i] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B.add(new_hyp) + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + return ys + + def beam_search( model: Transducer, encoder_out: torch.Tensor, @@ -351,7 +460,12 @@ def beam_search( t = 0 B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) + B.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) max_sym_per_utt = 20000 @@ -371,9 +485,6 @@ def beam_search( joint_cache: Dict[str, torch.Tensor] = {} - # TODO(fangjun): Implement prefix search to update the `log_prob` - # of hypotheses in A - while True: y_star = A.get_most_probable() A.remove(y_star) @@ -396,18 +507,21 @@ def beam_search( # First, process the blank symbol skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob.item() + new_y_star_log_prob = y_star.log_prob + skip_log_prob # ys[:] returns a copy of ys B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) # Second, process other non-blank labels values, indices = log_prob.topk(beam + 1) - for i, v in zip(indices.tolist(), values.tolist()): + for idx in range(values.size(0)): + i = indices[idx].item() if i == blank_id: continue + new_ys = y_star.ys + [i] - new_log_prob = y_star.log_prob + v + + new_log_prob = y_star.log_prob + values[idx] A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) # Check whether B contains more than "beam" elements more probable diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index e5987b75e..c101d9397 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -46,7 +46,7 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import beam_search, greedy_search +from beam_search import beam_search, greedy_search, modified_beam_search from conformer import Conformer from decoder import Decoder from joiner import Joiner @@ -104,6 +104,7 @@ def get_parser(): help="""Possible values are: - greedy_search - beam_search + - modified_beam_search """, ) @@ -111,7 +112,8 @@ def get_parser(): "--beam-size", type=int, default=4, - help="Used only when --decoding-method is beam_search", + help="""Used only when --decoding-method is + beam_search or modified_beam_search""", ) parser.add_argument( @@ -125,7 +127,8 @@ def get_parser(): "--max-sym-per-frame", type=int, default=3, - help="Maximum number of symbols per frame", + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", ) return parser @@ -256,6 +259,10 @@ def decode_one_batch( hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size ) + elif params.decoding_method == "modified_beam_search": + hyp = modified_beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" @@ -389,11 +396,15 @@ def main(): params = get_params() params.update(vars(args)) - assert params.decoding_method in ("greedy_search", "beam_search") + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "modified_beam_search", + ) params.res_dir = params.exp_dir / params.decoding_method params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.decoding_method == "beam_search": + if "beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index dca084477..b82fed37b 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -75,24 +75,24 @@ class Decoder(nn.Module): """ Args: y: - A 2-D tensor of shape (N, U) with blank prepended. + A 2-D tensor of shape (N, U). need_pad: True to left pad the input. Should be True during training. False to not pad the input. Should be False during inference. Returns: Return a tensor of shape (N, U, embedding_dim). """ - embeding_out = self.embedding(y) + embedding_out = self.embedding(y) if self.context_size > 1: - embeding_out = embeding_out.permute(0, 2, 1) + embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embeding_out = F.pad( - embeding_out, pad=(self.context_size - 1, 0) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) ) else: # During inference time, there is no need to do extra padding # as we only need one output - assert embeding_out.size(-1) == self.context_size - embeding_out = self.conv(embeding_out) - embeding_out = embeding_out.permute(0, 2, 1) - return embeding_out + assert embedding_out.size(-1) == self.context_size + embedding_out = self.conv(embedding_out) + embedding_out = embedding_out.permute(0, 2, 1) + return embedding_out diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py index 641555bdb..5687260df 100755 --- a/egs/librispeech/ASR/transducer_stateless/export.py +++ b/egs/librispeech/ASR/transducer_stateless/export.py @@ -48,6 +48,7 @@ from pathlib import Path import sentencepiece as spm import torch +import torch.nn as nn from conformer import Conformer from decoder import Decoder from joiner import Joiner @@ -133,7 +134,7 @@ def get_params() -> AttributeDict: return params -def get_encoder_model(params: AttributeDict): +def get_encoder_model(params: AttributeDict) -> nn.Module: encoder = Conformer( num_features=params.feature_dim, output_dim=params.encoder_out_dim, @@ -147,7 +148,7 @@ def get_encoder_model(params: AttributeDict): return encoder -def get_decoder_model(params: AttributeDict): +def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, embedding_dim=params.encoder_out_dim, @@ -157,7 +158,7 @@ def get_decoder_model(params: AttributeDict): return decoder -def get_joiner_model(params: AttributeDict): +def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( input_dim=params.encoder_out_dim, output_dim=params.vocab_size, @@ -165,7 +166,7 @@ def get_joiner_model(params: AttributeDict): return joiner -def get_transducer_model(params: AttributeDict): +def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 98a6f0f37..8281e1fb5 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random + import k2 import torch import torch.nn as nn @@ -62,6 +64,7 @@ class Transducer(nn.Module): x: torch.Tensor, x_lens: torch.Tensor, y: k2.RaggedTensor, + modified_transducer_prob: float = 0.0, ) -> torch.Tensor: """ Args: @@ -73,6 +76,8 @@ class Transducer(nn.Module): y: A ragged tensor with 2 axes [utt][label]. It contains labels of each utterance. + modified_transducer_prob: + The probability to use modified transducer loss. Returns: Return the transducer loss. """ @@ -93,6 +98,7 @@ class Transducer(nn.Module): sos_y = add_sos(y, sos_id=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + sos_y_padded = sos_y_padded.to(torch.int64) decoder_out = self.decoder(sos_y_padded) @@ -113,6 +119,16 @@ class Transducer(nn.Module): # reference stage import optimized_transducer + assert 0 <= modified_transducer_prob <= 1 + + if modified_transducer_prob == 0: + one_sym_per_frame = False + elif random.random() < modified_transducer_prob: + # random.random() returns a float in the range [0, 1) + one_sym_per_frame = True + else: + one_sym_per_frame = False + loss = optimized_transducer.transducer_loss( logits=logits, targets=y_padded, @@ -120,6 +136,8 @@ class Transducer(nn.Module): target_lengths=y_lens, blank=blank_id, reduction="sum", + one_sym_per_frame=one_sym_per_frame, + from_log_softmax=False, ) return loss diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py index e5dba8f0e..ad8d89918 100755 --- a/egs/librispeech/ASR/transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py @@ -22,10 +22,11 @@ Usage: --checkpoint ./transducer_stateless/exp/pretrained.pt \ --bpe-model ./data/lang_bpe_500/bpe.model \ --method greedy_search \ + --max-sym-per-frame 1 \ /path/to/foo.wav \ /path/to/bar.wav \ -(1) beam search +(2) beam search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ --bpe-model ./data/lang_bpe_500/bpe.model \ @@ -34,6 +35,15 @@ Usage: /path/to/foo.wav \ /path/to/bar.wav \ +(3) modified beam search +./transducer_stateless/pretrained.py \ + --checkpoint ./transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav \ + You can also use `./transducer_stateless/exp/epoch-xx.pt`. Note: ./transducer_stateless/exp/pretrained.pt is generated by @@ -49,8 +59,9 @@ from typing import List import kaldifeat import sentencepiece as spm import torch +import torch.nn as nn import torchaudio -from beam_search import beam_search, greedy_search +from beam_search import beam_search, greedy_search, modified_beam_search from conformer import Conformer from decoder import Decoder from joiner import Joiner @@ -90,6 +101,7 @@ def get_parser(): help="""Possible values are: - greedy_search - beam_search + - modified_beam_search """, ) @@ -107,7 +119,7 @@ def get_parser(): "--beam-size", type=int, default=4, - help="Used only when --method is beam_search", + help="Used only when --method is beam_search and modified_beam_search ", ) parser.add_argument( @@ -148,7 +160,7 @@ def get_params() -> AttributeDict: return params -def get_encoder_model(params: AttributeDict): +def get_encoder_model(params: AttributeDict) -> nn.Module: encoder = Conformer( num_features=params.feature_dim, output_dim=params.encoder_out_dim, @@ -162,7 +174,7 @@ def get_encoder_model(params: AttributeDict): return encoder -def get_decoder_model(params: AttributeDict): +def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, embedding_dim=params.encoder_out_dim, @@ -172,7 +184,7 @@ def get_decoder_model(params: AttributeDict): return decoder -def get_joiner_model(params: AttributeDict): +def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( input_dim=params.encoder_out_dim, output_dim=params.vocab_size, @@ -180,7 +192,7 @@ def get_joiner_model(params: AttributeDict): return joiner -def get_transducer_model(params: AttributeDict): +def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) @@ -217,6 +229,7 @@ def read_sound_files( return ans +@torch.no_grad() def main(): parser = get_parser() args = parser.parse_args() @@ -300,6 +313,10 @@ def main(): hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size ) + elif params.method == "modified_beam_search": + hyp = modified_beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) else: raise ValueError(f"Unsupported method: {params.method}") diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 694ebf1d5..544f6e9b1 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -138,6 +138,17 @@ def get_parser(): "2 means tri-gram", ) + parser.add_argument( + "--modified-transducer-prob", + type=float, + default=0.25, + help="""The probability to use modified transducer loss. + In modified transduer, it limits the maximum number of symbols + per frame to 1. See also the option --max-sym-per-frame in + transducer_stateless/decode.py + """, + ) + return parser @@ -213,7 +224,7 @@ def get_params() -> AttributeDict: return params -def get_encoder_model(params: AttributeDict): +def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Conformer and Transformer encoder = Conformer( num_features=params.feature_dim, @@ -228,7 +239,7 @@ def get_encoder_model(params: AttributeDict): return encoder -def get_decoder_model(params: AttributeDict): +def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, embedding_dim=params.encoder_out_dim, @@ -238,7 +249,7 @@ def get_decoder_model(params: AttributeDict): return decoder -def get_joiner_model(params: AttributeDict): +def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( input_dim=params.encoder_out_dim, output_dim=params.vocab_size, @@ -246,7 +257,7 @@ def get_joiner_model(params: AttributeDict): return joiner -def get_transducer_model(params: AttributeDict): +def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) @@ -383,7 +394,12 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - loss = model(x=feature, x_lens=feature_lens, y=y) + loss = model( + x=feature, + x_lens=feature_lens, + y=y, + modified_transducer_prob=params.modified_transducer_prob, + ) assert loss.requires_grad == is_training diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py index b4c87d964..570ed7d7a 100644 --- a/icefall/graph_compiler.py +++ b/icefall/graph_compiler.py @@ -89,6 +89,29 @@ class CtcTrainingGraphCompiler(object): return decoding_graph + def texts_to_ids(self, texts: List[str]) -> List[List[int]]: + """Convert a list of texts to a list-of-list of word IDs. + + Args: + texts: + It is a list of strings. Each string consists of space(s) + separated words. An example containing two strings is given below: + + ['HELLO ICEFALL', 'HELLO k2'] + Returns: + Return a list-of-list of word IDs. + """ + word_ids_list = [] + for text in texts: + word_ids = [] + for word in text.split(): + if word in self.word_table: + word_ids.append(self.word_table[word]) + else: + word_ids.append(self.oov_id) + word_ids_list.append(word_ids) + return word_ids_list + def convert_transcript_to_fsa(self, texts: List[str]) -> k2.Fsa: """Convert a list of transcript texts to an FsaVec.