mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
Update results.
This commit is contained in:
parent
994b8a7716
commit
a613e85900
@ -19,7 +19,8 @@ The following table lists the differences among them.
|
|||||||
| `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss |
|
| `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss |
|
||||||
| `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss |
|
| `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss |
|
||||||
| `pruned_transducer_stateless3` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss + using GigaSpeech as extra training data |
|
| `pruned_transducer_stateless3` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss + using GigaSpeech as extra training data |
|
||||||
| `pruned_transducer_stateless4` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss + Random combiner|
|
| `pruned_transducer_stateless4` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless2 + save averaged models periodically during training |
|
||||||
|
| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss + more layers + Random combiner|
|
||||||
|
|
||||||
|
|
||||||
The decoder in `transducer_stateless` is modified from the paper
|
The decoder in `transducer_stateless` is modified from the paper
|
||||||
|
@ -1,9 +1,142 @@
|
|||||||
## Results
|
## Results
|
||||||
|
|
||||||
### LibriSpeech BPE training results (Pruned Transducer 3)
|
### LibriSpeech BPE training results (Pruned Stateless Transducer 5)
|
||||||
|
|
||||||
|
[pruned_transducer_stateless5](./pruned_transducer_stateless5)
|
||||||
|
|
||||||
|
Same as `Pruned Stateless Transducer 2` but with more layers.
|
||||||
|
|
||||||
|
See <https://github.com/k2-fsa/icefall/pull/330>
|
||||||
|
|
||||||
|
Note models in `pruned_transducer_stateless` and `pruned_transducer_stateless2`
|
||||||
|
have about 80 M parameters.
|
||||||
|
|
||||||
|
The notations `large` and `medium` below are from the [Conformer](https://arxiv.org/pdf/2005.08100.pdf)
|
||||||
|
paper, where the large model has about 188 M parameters and the medium model
|
||||||
|
has 30.8 M parameters.
|
||||||
|
|
||||||
|
#### Large
|
||||||
|
|
||||||
|
Number of model parameters 118129516 (i.e, 118.13 M).
|
||||||
|
|
||||||
|
| | test-clean | test-other | comment |
|
||||||
|
|-------------------------------------|------------|------------|----------------------------------------|
|
||||||
|
| greedy search (max sym per frame 1) | 2.39 | 5.57 | --epoch 39 --avg 7 --max-duration 600 |
|
||||||
|
| modified beam search | 2.35 | 5.50 | --epoch 39 --avg 7 --max-duration 600 |
|
||||||
|
| fast beam search | 2.38 | 5.50 | --epoch 39 --avg 7 --max-duration 600 |
|
||||||
|
|
||||||
|
The training commands are:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless5/train.py \
|
||||||
|
--world-size 8 \
|
||||||
|
--num-epochs 40 \
|
||||||
|
--start-epoch 0 \
|
||||||
|
--full-libri 1 \
|
||||||
|
--exp-dir pruned_transducer_stateless5/exp-L \
|
||||||
|
--max-duration 300 \
|
||||||
|
--use-fp16 0 \
|
||||||
|
--num-encoder-layers 18 \
|
||||||
|
--dim-feedforward 2048 \
|
||||||
|
--nhead 8 \
|
||||||
|
--encoder-dim 512 \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512
|
||||||
|
```
|
||||||
|
|
||||||
|
The tensorboard log can be found at
|
||||||
|
<https://tensorboard.dev/experiment/Zq0h3KpnQ2igWbeR4U82Pw/>
|
||||||
|
|
||||||
|
The decoding commands are:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
for method in greedy_search modified_beam_search fast_beam_search; do
|
||||||
|
./pruned_transducer_stateless5/decode.py \
|
||||||
|
--epoch 39 \
|
||||||
|
--avg 7 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless5/exp-L \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method $method \
|
||||||
|
--max-sym-per-frame 1 \
|
||||||
|
--num-encoder-layers 18 \
|
||||||
|
--dim-feedforward 2048 \
|
||||||
|
--nhead 8 \
|
||||||
|
--encoder-dim 512 \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
You can find a pretrained model, training logs, decoding logs, and decoding
|
||||||
|
results at:
|
||||||
|
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13>
|
||||||
|
|
||||||
|
|
||||||
|
#### Medium
|
||||||
|
|
||||||
|
Number of model parameters 30896748 (i.e, 30.9 M).
|
||||||
|
|
||||||
|
| | test-clean | test-other | comment |
|
||||||
|
|-------------------------------------|------------|------------|-----------------------------------------|
|
||||||
|
| greedy search (max sym per frame 1) | 2.88 | 6.69 | --epoch 39 --avg 17 --max-duration 600 |
|
||||||
|
| modified beam search | 2.83 | 6.59 | --epoch 39 --avg 17 --max-duration 600 |
|
||||||
|
| fast beam search | 2.83 | 6.61 | --epoch 39 --avg 17 --max-duration 600 |
|
||||||
|
|
||||||
|
The training commands are:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless5/train.py \
|
||||||
|
--world-size 8 \
|
||||||
|
--num-epochs 40 \
|
||||||
|
--start-epoch 0 \
|
||||||
|
--full-libri 1 \
|
||||||
|
--exp-dir pruned_transducer_stateless5/exp-M \
|
||||||
|
--max-duration 300 \
|
||||||
|
--use-fp16 0 \
|
||||||
|
--num-encoder-layers 18 \
|
||||||
|
--dim-feedforward 1024 \
|
||||||
|
--nhead 4 \
|
||||||
|
--encoder-dim 256 \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512
|
||||||
|
```
|
||||||
|
|
||||||
|
The tensorboard log can be found at
|
||||||
|
<https://tensorboard.dev/experiment/bOQvULPsQ1iL7xpdI0VbXw/>
|
||||||
|
|
||||||
|
The decoding commands are:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
for method in greedy_search modified_beam_search fast_beam_search; do
|
||||||
|
./pruned_transducer_stateless5/decode.py \
|
||||||
|
--epoch 39 \
|
||||||
|
--avg 17 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless5/exp-M \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method $method \
|
||||||
|
--max-sym-per-frame 1 \
|
||||||
|
--num-encoder-layers 18 \
|
||||||
|
--dim-feedforward 1024 \
|
||||||
|
--nhead 4 \
|
||||||
|
--encoder-dim 256 \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
You can find a pretrained model, training logs, decoding logs, and decoding
|
||||||
|
results at:
|
||||||
|
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-M-2022-05-13>
|
||||||
|
|
||||||
|
|
||||||
|
### LibriSpeech BPE training results (Pruned Stateless Transducer 3)
|
||||||
|
|
||||||
[pruned_transducer_stateless3](./pruned_transducer_stateless3)
|
[pruned_transducer_stateless3](./pruned_transducer_stateless3)
|
||||||
Same as `Pruned Transducer 2` but using the XL subset from
|
Same as `Pruned Stateless Transducer 2` but using the XL subset from
|
||||||
[GigaSpeech](https://github.com/SpeechColab/GigaSpeech) as extra training data.
|
[GigaSpeech](https://github.com/SpeechColab/GigaSpeech) as extra training data.
|
||||||
|
|
||||||
During training, it selects either a batch from GigaSpeech with prob `giga_prob`
|
During training, it selects either a batch from GigaSpeech with prob `giga_prob`
|
||||||
@ -104,6 +237,7 @@ done
|
|||||||
The following table shows the
|
The following table shows the
|
||||||
[Nbest oracle WER](http://kaldi-asr.org/doc/lattices.html#lattices_operations_oracle)
|
[Nbest oracle WER](http://kaldi-asr.org/doc/lattices.html#lattices_operations_oracle)
|
||||||
for fast beam search.
|
for fast beam search.
|
||||||
|
|
||||||
| epoch | avg | num_paths | nbest_scale | test-clean | test-other |
|
| epoch | avg | num_paths | nbest_scale | test-clean | test-other |
|
||||||
|-------|-----|-----------|-------------|------------|------------|
|
|-------|-----|-----------|-------------|------------|------------|
|
||||||
| 27 | 10 | 50 | 0.5 | 0.91 | 2.74 |
|
| 27 | 10 | 50 | 0.5 | 0.91 | 2.74 |
|
||||||
|
@ -19,40 +19,40 @@
|
|||||||
Usage:
|
Usage:
|
||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search (not recommended)
|
(2) beam search (not recommended)
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method fast_beam_search \
|
--decoding-method fast_beam_search \
|
||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
--max-states 8
|
--max-states 8
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -485,7 +485,7 @@ def main():
|
|||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(params.bpe_model)
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
# <blk> and <unk> is defined in local/train_bpe_model.py
|
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
@ -18,41 +18,41 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless4/decode.py \
|
./pruned_transducer_stateless5/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search
|
(2) beam search (not recommended)
|
||||||
./pruned_transducer_stateless4/decode.py \
|
./pruned_transducer_stateless5/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./pruned_transducer_stateless4/decode.py \
|
./pruned_transducer_stateless5/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./pruned_transducer_stateless4/decode.py \
|
./pruned_transducer_stateless5/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||||
--max-duration 1500 \
|
--max-duration 600 \
|
||||||
--decoding-method fast_beam_search \
|
--decoding-method fast_beam_search \
|
||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
--max-states 8
|
--max-states 8
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -69,7 +69,7 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search,
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
@ -98,33 +98,34 @@ def get_parser():
|
|||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=28,
|
default=28,
|
||||||
help="It specifies the checkpoint to use for decoding."
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
"Note: Epoch counts from 0.",
|
Note: Epoch counts from 0.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=15,
|
default=15,
|
||||||
help="Number of checkpoints to average. Automatically select "
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. ",
|
"'--epoch' and '--iter'",
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--avg-last-n",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="""If positive, --epoch and --avg are ignored and it
|
|
||||||
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
|
|
||||||
where xxx is the number of processed batches while
|
|
||||||
saving that checkpoint.
|
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="pruned_transducer_stateless4/exp",
|
default="pruned_transducer_stateless5/exp",
|
||||||
help="The experiment dir",
|
help="The experiment dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -151,7 +152,7 @@ def get_parser():
|
|||||||
"--beam-size",
|
"--beam-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="""An interger indicating how many candidates we will keep for each
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
frame. Used only when --decoding-method is beam_search or
|
frame. Used only when --decoding-method is beam_search or
|
||||||
modified_beam_search.""",
|
modified_beam_search.""",
|
||||||
)
|
)
|
||||||
@ -253,7 +254,7 @@ def decode_one_batch(
|
|||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyp_tokens = fast_beam_search(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -271,6 +272,7 @@ def decode_one_batch(
|
|||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -278,6 +280,7 @@ def decode_one_batch(
|
|||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
@ -357,9 +360,9 @@ def decode_dataset(
|
|||||||
num_batches = "?"
|
num_batches = "?"
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
log_interval = 100
|
log_interval = 50
|
||||||
else:
|
else:
|
||||||
log_interval = 2
|
log_interval = 10
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
@ -455,13 +458,19 @@ def main():
|
|||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
if params.iter > 0:
|
||||||
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
|
else:
|
||||||
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam_size}"
|
params.suffix += (
|
||||||
|
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -478,8 +487,9 @@ def main():
|
|||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(params.bpe_model)
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
@ -487,8 +497,20 @@ def main():
|
|||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_transducer_model(params)
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
if params.avg_last_n > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
@ -20,23 +20,23 @@
|
|||||||
# to a single one using model averaging.
|
# to a single one using model averaging.
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
./pruned_transducer_stateless2/export.py \
|
./pruned_transducer_stateless5/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
It will generate a file exp_dir/pretrained.pt
|
It will generate a file exp_dir/pretrained.pt
|
||||||
|
|
||||||
To use the generated file with `pruned_transducer_stateless2/decode.py`,
|
To use the generated file with `pruned_transducer_stateless5/decode.py`,
|
||||||
you can do:
|
you can do:
|
||||||
|
|
||||||
cd /path/to/exp_dir
|
cd /path/to/exp_dir
|
||||||
ln -s pretrained.pt epoch-9999.pt
|
ln -s pretrained.pt epoch-9999.pt
|
||||||
|
|
||||||
cd /path/to/egs/librispeech/ASR
|
cd /path/to/egs/librispeech/ASR
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless5/decode.py \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||||
--epoch 9999 \
|
--epoch 9999 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
@ -49,7 +49,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
from train import get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
@ -80,7 +80,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="pruned_transducer_stateless2/exp",
|
default="pruned_transducer_stateless5/exp",
|
||||||
help="""It specifies the directory where all training related
|
help="""It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
""",
|
""",
|
||||||
@ -109,6 +109,8 @@ def get_parser():
|
|||||||
"2 means tri-gram",
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,22 +21,22 @@ Usage:
|
|||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
./pruned_transducer_stateless4/train.py \
|
./pruned_transducer_stateless5/train.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 0 \
|
--start-epoch 0 \
|
||||||
--exp-dir pruned_transducer_stateless4/exp \
|
--exp-dir pruned_transducer_stateless5/exp \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--max-duration 300
|
--max-duration 300
|
||||||
|
|
||||||
# For mix precision training:
|
# For mix precision training:
|
||||||
|
|
||||||
./pruned_transducer_stateless4/train.py \
|
./pruned_transducer_stateless5/train.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 0 \
|
--start-epoch 0 \
|
||||||
--use_fp16 1 \
|
--use-fp16 1 \
|
||||||
--exp-dir pruned_transducer_stateless4/exp \
|
--exp-dir pruned_transducer_stateless5/exp \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--max-duration 550
|
--max-duration 550
|
||||||
|
|
||||||
@ -185,7 +185,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="pruned_transducer_stateless4/exp",
|
default="pruned_transducer_stateless5/exp",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
@ -712,25 +712,29 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
try:
|
||||||
loss, loss_info = compute_loss(
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
params=params,
|
loss, loss_info = compute_loss(
|
||||||
model=model,
|
params=params,
|
||||||
sp=sp,
|
model=model,
|
||||||
batch=batch,
|
sp=sp,
|
||||||
is_training=True,
|
batch=batch,
|
||||||
warmup=(params.batch_idx_train / params.model_warm_step),
|
is_training=True,
|
||||||
)
|
warmup=(params.batch_idx_train / params.model_warm_step),
|
||||||
# summary stats
|
)
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
# summary stats
|
||||||
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
|
|
||||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||||
# in the batch and there is no normalization to it so far.
|
# in the batch and there is no normalization to it so far.
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
scheduler.step_batch(params.batch_idx_train)
|
scheduler.step_batch(params.batch_idx_train)
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
except: # noqa
|
||||||
|
display_and_save_batch(batch, params=params, sp=sp)
|
||||||
|
raise
|
||||||
|
|
||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
return
|
return
|
||||||
@ -975,6 +979,38 @@ def run(rank, world_size, args):
|
|||||||
cleanup_dist()
|
cleanup_dist()
|
||||||
|
|
||||||
|
|
||||||
|
def display_and_save_batch(
|
||||||
|
batch: dict,
|
||||||
|
params: AttributeDict,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
) -> None:
|
||||||
|
"""Display the batch statistics and save the batch into disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch:
|
||||||
|
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||||
|
for the content in it.
|
||||||
|
params:
|
||||||
|
Parameters for training. See :func:`get_params`.
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
"""
|
||||||
|
from lhotse.utils import uuid4
|
||||||
|
|
||||||
|
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
|
||||||
|
logging.info(f"Saving batch to {filename}")
|
||||||
|
torch.save(batch, filename)
|
||||||
|
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
features = batch["inputs"]
|
||||||
|
|
||||||
|
logging.info(f"features shape: {features.shape}")
|
||||||
|
|
||||||
|
y = sp.encode(supervisions["text"], out_type=int)
|
||||||
|
num_tokens = sum(len(i) for i in y)
|
||||||
|
logging.info(f"num tokens: {num_tokens}")
|
||||||
|
|
||||||
|
|
||||||
def scan_pessimistic_batches_for_oom(
|
def scan_pessimistic_batches_for_oom(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
@ -1006,7 +1042,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
except RuntimeError as e:
|
except Exception as e:
|
||||||
if "CUDA out of memory" in str(e):
|
if "CUDA out of memory" in str(e):
|
||||||
logging.error(
|
logging.error(
|
||||||
"Your GPU ran out of memory with the current "
|
"Your GPU ran out of memory with the current "
|
||||||
@ -1015,6 +1051,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
f"Failing criterion: {criterion} "
|
f"Failing criterion: {criterion} "
|
||||||
f"(={crit_values[criterion]}) ..."
|
f"(={crit_values[criterion]}) ..."
|
||||||
)
|
)
|
||||||
|
display_and_save_batch(batch, params=params, sp=sp)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user