mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Update RESULTS.md.
This commit is contained in:
parent
d9beb73869
commit
c5291c828c
@ -84,7 +84,7 @@ The best WER using modified beam search with beam size 4 is:
|
|||||||
|
|
||||||
| | test-clean | test-other |
|
| | test-clean | test-other |
|
||||||
|-----|------------|------------|
|
|-----|------------|------------|
|
||||||
| WER | 2.61 | 6.46 |
|
| WER | 2.56 | 6.27 |
|
||||||
|
|
||||||
Note: No auxiliary losses are used in the training and no LMs are used
|
Note: No auxiliary losses are used in the training and no LMs are used
|
||||||
in the decoding.
|
in the decoding.
|
||||||
|
@ -15,6 +15,7 @@ The following table lists the differences among them.
|
|||||||
| `transducer_stateless` | Conformer | Embedding + Conv1d | |
|
| `transducer_stateless` | Conformer | Embedding + Conv1d | |
|
||||||
| `transducer_lstm` | LSTM | LSTM | |
|
| `transducer_lstm` | LSTM | LSTM | |
|
||||||
| `transducer_stateless_multi_datasets` | Conformer | Embedding + Conv1d | Using data from GigaSpeech as extra training data |
|
| `transducer_stateless_multi_datasets` | Conformer | Embedding + Conv1d | Using data from GigaSpeech as extra training data |
|
||||||
|
| `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss |
|
||||||
|
|
||||||
The decoder in `transducer_stateless` is modified from the paper
|
The decoder in `transducer_stateless` is modified from the paper
|
||||||
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||||
|
@ -2,12 +2,111 @@
|
|||||||
|
|
||||||
### LibriSpeech BPE training results (Pruned Transducer)
|
### LibriSpeech BPE training results (Pruned Transducer)
|
||||||
|
|
||||||
#### Conformer encoder + embedding decoder
|
|
||||||
|
|
||||||
Conformer encoder + non-current decoder. The decoder
|
Conformer encoder + non-current decoder. The decoder
|
||||||
contains only an embedding layer, a Conv1d (with kernel size 2) and a linear
|
contains only an embedding layer, a Conv1d (with kernel size 2) and a linear
|
||||||
layer (to transform tensor dim).
|
layer (to transform tensor dim).
|
||||||
|
|
||||||
|
#### 2022-03-12
|
||||||
|
|
||||||
|
[pruned_transducer_stateless](./pruned_transducer_stateless)
|
||||||
|
|
||||||
|
Using commit `1603744469d167d848e074f2ea98c587153205fa`.
|
||||||
|
See <https://github.com/k2-fsa/icefall/pull/248>
|
||||||
|
|
||||||
|
The WERs are:
|
||||||
|
|
||||||
|
| | test-clean | test-other | comment |
|
||||||
|
|-------------------------------------|------------|------------|------------------------------------------|
|
||||||
|
| greedy search (max sym per frame 1) | 2.62 | 6.37 | --epoch 42, --avg 11, --max-duration 100 |
|
||||||
|
| greedy search (max sym per frame 2) | 2.62 | 6.37 | --epoch 42, --avg 11, --max-duration 100 |
|
||||||
|
| greedy search (max sym per frame 3) | 2.62 | 6.37 | --epoch 42, --avg 11, --max-duration 100 |
|
||||||
|
| modified beam search (beam size 4) | 2.56 | 6.27 | --epoch 39, --avg 15, --max-duration 100 |
|
||||||
|
| beam search (beam size 4) | 2.57 | 6.27 | --epoch 39, --avg 15, --max-duration 100 |
|
||||||
|
|
||||||
|
The decoding time for `test-clean` and `test-other` is given below:
|
||||||
|
(A V100 GPU with 32 GB RAM is used for decoding. Note: Not all GPU RAM is used during decoding.)
|
||||||
|
|
||||||
|
| decoding method | test-clean (seconds) | test-other (seconds)|
|
||||||
|
|---|---:|---:|
|
||||||
|
| greedy search (--max-sym-per-frame=1) | 160 | 159 |
|
||||||
|
| greedy search (--max-sym-per-frame=2) | 184 | 177 |
|
||||||
|
| greedy search (--max-sym-per-frame=3) | 210 | 213 |
|
||||||
|
| modified beam search (--beam-size 4)| 273 | 269 |
|
||||||
|
|beam search (--beam-size 4) | 2741 | 2221 |
|
||||||
|
|
||||||
|
We recommend you to use `modified_beam_search`.
|
||||||
|
|
||||||
|
Training command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd egs/librispeech/ASR/
|
||||||
|
./prepare.sh
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||||
|
|
||||||
|
. path.sh
|
||||||
|
|
||||||
|
./pruned_transducer_stateless/train.py \
|
||||||
|
--world-size 8 \
|
||||||
|
--num-epochs 60 \
|
||||||
|
--start-epoch 0 \
|
||||||
|
--exp-dir pruned_transducer_stateless/exp \
|
||||||
|
--full-libri 1 \
|
||||||
|
--max-duration 300 \
|
||||||
|
--prune-range 5 \
|
||||||
|
--lr-factor 5 \
|
||||||
|
--lm-scale 0.25
|
||||||
|
```
|
||||||
|
|
||||||
|
The tensorboard training log can be found at
|
||||||
|
<https://tensorboard.dev/experiment/WKRFY5fYSzaVBHahenpNlA/>
|
||||||
|
|
||||||
|
The command for decoding is:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
epoch=42
|
||||||
|
avg=11
|
||||||
|
sym=1
|
||||||
|
|
||||||
|
# greedy search
|
||||||
|
|
||||||
|
./pruned_transducer_stateless/decode.py \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg \
|
||||||
|
--exp-dir ./pruned_transducer_stateless/exp \
|
||||||
|
--max-duration 100 \
|
||||||
|
--decoding-method greedy_search \
|
||||||
|
--beam-size 4 \
|
||||||
|
--max-sym-per-frame $sym
|
||||||
|
|
||||||
|
# modified beam search
|
||||||
|
./pruned_transducer_stateless/decode.py \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg \
|
||||||
|
--exp-dir ./pruned_transducer_stateless/exp \
|
||||||
|
--max-duration 100 \
|
||||||
|
--decoding-method modified_beam_search \
|
||||||
|
--beam-size 4
|
||||||
|
|
||||||
|
# beam search
|
||||||
|
# (not recommended)
|
||||||
|
./pruned_transducer_stateless/decode.py \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg \
|
||||||
|
--exp-dir ./pruned_transducer_stateless/exp \
|
||||||
|
--max-duration 100 \
|
||||||
|
--decoding-method beam_search \
|
||||||
|
--beam-size 4
|
||||||
|
```
|
||||||
|
|
||||||
|
You can find a pre-trained model, decoding logs, and decoding results at
|
||||||
|
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12>
|
||||||
|
|
||||||
|
#### 2022-02-18
|
||||||
|
|
||||||
|
[pruned_transducer_stateless](./pruned_transducer_stateless)
|
||||||
|
|
||||||
|
|
||||||
The WERs are
|
The WERs are
|
||||||
|
|
||||||
| | test-clean | test-other | comment |
|
| | test-clean | test-other | comment |
|
||||||
@ -62,7 +161,7 @@ See
|
|||||||
|
|
||||||
##### 2022-03-01
|
##### 2022-03-01
|
||||||
|
|
||||||
Using commit `fill in it after merging`.
|
Using commit `2332ba312d7ce72f08c7bac1e3312f7e3dd722dc`.
|
||||||
|
|
||||||
It uses [GigaSpeech](https://github.com/SpeechColab/GigaSpeech)
|
It uses [GigaSpeech](https://github.com/SpeechColab/GigaSpeech)
|
||||||
as extra training data. 20% of the time it selects a batch from L subset of
|
as extra training data. 20% of the time it selects a batch from L subset of
|
||||||
@ -129,6 +228,9 @@ sym=1
|
|||||||
--beam-size 4
|
--beam-size 4
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can find a pretrained model by visiting
|
||||||
|
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01>
|
||||||
|
|
||||||
|
|
||||||
##### 2022-02-07
|
##### 2022-02-07
|
||||||
|
|
||||||
|
@ -56,13 +56,9 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from beam_search import beam_search, greedy_search, modified_beam_search
|
from beam_search import beam_search, greedy_search, modified_beam_search
|
||||||
from conformer import Conformer
|
from train import get_params, get_transducer_model
|
||||||
from decoder import Decoder
|
|
||||||
from joiner import Joiner
|
|
||||||
from model import Transducer
|
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
@ -143,72 +139,6 @@ def get_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def get_params() -> AttributeDict:
|
|
||||||
params = AttributeDict(
|
|
||||||
{
|
|
||||||
# parameters for conformer
|
|
||||||
"feature_dim": 80,
|
|
||||||
"subsampling_factor": 4,
|
|
||||||
"attention_dim": 512,
|
|
||||||
"nhead": 8,
|
|
||||||
"dim_feedforward": 2048,
|
|
||||||
"num_encoder_layers": 12,
|
|
||||||
"vgg_frontend": False,
|
|
||||||
# parameters for decoder
|
|
||||||
"embedding_dim": 512,
|
|
||||||
"env_info": get_env_info(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|
||||||
# TODO: We can add an option to switch between Conformer and Transformer
|
|
||||||
encoder = Conformer(
|
|
||||||
num_features=params.feature_dim,
|
|
||||||
output_dim=params.vocab_size,
|
|
||||||
subsampling_factor=params.subsampling_factor,
|
|
||||||
d_model=params.attention_dim,
|
|
||||||
nhead=params.nhead,
|
|
||||||
dim_feedforward=params.dim_feedforward,
|
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
|
||||||
vgg_frontend=params.vgg_frontend,
|
|
||||||
)
|
|
||||||
return encoder
|
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
|
||||||
decoder = Decoder(
|
|
||||||
vocab_size=params.vocab_size,
|
|
||||||
embedding_dim=params.embedding_dim,
|
|
||||||
blank_id=params.blank_id,
|
|
||||||
context_size=params.context_size,
|
|
||||||
)
|
|
||||||
return decoder
|
|
||||||
|
|
||||||
|
|
||||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|
||||||
joiner = Joiner(
|
|
||||||
input_dim=params.vocab_size,
|
|
||||||
inner_dim=params.embedding_dim,
|
|
||||||
output_dim=params.vocab_size,
|
|
||||||
)
|
|
||||||
return joiner
|
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|
||||||
encoder = get_encoder_model(params)
|
|
||||||
decoder = get_decoder_model(params)
|
|
||||||
joiner = get_joiner_model(params)
|
|
||||||
|
|
||||||
model = Transducer(
|
|
||||||
encoder=encoder,
|
|
||||||
decoder=decoder,
|
|
||||||
joiner=joiner,
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -489,8 +419,5 @@ def main():
|
|||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
torch.set_num_interop_threads(1)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -39,7 +39,7 @@ you can do:
|
|||||||
--exp-dir ./pruned_transducer_stateless/exp \
|
--exp-dir ./pruned_transducer_stateless/exp \
|
||||||
--epoch 9999 \
|
--epoch 9999 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--max-duration 1 \
|
--max-duration 100 \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model
|
--bpe-model data/lang_bpe_500/bpe.model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -49,15 +49,10 @@ from pathlib import Path
|
|||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
from train import get_params, get_transducer_model
|
||||||
from conformer import Conformer
|
|
||||||
from decoder import Decoder
|
|
||||||
from joiner import Joiner
|
|
||||||
from model import Transducer
|
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.utils import str2bool
|
||||||
from icefall.utils import AttributeDict, str2bool
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -117,71 +112,6 @@ def get_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def get_params() -> AttributeDict:
|
|
||||||
params = AttributeDict(
|
|
||||||
{
|
|
||||||
# parameters for conformer
|
|
||||||
"feature_dim": 80,
|
|
||||||
"subsampling_factor": 4,
|
|
||||||
"attention_dim": 512,
|
|
||||||
"nhead": 8,
|
|
||||||
"dim_feedforward": 2048,
|
|
||||||
"num_encoder_layers": 12,
|
|
||||||
"vgg_frontend": False,
|
|
||||||
# parameters for decoder
|
|
||||||
"embedding_dim": 512,
|
|
||||||
"env_info": get_env_info(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|
||||||
encoder = Conformer(
|
|
||||||
num_features=params.feature_dim,
|
|
||||||
output_dim=params.vocab_size,
|
|
||||||
subsampling_factor=params.subsampling_factor,
|
|
||||||
d_model=params.attention_dim,
|
|
||||||
nhead=params.nhead,
|
|
||||||
dim_feedforward=params.dim_feedforward,
|
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
|
||||||
vgg_frontend=params.vgg_frontend,
|
|
||||||
)
|
|
||||||
return encoder
|
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
|
||||||
decoder = Decoder(
|
|
||||||
vocab_size=params.vocab_size,
|
|
||||||
embedding_dim=params.embedding_dim,
|
|
||||||
blank_id=params.blank_id,
|
|
||||||
context_size=params.context_size,
|
|
||||||
)
|
|
||||||
return decoder
|
|
||||||
|
|
||||||
|
|
||||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|
||||||
joiner = Joiner(
|
|
||||||
input_dim=params.vocab_size,
|
|
||||||
inner_dim=params.embedding_dim,
|
|
||||||
output_dim=params.vocab_size,
|
|
||||||
)
|
|
||||||
return joiner
|
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|
||||||
encoder = get_encoder_model(params)
|
|
||||||
decoder = get_decoder_model(params)
|
|
||||||
joiner = get_joiner_model(params)
|
|
||||||
|
|
||||||
model = Transducer(
|
|
||||||
encoder=encoder,
|
|
||||||
decoder=decoder,
|
|
||||||
joiner=joiner,
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
@ -49,17 +49,10 @@ from typing import List
|
|||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import beam_search, greedy_search
|
from beam_search import beam_search, greedy_search, modified_beam_search
|
||||||
from conformer import Conformer
|
|
||||||
from decoder import Decoder
|
|
||||||
from joiner import Joiner
|
|
||||||
from model import Transducer
|
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from train import get_params, get_transducer_model
|
||||||
from icefall.env import get_env_info
|
|
||||||
from icefall.utils import AttributeDict
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -91,6 +84,7 @@ def get_parser():
|
|||||||
help="""Possible values are:
|
help="""Possible values are:
|
||||||
- greedy_search
|
- greedy_search
|
||||||
- beam_search
|
- beam_search
|
||||||
|
- modified_beam_search
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -104,11 +98,18 @@ def get_parser():
|
|||||||
"The sample rate has to be 16kHz.",
|
"The sample rate has to be 16kHz.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="The sample rate of the input sound file",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--beam-size",
|
"--beam-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="Used only when --method is beam_search",
|
help="Used only when --method is beam_search and modified_beam_search",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -130,72 +131,6 @@ def get_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def get_params() -> AttributeDict:
|
|
||||||
params = AttributeDict(
|
|
||||||
{
|
|
||||||
"sample_rate": 16000,
|
|
||||||
# parameters for conformer
|
|
||||||
"feature_dim": 80,
|
|
||||||
"subsampling_factor": 4,
|
|
||||||
"attention_dim": 512,
|
|
||||||
"nhead": 8,
|
|
||||||
"dim_feedforward": 2048,
|
|
||||||
"num_encoder_layers": 12,
|
|
||||||
"vgg_frontend": False,
|
|
||||||
# parameters for decoder
|
|
||||||
"embedding_dim": 512,
|
|
||||||
"env_info": get_env_info(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|
||||||
encoder = Conformer(
|
|
||||||
num_features=params.feature_dim,
|
|
||||||
output_dim=params.vocab_size,
|
|
||||||
subsampling_factor=params.subsampling_factor,
|
|
||||||
d_model=params.attention_dim,
|
|
||||||
nhead=params.nhead,
|
|
||||||
dim_feedforward=params.dim_feedforward,
|
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
|
||||||
vgg_frontend=params.vgg_frontend,
|
|
||||||
)
|
|
||||||
return encoder
|
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
|
||||||
decoder = Decoder(
|
|
||||||
vocab_size=params.vocab_size,
|
|
||||||
embedding_dim=params.embedding_dim,
|
|
||||||
blank_id=params.blank_id,
|
|
||||||
context_size=params.context_size,
|
|
||||||
)
|
|
||||||
return decoder
|
|
||||||
|
|
||||||
|
|
||||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|
||||||
joiner = Joiner(
|
|
||||||
input_dim=params.vocab_size,
|
|
||||||
inner_dim=params.embedding_dim,
|
|
||||||
output_dim=params.vocab_size,
|
|
||||||
)
|
|
||||||
return joiner
|
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|
||||||
encoder = get_encoder_model(params)
|
|
||||||
decoder = get_decoder_model(params)
|
|
||||||
joiner = get_joiner_model(params)
|
|
||||||
|
|
||||||
model = Transducer(
|
|
||||||
encoder=encoder,
|
|
||||||
decoder=decoder,
|
|
||||||
joiner=joiner,
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def read_sound_files(
|
def read_sound_files(
|
||||||
filenames: List[str], expected_sample_rate: float
|
filenames: List[str], expected_sample_rate: float
|
||||||
) -> List[torch.Tensor]:
|
) -> List[torch.Tensor]:
|
||||||
@ -220,6 +155,7 @@ def read_sound_files(
|
|||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -278,10 +214,9 @@ def main():
|
|||||||
|
|
||||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
with torch.no_grad():
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
x=features, x_lens=feature_lengths
|
||||||
x=features, x_lens=feature_lengths
|
)
|
||||||
)
|
|
||||||
|
|
||||||
num_waves = encoder_out.size(0)
|
num_waves = encoder_out.size(0)
|
||||||
hyps = []
|
hyps = []
|
||||||
@ -303,6 +238,10 @@ def main():
|
|||||||
hyp = beam_search(
|
hyp = beam_search(
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||||
)
|
)
|
||||||
|
elif params.method == "modified_beam_search":
|
||||||
|
hyp = modified_beam_search(
|
||||||
|
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user