mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Update results.
This commit is contained in:
parent
994b8a7716
commit
a613e85900
@ -19,7 +19,8 @@ The following table lists the differences among them.
|
||||
| `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss |
|
||||
| `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss |
|
||||
| `pruned_transducer_stateless3` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss + using GigaSpeech as extra training data |
|
||||
| `pruned_transducer_stateless4` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss + Random combiner|
|
||||
| `pruned_transducer_stateless4` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless2 + save averaged models periodically during training |
|
||||
| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss + more layers + Random combiner|
|
||||
|
||||
|
||||
The decoder in `transducer_stateless` is modified from the paper
|
||||
|
@ -1,9 +1,142 @@
|
||||
## Results
|
||||
|
||||
### LibriSpeech BPE training results (Pruned Transducer 3)
|
||||
### LibriSpeech BPE training results (Pruned Stateless Transducer 5)
|
||||
|
||||
[pruned_transducer_stateless5](./pruned_transducer_stateless5)
|
||||
|
||||
Same as `Pruned Stateless Transducer 2` but with more layers.
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/330>
|
||||
|
||||
Note models in `pruned_transducer_stateless` and `pruned_transducer_stateless2`
|
||||
have about 80 M parameters.
|
||||
|
||||
The notations `large` and `medium` below are from the [Conformer](https://arxiv.org/pdf/2005.08100.pdf)
|
||||
paper, where the large model has about 188 M parameters and the medium model
|
||||
has 30.8 M parameters.
|
||||
|
||||
#### Large
|
||||
|
||||
Number of model parameters 118129516 (i.e, 118.13 M).
|
||||
|
||||
| | test-clean | test-other | comment |
|
||||
|-------------------------------------|------------|------------|----------------------------------------|
|
||||
| greedy search (max sym per frame 1) | 2.39 | 5.57 | --epoch 39 --avg 7 --max-duration 600 |
|
||||
| modified beam search | 2.35 | 5.50 | --epoch 39 --avg 7 --max-duration 600 |
|
||||
| fast beam search | 2.38 | 5.50 | --epoch 39 --avg 7 --max-duration 600 |
|
||||
|
||||
The training commands are:
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
|
||||
./pruned_transducer_stateless5/train.py \
|
||||
--world-size 8 \
|
||||
--num-epochs 40 \
|
||||
--start-epoch 0 \
|
||||
--full-libri 1 \
|
||||
--exp-dir pruned_transducer_stateless5/exp-L \
|
||||
--max-duration 300 \
|
||||
--use-fp16 0 \
|
||||
--num-encoder-layers 18 \
|
||||
--dim-feedforward 2048 \
|
||||
--nhead 8 \
|
||||
--encoder-dim 512 \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512
|
||||
```
|
||||
|
||||
The tensorboard log can be found at
|
||||
<https://tensorboard.dev/experiment/Zq0h3KpnQ2igWbeR4U82Pw/>
|
||||
|
||||
The decoding commands are:
|
||||
|
||||
```bash
|
||||
for method in greedy_search modified_beam_search fast_beam_search; do
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 39 \
|
||||
--avg 7 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp-L \
|
||||
--max-duration 600 \
|
||||
--decoding-method $method \
|
||||
--max-sym-per-frame 1 \
|
||||
--num-encoder-layers 18 \
|
||||
--dim-feedforward 2048 \
|
||||
--nhead 8 \
|
||||
--encoder-dim 512 \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512
|
||||
done
|
||||
```
|
||||
|
||||
You can find a pretrained model, training logs, decoding logs, and decoding
|
||||
results at:
|
||||
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13>
|
||||
|
||||
|
||||
#### Medium
|
||||
|
||||
Number of model parameters 30896748 (i.e, 30.9 M).
|
||||
|
||||
| | test-clean | test-other | comment |
|
||||
|-------------------------------------|------------|------------|-----------------------------------------|
|
||||
| greedy search (max sym per frame 1) | 2.88 | 6.69 | --epoch 39 --avg 17 --max-duration 600 |
|
||||
| modified beam search | 2.83 | 6.59 | --epoch 39 --avg 17 --max-duration 600 |
|
||||
| fast beam search | 2.83 | 6.61 | --epoch 39 --avg 17 --max-duration 600 |
|
||||
|
||||
The training commands are:
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
|
||||
./pruned_transducer_stateless5/train.py \
|
||||
--world-size 8 \
|
||||
--num-epochs 40 \
|
||||
--start-epoch 0 \
|
||||
--full-libri 1 \
|
||||
--exp-dir pruned_transducer_stateless5/exp-M \
|
||||
--max-duration 300 \
|
||||
--use-fp16 0 \
|
||||
--num-encoder-layers 18 \
|
||||
--dim-feedforward 1024 \
|
||||
--nhead 4 \
|
||||
--encoder-dim 256 \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512
|
||||
```
|
||||
|
||||
The tensorboard log can be found at
|
||||
<https://tensorboard.dev/experiment/bOQvULPsQ1iL7xpdI0VbXw/>
|
||||
|
||||
The decoding commands are:
|
||||
|
||||
```bash
|
||||
for method in greedy_search modified_beam_search fast_beam_search; do
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 39 \
|
||||
--avg 17 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp-M \
|
||||
--max-duration 600 \
|
||||
--decoding-method $method \
|
||||
--max-sym-per-frame 1 \
|
||||
--num-encoder-layers 18 \
|
||||
--dim-feedforward 1024 \
|
||||
--nhead 4 \
|
||||
--encoder-dim 256 \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512
|
||||
done
|
||||
```
|
||||
|
||||
You can find a pretrained model, training logs, decoding logs, and decoding
|
||||
results at:
|
||||
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-M-2022-05-13>
|
||||
|
||||
|
||||
### LibriSpeech BPE training results (Pruned Stateless Transducer 3)
|
||||
|
||||
[pruned_transducer_stateless3](./pruned_transducer_stateless3)
|
||||
Same as `Pruned Transducer 2` but using the XL subset from
|
||||
Same as `Pruned Stateless Transducer 2` but using the XL subset from
|
||||
[GigaSpeech](https://github.com/SpeechColab/GigaSpeech) as extra training data.
|
||||
|
||||
During training, it selects either a batch from GigaSpeech with prob `giga_prob`
|
||||
@ -104,6 +237,7 @@ done
|
||||
The following table shows the
|
||||
[Nbest oracle WER](http://kaldi-asr.org/doc/lattices.html#lattices_operations_oracle)
|
||||
for fast beam search.
|
||||
|
||||
| epoch | avg | num_paths | nbest_scale | test-clean | test-other |
|
||||
|-------|-----|-----------|-------------|------------|------------|
|
||||
| 27 | 10 | 50 | 0.5 | 0.91 | 2.74 |
|
||||
|
@ -19,40 +19,40 @@
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
"""
|
||||
|
||||
|
||||
@ -485,7 +485,7 @@ def main():
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> is defined in local/train_bpe_model.py
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
@ -18,41 +18,41 @@
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless4/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method greedy_search
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search
|
||||
./pruned_transducer_stateless4/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
(2) beam search (not recommended)
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless4/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search
|
||||
./pruned_transducer_stateless4/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||
--max-duration 1500 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
"""
|
||||
|
||||
|
||||
@ -69,7 +69,7 @@ import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
@ -98,33 +98,34 @@ def get_parser():
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg-last-n",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch and --avg are ignored and it
|
||||
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
|
||||
where xxx is the number of processed batches while
|
||||
saving that checkpoint.
|
||||
""",
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless4/exp",
|
||||
default="pruned_transducer_stateless5/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
@ -151,7 +152,7 @@ def get_parser():
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An interger indicating how many candidates we will keep for each
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
@ -253,7 +254,7 @@ def decode_one_batch(
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search(
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
@ -271,6 +272,7 @@ def decode_one_batch(
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
@ -278,6 +280,7 @@ def decode_one_batch(
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
@ -357,9 +360,9 @@ def decode_dataset(
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 100
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 2
|
||||
log_interval = 10
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
@ -455,13 +458,19 @@ def main():
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -478,8 +487,9 @@ def main():
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
@ -487,8 +497,20 @@ def main():
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if params.avg_last_n > 0:
|
||||
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
@ -20,23 +20,23 @@
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
Usage:
|
||||
./pruned_transducer_stateless2/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
./pruned_transducer_stateless5/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
It will generate a file exp_dir/pretrained.pt
|
||||
|
||||
To use the generated file with `pruned_transducer_stateless2/decode.py`,
|
||||
To use the generated file with `pruned_transducer_stateless5/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 100 \
|
||||
@ -49,7 +49,7 @@ from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from train import get_params, get_transducer_model
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.utils import str2bool
|
||||
@ -80,7 +80,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
default="pruned_transducer_stateless5/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
@ -109,6 +109,8 @@ def get_parser():
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -21,22 +21,22 @@ Usage:
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./pruned_transducer_stateless4/train.py \
|
||||
./pruned_transducer_stateless5/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir pruned_transducer_stateless4/exp \
|
||||
--exp-dir pruned_transducer_stateless5/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 300
|
||||
|
||||
# For mix precision training:
|
||||
|
||||
./pruned_transducer_stateless4/train.py \
|
||||
./pruned_transducer_stateless5/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--use_fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless4/exp \
|
||||
--use-fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless5/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 550
|
||||
|
||||
@ -185,7 +185,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless4/exp",
|
||||
default="pruned_transducer_stateless5/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
@ -712,25 +712,29 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
warmup=(params.batch_idx_train / params.model_warm_step),
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
warmup=(params.batch_idx_train / params.model_warm_step),
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
scaler.scale(loss).backward()
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
scaler.scale(loss).backward()
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
except: # noqa
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
raise
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
return
|
||||
@ -975,6 +979,38 @@ def run(rank, world_size, args):
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def display_and_save_batch(
|
||||
batch: dict,
|
||||
params: AttributeDict,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
) -> None:
|
||||
"""Display the batch statistics and save the batch into disk.
|
||||
|
||||
Args:
|
||||
batch:
|
||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||
for the content in it.
|
||||
params:
|
||||
Parameters for training. See :func:`get_params`.
|
||||
sp:
|
||||
The BPE model.
|
||||
"""
|
||||
from lhotse.utils import uuid4
|
||||
|
||||
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
|
||||
logging.info(f"Saving batch to {filename}")
|
||||
torch.save(batch, filename)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
features = batch["inputs"]
|
||||
|
||||
logging.info(f"features shape: {features.shape}")
|
||||
|
||||
y = sp.encode(supervisions["text"], out_type=int)
|
||||
num_tokens = sum(len(i) for i in y)
|
||||
logging.info(f"num tokens: {num_tokens}")
|
||||
|
||||
|
||||
def scan_pessimistic_batches_for_oom(
|
||||
model: nn.Module,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
@ -1006,7 +1042,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
except RuntimeError as e:
|
||||
except Exception as e:
|
||||
if "CUDA out of memory" in str(e):
|
||||
logging.error(
|
||||
"Your GPU ran out of memory with the current "
|
||||
@ -1015,6 +1051,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
f"Failing criterion: {criterion} "
|
||||
f"(={crit_values[criterion]}) ..."
|
||||
)
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
raise
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user