From a8150021e01d34ecbd6198fe03a57eacf47a16f2 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 7 Feb 2022 18:37:36 +0800 Subject: [PATCH] Use modified transducer loss in training. (#179) * Use modified transducer loss in training. * Minor fix. * Add modified beam search. * Add modified beam search. * Minor fixes. * Fix typo. * Update RESULTS. * Fix a typo. * Minor fixes. --- .../run-pretrained-transducer-stateless.yml | 71 +++++++-- README.md | 6 +- egs/librispeech/ASR/RESULTS.md | 53 ++++--- .../ASR/transducer_stateless/beam_search.py | 140 ++++++++++++++++-- .../ASR/transducer_stateless/decode.py | 21 ++- .../ASR/transducer_stateless/decoder.py | 2 +- .../ASR/transducer_stateless/model.py | 16 ++ .../ASR/transducer_stateless/pretrained.py | 22 ++- .../ASR/transducer_stateless/train.py | 18 ++- 9 files changed, 288 insertions(+), 61 deletions(-) diff --git a/.github/workflows/run-pretrained-transducer-stateless.yml b/.github/workflows/run-pretrained-transducer-stateless.yml index 5f4a425d9..de66b90c5 100644 --- a/.github/workflows/run-pretrained-transducer-stateless.yml +++ b/.github/workflows/run-pretrained-transducer-stateless.yml @@ -74,24 +74,53 @@ jobs: mkdir tmp cd tmp git lfs install - git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10 + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07 cd .. tree tmp - soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/*.wav - ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/*.wav + soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/*.wav + ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/*.wav - - name: Run greedy search decoding + - name: Run greedy search decoding (max-sym-per-frame 1) shell: bash run: | export PYTHONPATH=$PWD:PYTHONPATH cd egs/librispeech/ASR ./transducer_stateless/pretrained.py \ --method greedy_search \ - --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/data/lang_bpe_500/bpe.model \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1089-134686-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0002.wav + --max-sym-per-frame 1 \ + --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \ + --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav + + - name: Run greedy search decoding (max-sym-per-frame 2) + shell: bash + run: | + export PYTHONPATH=$PWD:PYTHONPATH + cd egs/librispeech/ASR + ./transducer_stateless/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame 2 \ + --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \ + --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav + + - name: Run greedy search decoding (max-sym-per-frame 3) + shell: bash + run: | + export PYTHONPATH=$PWD:PYTHONPATH + cd egs/librispeech/ASR + ./transducer_stateless/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame 3 \ + --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \ + --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav - name: Run beam search decoding shell: bash @@ -101,8 +130,22 @@ jobs: ./transducer_stateless/pretrained.py \ --method beam_search \ --beam-size 4 \ - --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/data/lang_bpe_500/bpe.model \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1089-134686-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0002.wav + --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \ + --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav + + - name: Run modified beam search decoding + shell: bash + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + cd egs/librispeech/ASR + ./transducer_stateless/pretrained.py \ + --method modified_beam_search \ + --beam-size 4 \ + --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \ + --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav diff --git a/README.md b/README.md index 38c25900f..28c9b6ce4 100644 --- a/README.md +++ b/README.md @@ -80,16 +80,16 @@ We provide a Colab notebook to run a pre-trained RNN-T conformer model: [![Open Using Conformer as encoder. The decoder consists of 1 embedding layer and 1 convolutional layer. -The best WER using beam search with beam size 4 is: +The best WER using modified beam search with beam size 4 is: | | test-clean | test-other | |-----|------------|------------| -| WER | 2.68 | 6.72 | +| WER | 2.67 | 6.64 | Note: No auxiliary losses are used in the training and no LMs are used in the decoding. -We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Rc4Is-3Yp9LbcEz_Iy8hfyenyHsyjvqE?usp=sharing) +We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1CO1bXJ-2khDckZIW8zjOPHGSKLHpTDlp?usp=sharing) ### Aishell diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index ffeaaae68..17679ba3d 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -4,62 +4,73 @@ #### Conformer encoder + embedding decoder -Using commit `4c1b3665ee6efb935f4dd93a80ff0e154b13efb6`. +Using commit `TODO`. -Conformer encoder + non-current decoder. The decoder +Conformer encoder + non-recurrent decoder. The decoder contains only an embedding layer and a Conv1d (with kernel size 2). The WERs are -| | test-clean | test-other | comment | -|---------------------------|------------|------------|------------------------------------------| -| greedy search | 2.69 | 6.81 | --epoch 71, --avg 15, --max-duration 100 | -| beam search (beam size 4) | 2.68 | 6.72 | --epoch 71, --avg 15, --max-duration 100 | +| | test-clean | test-other | comment | +|-------------------------------------|------------|------------|------------------------------------------| +| greedy search (max sym per frame 1) | 2.68 | 6.71 | --epoch 61, --avg 18, --max-duration 100 | +| greedy search (max sym per frame 2) | 2.69 | 6.71 | --epoch 61, --avg 18, --max-duration 100 | +| greedy search (max sym per frame 3) | 2.69 | 6.71 | --epoch 61, --avg 18, --max-duration 100 | +| modified beam search (beam size 4) | 2.67 | 6.64 | --epoch 61, --avg 18, --max-duration 100 | + The training command for reproducing is given below: ``` +cd egs/librispeech/ASR/ +./prepare.sh export CUDA_VISIBLE_DEVICES="0,1,2,3" - ./transducer_stateless/train.py \ --world-size 4 \ --num-epochs 76 \ --start-epoch 0 \ --exp-dir transducer_stateless/exp-full \ --full-libri 1 \ - --max-duration 250 \ - --lr-factor 3 + --max-duration 300 \ + --lr-factor 5 \ + --bpe-model data/lang_bpe_500/bpe.model \ + --modified-transducer-prob 0.25 ``` The tensorboard training log can be found at - + The decoding command is: ``` -epoch=71 -avg=15 +epoch=61 +avg=18 ## greedy search -./transducer_stateless/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir transducer_stateless/exp-full \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --max-duration 100 +for sym in 1 2 3; do + ./transducer_stateless/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir transducer_stateless/exp-full \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --max-duration 100 \ + --max-sym-per-frame $sym +done + +## modified beam search -## beam search ./transducer_stateless/decode.py \ --epoch $epoch \ --avg $avg \ --exp-dir transducer_stateless/exp-full \ --bpe-model ./data/lang_bpe_500/bpe.model \ --max-duration 100 \ - --decoding-method beam_search \ + --context-size 2 \ + --decoding-method modified_beam_search \ --beam-size 4 ``` You can find a pretrained model by visiting - + #### Conformer encoder + LSTM decoder diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index 1cce48235..c5efb733d 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -17,7 +17,6 @@ from dataclasses import dataclass from typing import Dict, List, Optional -import numpy as np import torch from model import Transducer @@ -108,8 +107,9 @@ class Hypothesis: # Newly predicted tokens are appended to `ys`. ys: List[int] - # The log prob of ys - log_prob: float + # The log prob of ys. + # It contains only one entry. + log_prob: torch.Tensor @property def key(self) -> str: @@ -145,8 +145,10 @@ class HypothesisList(object): """ key = hyp.key if key in self: - old_hyp = self._data[key] - old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob) + old_hyp = self._data[key] # shallow copy + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -184,7 +186,7 @@ class HypothesisList(object): assert key in self, f"{key} does not exist" del self._data[key] - def filter(self, threshold: float) -> "HypothesisList": + def filter(self, threshold: torch.Tensor) -> "HypothesisList": """Remove all Hypotheses whose log_prob is less than threshold. Caution: @@ -312,6 +314,113 @@ def run_joiner( return log_prob +def modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, +) -> List[int]: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + Returns: + Return the decoded result. + """ + + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + device = model.device + + decoder_input = torch.tensor( + [blank_id] * context_size, device=device + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + + T = encoder_out.size(1) + + B = HypothesisList() + B.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + encoder_out_len = torch.tensor([1]) + decoder_out_len = torch.tensor([1]) + + for t in range(T): + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :] + # current_encoder_out is of shape (1, 1, encoder_out_dim) + # fmt: on + A = list(B) + B = HypothesisList() + + ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) + # ys_log_probs is of shape (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyp in A], + device=device, + ) + # decoder_input is of shape (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + # decoder_output is of shape (num_hyps, 1, decoder_output_dim) + + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, -1 + ) + + logits = model.joiner( + current_encoder_out, + decoder_out, + encoder_out_len.expand(decoder_out.size(0)), + decoder_out_len.expand(decoder_out.size(0)), + ) + # logits is of shape (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) + + log_probs.add_(ys_log_probs) + + log_probs = log_probs.reshape(-1) + topk_log_probs, topk_indexes = log_probs.topk(beam) + + # topk_hyp_indexes are indexes into `A` + topk_hyp_indexes = topk_indexes // logits.size(-1) + topk_token_indexes = topk_indexes % logits.size(-1) + + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() + + for i in range(len(topk_hyp_indexes)): + hyp = A[topk_hyp_indexes[i]] + new_ys = hyp.ys[:] + new_token = topk_token_indexes[i] + if new_token != blank_id: + new_ys.append(new_token) + new_log_prob = topk_log_probs[i] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B.add(new_hyp) + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + return ys + + def beam_search( model: Transducer, encoder_out: torch.Tensor, @@ -351,7 +460,12 @@ def beam_search( t = 0 B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) + B.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) max_sym_per_utt = 20000 @@ -371,9 +485,6 @@ def beam_search( joint_cache: Dict[str, torch.Tensor] = {} - # TODO(fangjun): Implement prefix search to update the `log_prob` - # of hypotheses in A - while True: y_star = A.get_most_probable() A.remove(y_star) @@ -396,18 +507,21 @@ def beam_search( # First, process the blank symbol skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob.item() + new_y_star_log_prob = y_star.log_prob + skip_log_prob # ys[:] returns a copy of ys B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) # Second, process other non-blank labels values, indices = log_prob.topk(beam + 1) - for i, v in zip(indices.tolist(), values.tolist()): + for idx in range(values.size(0)): + i = indices[idx].item() if i == blank_id: continue + new_ys = y_star.ys + [i] - new_log_prob = y_star.log_prob + v + + new_log_prob = y_star.log_prob + values[idx] A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) # Check whether B contains more than "beam" elements more probable diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index e5987b75e..c101d9397 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -46,7 +46,7 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import beam_search, greedy_search +from beam_search import beam_search, greedy_search, modified_beam_search from conformer import Conformer from decoder import Decoder from joiner import Joiner @@ -104,6 +104,7 @@ def get_parser(): help="""Possible values are: - greedy_search - beam_search + - modified_beam_search """, ) @@ -111,7 +112,8 @@ def get_parser(): "--beam-size", type=int, default=4, - help="Used only when --decoding-method is beam_search", + help="""Used only when --decoding-method is + beam_search or modified_beam_search""", ) parser.add_argument( @@ -125,7 +127,8 @@ def get_parser(): "--max-sym-per-frame", type=int, default=3, - help="Maximum number of symbols per frame", + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", ) return parser @@ -256,6 +259,10 @@ def decode_one_batch( hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size ) + elif params.decoding_method == "modified_beam_search": + hyp = modified_beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" @@ -389,11 +396,15 @@ def main(): params = get_params() params.update(vars(args)) - assert params.decoding_method in ("greedy_search", "beam_search") + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "modified_beam_search", + ) params.res_dir = params.exp_dir / params.decoding_method params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.decoding_method == "beam_search": + if "beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index c2c6552a9..b82fed37b 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -75,7 +75,7 @@ class Decoder(nn.Module): """ Args: y: - A 2-D tensor of shape (N, U) with blank prepended. + A 2-D tensor of shape (N, U). need_pad: True to left pad the input. Should be True during training. False to not pad the input. Should be False during inference. diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 17b5f63e5..8281e1fb5 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random + import k2 import torch import torch.nn as nn @@ -62,6 +64,7 @@ class Transducer(nn.Module): x: torch.Tensor, x_lens: torch.Tensor, y: k2.RaggedTensor, + modified_transducer_prob: float = 0.0, ) -> torch.Tensor: """ Args: @@ -73,6 +76,8 @@ class Transducer(nn.Module): y: A ragged tensor with 2 axes [utt][label]. It contains labels of each utterance. + modified_transducer_prob: + The probability to use modified transducer loss. Returns: Return the transducer loss. """ @@ -114,6 +119,16 @@ class Transducer(nn.Module): # reference stage import optimized_transducer + assert 0 <= modified_transducer_prob <= 1 + + if modified_transducer_prob == 0: + one_sym_per_frame = False + elif random.random() < modified_transducer_prob: + # random.random() returns a float in the range [0, 1) + one_sym_per_frame = True + else: + one_sym_per_frame = False + loss = optimized_transducer.transducer_loss( logits=logits, targets=y_padded, @@ -121,6 +136,7 @@ class Transducer(nn.Module): target_lengths=y_lens, blank=blank_id, reduction="sum", + one_sym_per_frame=one_sym_per_frame, from_log_softmax=False, ) diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py index c248de777..ad8d89918 100755 --- a/egs/librispeech/ASR/transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py @@ -22,10 +22,11 @@ Usage: --checkpoint ./transducer_stateless/exp/pretrained.pt \ --bpe-model ./data/lang_bpe_500/bpe.model \ --method greedy_search \ + --max-sym-per-frame 1 \ /path/to/foo.wav \ /path/to/bar.wav \ -(1) beam search +(2) beam search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ --bpe-model ./data/lang_bpe_500/bpe.model \ @@ -34,6 +35,15 @@ Usage: /path/to/foo.wav \ /path/to/bar.wav \ +(3) modified beam search +./transducer_stateless/pretrained.py \ + --checkpoint ./transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav \ + You can also use `./transducer_stateless/exp/epoch-xx.pt`. Note: ./transducer_stateless/exp/pretrained.pt is generated by @@ -51,7 +61,7 @@ import sentencepiece as spm import torch import torch.nn as nn import torchaudio -from beam_search import beam_search, greedy_search +from beam_search import beam_search, greedy_search, modified_beam_search from conformer import Conformer from decoder import Decoder from joiner import Joiner @@ -91,6 +101,7 @@ def get_parser(): help="""Possible values are: - greedy_search - beam_search + - modified_beam_search """, ) @@ -108,7 +119,7 @@ def get_parser(): "--beam-size", type=int, default=4, - help="Used only when --method is beam_search", + help="Used only when --method is beam_search and modified_beam_search ", ) parser.add_argument( @@ -218,6 +229,7 @@ def read_sound_files( return ans +@torch.no_grad() def main(): parser = get_parser() args = parser.parse_args() @@ -301,6 +313,10 @@ def main(): hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size ) + elif params.method == "modified_beam_search": + hyp = modified_beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) else: raise ValueError(f"Unsupported method: {params.method}") diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 950a88a35..544f6e9b1 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -138,6 +138,17 @@ def get_parser(): "2 means tri-gram", ) + parser.add_argument( + "--modified-transducer-prob", + type=float, + default=0.25, + help="""The probability to use modified transducer loss. + In modified transduer, it limits the maximum number of symbols + per frame to 1. See also the option --max-sym-per-frame in + transducer_stateless/decode.py + """, + ) + return parser @@ -383,7 +394,12 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - loss = model(x=feature, x_lens=feature_lens, y=y) + loss = model( + x=feature, + x_lens=feature_lens, + y=y, + modified_transducer_prob=params.modified_transducer_prob, + ) assert loss.requires_grad == is_training