Use modified transducer loss in training. (#179)

* Use modified transducer loss in training.

* Minor fix.

* Add modified beam search.

* Add modified beam search.

* Minor fixes.

* Fix typo.

* Update RESULTS.

* Fix a typo.

* Minor fixes.
This commit is contained in:
Fangjun Kuang 2022-02-07 18:37:36 +08:00 committed by GitHub
parent 35ecd7e562
commit a8150021e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 288 additions and 61 deletions

View File

@ -74,24 +74,53 @@ jobs:
mkdir tmp mkdir tmp
cd tmp cd tmp
git lfs install git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10 git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07
cd .. cd ..
tree tmp tree tmp
soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/*.wav soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/*.wav
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/*.wav ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/*.wav
- name: Run greedy search decoding - name: Run greedy search decoding (max-sym-per-frame 1)
shell: bash shell: bash
run: | run: |
export PYTHONPATH=$PWD:PYTHONPATH export PYTHONPATH=$PWD:PYTHONPATH
cd egs/librispeech/ASR cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \ ./transducer_stateless/pretrained.py \
--method greedy_search \ --method greedy_search \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/exp/pretrained.pt \ --max-sym-per-frame 1 \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/data/lang_bpe_500/bpe.model \ --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1089-134686-0001.wav \ --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0001.wav \ ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0002.wav ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
- name: Run greedy search decoding (max-sym-per-frame 2)
shell: bash
run: |
export PYTHONPATH=$PWD:PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \
--method greedy_search \
--max-sym-per-frame 2 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
- name: Run greedy search decoding (max-sym-per-frame 3)
shell: bash
run: |
export PYTHONPATH=$PWD:PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \
--method greedy_search \
--max-sym-per-frame 3 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
- name: Run beam search decoding - name: Run beam search decoding
shell: bash shell: bash
@ -101,8 +130,22 @@ jobs:
./transducer_stateless/pretrained.py \ ./transducer_stateless/pretrained.py \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/exp/pretrained.pt \ --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/data/lang_bpe_500/bpe.model \ --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1089-134686-0001.wav \ ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0001.wav \ ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0002.wav ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
- name: Run modified beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \
--method modified_beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav

View File

@ -80,16 +80,16 @@ We provide a Colab notebook to run a pre-trained RNN-T conformer model: [![Open
Using Conformer as encoder. The decoder consists of 1 embedding layer Using Conformer as encoder. The decoder consists of 1 embedding layer
and 1 convolutional layer. and 1 convolutional layer.
The best WER using beam search with beam size 4 is: The best WER using modified beam search with beam size 4 is:
| | test-clean | test-other | | | test-clean | test-other |
|-----|------------|------------| |-----|------------|------------|
| WER | 2.68 | 6.72 | | WER | 2.67 | 6.64 |
Note: No auxiliary losses are used in the training and no LMs are used Note: No auxiliary losses are used in the training and no LMs are used
in the decoding. in the decoding.
We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Rc4Is-3Yp9LbcEz_Iy8hfyenyHsyjvqE?usp=sharing) We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1CO1bXJ-2khDckZIW8zjOPHGSKLHpTDlp?usp=sharing)
### Aishell ### Aishell

View File

@ -4,62 +4,73 @@
#### Conformer encoder + embedding decoder #### Conformer encoder + embedding decoder
Using commit `4c1b3665ee6efb935f4dd93a80ff0e154b13efb6`. Using commit `TODO`.
Conformer encoder + non-current decoder. The decoder Conformer encoder + non-recurrent decoder. The decoder
contains only an embedding layer and a Conv1d (with kernel size 2). contains only an embedding layer and a Conv1d (with kernel size 2).
The WERs are The WERs are
| | test-clean | test-other | comment | | | test-clean | test-other | comment |
|---------------------------|------------|------------|------------------------------------------| |-------------------------------------|------------|------------|------------------------------------------|
| greedy search | 2.69 | 6.81 | --epoch 71, --avg 15, --max-duration 100 | | greedy search (max sym per frame 1) | 2.68 | 6.71 | --epoch 61, --avg 18, --max-duration 100 |
| beam search (beam size 4) | 2.68 | 6.72 | --epoch 71, --avg 15, --max-duration 100 | | greedy search (max sym per frame 2) | 2.69 | 6.71 | --epoch 61, --avg 18, --max-duration 100 |
| greedy search (max sym per frame 3) | 2.69 | 6.71 | --epoch 61, --avg 18, --max-duration 100 |
| modified beam search (beam size 4) | 2.67 | 6.64 | --epoch 61, --avg 18, --max-duration 100 |
The training command for reproducing is given below: The training command for reproducing is given below:
``` ```
cd egs/librispeech/ASR/
./prepare.sh
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer_stateless/train.py \ ./transducer_stateless/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 76 \ --num-epochs 76 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir transducer_stateless/exp-full \ --exp-dir transducer_stateless/exp-full \
--full-libri 1 \ --full-libri 1 \
--max-duration 250 \ --max-duration 300 \
--lr-factor 3 --lr-factor 5 \
--bpe-model data/lang_bpe_500/bpe.model \
--modified-transducer-prob 0.25
``` ```
The tensorboard training log can be found at The tensorboard training log can be found at
<https://tensorboard.dev/experiment/qGdqzHnxS0WJ695OXfZDzA/#scalars&_smoothingWeight=0> <https://tensorboard.dev/experiment/qgvWkbF2R46FYA6ZMNmOjA/#scalars>
The decoding command is: The decoding command is:
``` ```
epoch=71 epoch=61
avg=15 avg=18
## greedy search ## greedy search
./transducer_stateless/decode.py \ for sym in 1 2 3; do
--epoch $epoch \ ./transducer_stateless/decode.py \
--avg $avg \ --epoch $epoch \
--exp-dir transducer_stateless/exp-full \ --avg $avg \
--bpe-model ./data/lang_bpe_500/bpe.model \ --exp-dir transducer_stateless/exp-full \
--max-duration 100 --bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 100 \
--max-sym-per-frame $sym
done
## modified beam search
## beam search
./transducer_stateless/decode.py \ ./transducer_stateless/decode.py \
--epoch $epoch \ --epoch $epoch \
--avg $avg \ --avg $avg \
--exp-dir transducer_stateless/exp-full \ --exp-dir transducer_stateless/exp-full \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 100 \ --max-duration 100 \
--decoding-method beam_search \ --context-size 2 \
--decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
``` ```
You can find a pretrained model by visiting You can find a pretrained model by visiting
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10> <https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07>
#### Conformer encoder + LSTM decoder #### Conformer encoder + LSTM decoder

View File

@ -17,7 +17,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
import numpy as np
import torch import torch
from model import Transducer from model import Transducer
@ -108,8 +107,9 @@ class Hypothesis:
# Newly predicted tokens are appended to `ys`. # Newly predicted tokens are appended to `ys`.
ys: List[int] ys: List[int]
# The log prob of ys # The log prob of ys.
log_prob: float # It contains only one entry.
log_prob: torch.Tensor
@property @property
def key(self) -> str: def key(self) -> str:
@ -145,8 +145,10 @@ class HypothesisList(object):
""" """
key = hyp.key key = hyp.key
if key in self: if key in self:
old_hyp = self._data[key] old_hyp = self._data[key] # shallow copy
old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob) torch.logaddexp(
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
)
else: else:
self._data[key] = hyp self._data[key] = hyp
@ -184,7 +186,7 @@ class HypothesisList(object):
assert key in self, f"{key} does not exist" assert key in self, f"{key} does not exist"
del self._data[key] del self._data[key]
def filter(self, threshold: float) -> "HypothesisList": def filter(self, threshold: torch.Tensor) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold. """Remove all Hypotheses whose log_prob is less than threshold.
Caution: Caution:
@ -312,6 +314,113 @@ def run_joiner(
return log_prob return log_prob
def modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
) -> List[int]:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
T = encoder_out.size(1)
B = HypothesisList()
B.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
encoder_out_len = torch.tensor([1])
decoder_out_len = torch.tensor([1])
for t in range(T):
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# current_encoder_out is of shape (1, 1, encoder_out_dim)
# fmt: on
A = list(B)
B = HypothesisList()
ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A])
# ys_log_probs is of shape (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyp in A],
device=device,
)
# decoder_input is of shape (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
# decoder_output is of shape (num_hyps, 1, decoder_output_dim)
current_encoder_out = current_encoder_out.expand(
decoder_out.size(0), 1, -1
)
logits = model.joiner(
current_encoder_out,
decoder_out,
encoder_out_len.expand(decoder_out.size(0)),
decoder_out_len.expand(decoder_out.size(0)),
)
# logits is of shape (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1)
log_probs.add_(ys_log_probs)
log_probs = log_probs.reshape(-1)
topk_log_probs, topk_indexes = log_probs.topk(beam)
# topk_hyp_indexes are indexes into `A`
topk_hyp_indexes = topk_indexes // logits.size(-1)
topk_token_indexes = topk_indexes % logits.size(-1)
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist()
for i in range(len(topk_hyp_indexes)):
hyp = A[topk_hyp_indexes[i]]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[i]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B.add(new_hyp)
best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys
def beam_search( def beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
@ -351,7 +460,12 @@ def beam_search(
t = 0 t = 0
B = HypothesisList() B = HypothesisList()
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) B.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
max_sym_per_utt = 20000 max_sym_per_utt = 20000
@ -371,9 +485,6 @@ def beam_search(
joint_cache: Dict[str, torch.Tensor] = {} joint_cache: Dict[str, torch.Tensor] = {}
# TODO(fangjun): Implement prefix search to update the `log_prob`
# of hypotheses in A
while True: while True:
y_star = A.get_most_probable() y_star = A.get_most_probable()
A.remove(y_star) A.remove(y_star)
@ -396,18 +507,21 @@ def beam_search(
# First, process the blank symbol # First, process the blank symbol
skip_log_prob = log_prob[blank_id] skip_log_prob = log_prob[blank_id]
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item() new_y_star_log_prob = y_star.log_prob + skip_log_prob
# ys[:] returns a copy of ys # ys[:] returns a copy of ys
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
# Second, process other non-blank labels # Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1) values, indices = log_prob.topk(beam + 1)
for i, v in zip(indices.tolist(), values.tolist()): for idx in range(values.size(0)):
i = indices[idx].item()
if i == blank_id: if i == blank_id:
continue continue
new_ys = y_star.ys + [i] new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v
new_log_prob = y_star.log_prob + values[idx]
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
# Check whether B contains more than "beam" elements more probable # Check whether B contains more than "beam" elements more probable

View File

@ -46,7 +46,7 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import beam_search, greedy_search from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
@ -104,6 +104,7 @@ def get_parser():
help="""Possible values are: help="""Possible values are:
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search
""", """,
) )
@ -111,7 +112,8 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="Used only when --decoding-method is beam_search", help="""Used only when --decoding-method is
beam_search or modified_beam_search""",
) )
parser.add_argument( parser.add_argument(
@ -125,7 +127,8 @@ def get_parser():
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=3,
help="Maximum number of symbols per frame", help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
) )
return parser return parser
@ -256,6 +259,10 @@ def decode_one_batch(
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model, encoder_out=encoder_out_i, beam=params.beam_size
) )
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
@ -389,11 +396,15 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
assert params.decoding_method in ("greedy_search", "beam_search") assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "beam_search": if "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" params.suffix += f"-beam-{params.beam_size}"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"

View File

@ -75,7 +75,7 @@ class Decoder(nn.Module):
""" """
Args: Args:
y: y:
A 2-D tensor of shape (N, U) with blank prepended. A 2-D tensor of shape (N, U).
need_pad: need_pad:
True to left pad the input. Should be True during training. True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference. False to not pad the input. Should be False during inference.

View File

@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -62,6 +64,7 @@ class Transducer(nn.Module):
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
y: k2.RaggedTensor, y: k2.RaggedTensor,
modified_transducer_prob: float = 0.0,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
@ -73,6 +76,8 @@ class Transducer(nn.Module):
y: y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance. utterance.
modified_transducer_prob:
The probability to use modified transducer loss.
Returns: Returns:
Return the transducer loss. Return the transducer loss.
""" """
@ -114,6 +119,16 @@ class Transducer(nn.Module):
# reference stage # reference stage
import optimized_transducer import optimized_transducer
assert 0 <= modified_transducer_prob <= 1
if modified_transducer_prob == 0:
one_sym_per_frame = False
elif random.random() < modified_transducer_prob:
# random.random() returns a float in the range [0, 1)
one_sym_per_frame = True
else:
one_sym_per_frame = False
loss = optimized_transducer.transducer_loss( loss = optimized_transducer.transducer_loss(
logits=logits, logits=logits,
targets=y_padded, targets=y_padded,
@ -121,6 +136,7 @@ class Transducer(nn.Module):
target_lengths=y_lens, target_lengths=y_lens,
blank=blank_id, blank=blank_id,
reduction="sum", reduction="sum",
one_sym_per_frame=one_sym_per_frame,
from_log_softmax=False, from_log_softmax=False,
) )

View File

@ -22,10 +22,11 @@ Usage:
--checkpoint ./transducer_stateless/exp/pretrained.pt \ --checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \ --method greedy_search \
--max-sym-per-frame 1 \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav \ /path/to/bar.wav \
(1) beam search (2) beam search
./transducer_stateless/pretrained.py \ ./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \ --checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model ./data/lang_bpe_500/bpe.model \
@ -34,6 +35,15 @@ Usage:
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav \ /path/to/bar.wav \
(3) modified beam search
./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
You can also use `./transducer_stateless/exp/epoch-xx.pt`. You can also use `./transducer_stateless/exp/epoch-xx.pt`.
Note: ./transducer_stateless/exp/pretrained.pt is generated by Note: ./transducer_stateless/exp/pretrained.pt is generated by
@ -51,7 +61,7 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
@ -91,6 +101,7 @@ def get_parser():
help="""Possible values are: help="""Possible values are:
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search
""", """,
) )
@ -108,7 +119,7 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="Used only when --method is beam_search", help="Used only when --method is beam_search and modified_beam_search ",
) )
parser.add_argument( parser.add_argument(
@ -218,6 +229,7 @@ def read_sound_files(
return ans return ans
@torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
args = parser.parse_args() args = parser.parse_args()
@ -301,6 +313,10 @@ def main():
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model, encoder_out=encoder_out_i, beam=params.beam_size
) )
elif params.method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")

View File

@ -138,6 +138,17 @@ def get_parser():
"2 means tri-gram", "2 means tri-gram",
) )
parser.add_argument(
"--modified-transducer-prob",
type=float,
default=0.25,
help="""The probability to use modified transducer loss.
In modified transduer, it limits the maximum number of symbols
per frame to 1. See also the option --max-sym-per-frame in
transducer_stateless/decode.py
""",
)
return parser return parser
@ -383,7 +394,12 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
loss = model(x=feature, x_lens=feature_lens, y=y) loss = model(
x=feature,
x_lens=feature_lens,
y=y,
modified_transducer_prob=params.modified_transducer_prob,
)
assert loss.requires_grad == is_training assert loss.requires_grad == is_training