Update results.

This commit is contained in:
Fangjun Kuang 2022-05-13 18:04:03 +08:00
parent 994b8a7716
commit a613e85900
6 changed files with 312 additions and 116 deletions

View File

@ -19,7 +19,8 @@ The following table lists the differences among them.
| `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss | | `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss |
| `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | | `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss |
| `pruned_transducer_stateless3` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss + using GigaSpeech as extra training data | | `pruned_transducer_stateless3` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss + using GigaSpeech as extra training data |
| `pruned_transducer_stateless4` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss + Random combiner| | `pruned_transducer_stateless4` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless2 + save averaged models periodically during training |
| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss + more layers + Random combiner|
The decoder in `transducer_stateless` is modified from the paper The decoder in `transducer_stateless` is modified from the paper

View File

@ -1,9 +1,142 @@
## Results ## Results
### LibriSpeech BPE training results (Pruned Transducer 3) ### LibriSpeech BPE training results (Pruned Stateless Transducer 5)
[pruned_transducer_stateless5](./pruned_transducer_stateless5)
Same as `Pruned Stateless Transducer 2` but with more layers.
See <https://github.com/k2-fsa/icefall/pull/330>
Note models in `pruned_transducer_stateless` and `pruned_transducer_stateless2`
have about 80 M parameters.
The notations `large` and `medium` below are from the [Conformer](https://arxiv.org/pdf/2005.08100.pdf)
paper, where the large model has about 188 M parameters and the medium model
has 30.8 M parameters.
#### Large
Number of model parameters 118129516 (i.e, 118.13 M).
| | test-clean | test-other | comment |
|-------------------------------------|------------|------------|----------------------------------------|
| greedy search (max sym per frame 1) | 2.39 | 5.57 | --epoch 39 --avg 7 --max-duration 600 |
| modified beam search | 2.35 | 5.50 | --epoch 39 --avg 7 --max-duration 600 |
| fast beam search | 2.38 | 5.50 | --epoch 39 --avg 7 --max-duration 600 |
The training commands are:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./pruned_transducer_stateless5/train.py \
--world-size 8 \
--num-epochs 40 \
--start-epoch 0 \
--full-libri 1 \
--exp-dir pruned_transducer_stateless5/exp-L \
--max-duration 300 \
--use-fp16 0 \
--num-encoder-layers 18 \
--dim-feedforward 2048 \
--nhead 8 \
--encoder-dim 512 \
--decoder-dim 512 \
--joiner-dim 512
```
The tensorboard log can be found at
<https://tensorboard.dev/experiment/Zq0h3KpnQ2igWbeR4U82Pw/>
The decoding commands are:
```bash
for method in greedy_search modified_beam_search fast_beam_search; do
./pruned_transducer_stateless5/decode.py \
--epoch 39 \
--avg 7 \
--exp-dir ./pruned_transducer_stateless5/exp-L \
--max-duration 600 \
--decoding-method $method \
--max-sym-per-frame 1 \
--num-encoder-layers 18 \
--dim-feedforward 2048 \
--nhead 8 \
--encoder-dim 512 \
--decoder-dim 512 \
--joiner-dim 512
done
```
You can find a pretrained model, training logs, decoding logs, and decoding
results at:
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13>
#### Medium
Number of model parameters 30896748 (i.e, 30.9 M).
| | test-clean | test-other | comment |
|-------------------------------------|------------|------------|-----------------------------------------|
| greedy search (max sym per frame 1) | 2.88 | 6.69 | --epoch 39 --avg 17 --max-duration 600 |
| modified beam search | 2.83 | 6.59 | --epoch 39 --avg 17 --max-duration 600 |
| fast beam search | 2.83 | 6.61 | --epoch 39 --avg 17 --max-duration 600 |
The training commands are:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./pruned_transducer_stateless5/train.py \
--world-size 8 \
--num-epochs 40 \
--start-epoch 0 \
--full-libri 1 \
--exp-dir pruned_transducer_stateless5/exp-M \
--max-duration 300 \
--use-fp16 0 \
--num-encoder-layers 18 \
--dim-feedforward 1024 \
--nhead 4 \
--encoder-dim 256 \
--decoder-dim 512 \
--joiner-dim 512
```
The tensorboard log can be found at
<https://tensorboard.dev/experiment/bOQvULPsQ1iL7xpdI0VbXw/>
The decoding commands are:
```bash
for method in greedy_search modified_beam_search fast_beam_search; do
./pruned_transducer_stateless5/decode.py \
--epoch 39 \
--avg 17 \
--exp-dir ./pruned_transducer_stateless5/exp-M \
--max-duration 600 \
--decoding-method $method \
--max-sym-per-frame 1 \
--num-encoder-layers 18 \
--dim-feedforward 1024 \
--nhead 4 \
--encoder-dim 256 \
--decoder-dim 512 \
--joiner-dim 512
done
```
You can find a pretrained model, training logs, decoding logs, and decoding
results at:
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-M-2022-05-13>
### LibriSpeech BPE training results (Pruned Stateless Transducer 3)
[pruned_transducer_stateless3](./pruned_transducer_stateless3) [pruned_transducer_stateless3](./pruned_transducer_stateless3)
Same as `Pruned Transducer 2` but using the XL subset from Same as `Pruned Stateless Transducer 2` but using the XL subset from
[GigaSpeech](https://github.com/SpeechColab/GigaSpeech) as extra training data. [GigaSpeech](https://github.com/SpeechColab/GigaSpeech) as extra training data.
During training, it selects either a batch from GigaSpeech with prob `giga_prob` During training, it selects either a batch from GigaSpeech with prob `giga_prob`
@ -104,6 +237,7 @@ done
The following table shows the The following table shows the
[Nbest oracle WER](http://kaldi-asr.org/doc/lattices.html#lattices_operations_oracle) [Nbest oracle WER](http://kaldi-asr.org/doc/lattices.html#lattices_operations_oracle)
for fast beam search. for fast beam search.
| epoch | avg | num_paths | nbest_scale | test-clean | test-other | | epoch | avg | num_paths | nbest_scale | test-clean | test-other |
|-------|-----|-----------|-------------|------------|------------| |-------|-----|-----------|-------------|------------|------------|
| 27 | 10 | 50 | 0.5 | 0.91 | 2.74 | | 27 | 10 | 50 | 0.5 | 0.91 | 2.74 |

View File

@ -19,40 +19,40 @@
Usage: Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (not recommended) (2) beam search (not recommended)
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search (4) fast beam search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 4 \
--max-contexts 4 \ --max-contexts 4 \
--max-states 8 --max-states 8
""" """
@ -485,7 +485,7 @@ def main():
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py # <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()

View File

@ -18,41 +18,41 @@
""" """
Usage: Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless4/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (2) beam search (not recommended)
./pruned_transducer_stateless4/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless4/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search (4) fast beam search
./pruned_transducer_stateless4/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 1500 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 4 \
--max-contexts 4 \ --max-contexts 4 \
--max-states 8 --max-states 8
""" """
@ -69,7 +69,7 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search, fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
@ -98,33 +98,34 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=28, default=28,
help="It specifies the checkpoint to use for decoding." help="""It specifies the checkpoint to use for decoding.
"Note: Epoch counts from 0.", Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
) )
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch' and '--iter'",
)
parser.add_argument(
"--avg-last-n",
type=int,
default=0,
help="""If positive, --epoch and --avg are ignored and it
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
where xxx is the number of processed batches while
saving that checkpoint.
""",
) )
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless4/exp", default="pruned_transducer_stateless5/exp",
help="The experiment dir", help="The experiment dir",
) )
@ -151,7 +152,7 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="""An interger indicating how many candidates we will keep for each help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or frame. Used only when --decoding-method is beam_search or
modified_beam_search.""", modified_beam_search.""",
) )
@ -253,7 +254,7 @@ def decode_one_batch(
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search( hyp_tokens = fast_beam_search_one_best(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -271,6 +272,7 @@ def decode_one_batch(
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -278,6 +280,7 @@ def decode_one_batch(
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
@ -357,9 +360,9 @@ def decode_dataset(
num_batches = "?" num_batches = "?"
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
log_interval = 100 log_interval = 50
else: else:
log_interval = 2 log_interval = 10
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
@ -455,13 +458,19 @@ def main():
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -478,8 +487,9 @@ def main():
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py # <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)
@ -487,8 +497,20 @@ def main():
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
if params.avg_last_n > 0: if params.iter > 0:
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}") logging.info(f"averaging {filenames}")
model.to(device) model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))

View File

@ -20,23 +20,23 @@
# to a single one using model averaging. # to a single one using model averaging.
""" """
Usage: Usage:
./pruned_transducer_stateless2/export.py \ ./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
It will generate a file exp_dir/pretrained.pt It will generate a file exp_dir/pretrained.pt
To use the generated file with `pruned_transducer_stateless2/decode.py`, To use the generated file with `pruned_transducer_stateless5/decode.py`,
you can do: you can do:
cd /path/to/exp_dir cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR cd /path/to/egs/librispeech/ASR
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless5/decode.py \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--epoch 9999 \ --epoch 9999 \
--avg 1 \ --avg 1 \
--max-duration 100 \ --max-duration 100 \
@ -49,7 +49,7 @@ from pathlib import Path
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from train import get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import str2bool from icefall.utils import str2bool
@ -80,7 +80,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless2/exp", default="pruned_transducer_stateless5/exp",
help="""It specifies the directory where all training related help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
""", """,
@ -109,6 +109,8 @@ def get_parser():
"2 means tri-gram", "2 means tri-gram",
) )
add_model_arguments(parser)
return parser return parser

View File

@ -21,22 +21,22 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless4/train.py \ ./pruned_transducer_stateless5/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir pruned_transducer_stateless4/exp \ --exp-dir pruned_transducer_stateless5/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 300 --max-duration 300
# For mix precision training: # For mix precision training:
./pruned_transducer_stateless4/train.py \ ./pruned_transducer_stateless5/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 0 \
--use_fp16 1 \ --use-fp16 1 \
--exp-dir pruned_transducer_stateless4/exp \ --exp-dir pruned_transducer_stateless5/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 550 --max-duration 550
@ -185,7 +185,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless4/exp", default="pruned_transducer_stateless5/exp",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -712,25 +712,29 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): try:
loss, loss_info = compute_loss( with torch.cuda.amp.autocast(enabled=params.use_fp16):
params=params, loss, loss_info = compute_loss(
model=model, params=params,
sp=sp, model=model,
batch=batch, sp=sp,
is_training=True, batch=batch,
warmup=(params.batch_idx_train / params.model_warm_step), is_training=True,
) warmup=(params.batch_idx_train / params.model_warm_step),
# summary stats )
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
scaler.scale(loss).backward() scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train) scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
except: # noqa
display_and_save_batch(batch, params=params, sp=sp)
raise
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
return return
@ -975,6 +979,38 @@ def run(rank, world_size, args):
cleanup_dist() cleanup_dist()
def display_and_save_batch(
batch: dict,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
) -> None:
"""Display the batch statistics and save the batch into disk.
Args:
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
params:
Parameters for training. See :func:`get_params`.
sp:
The BPE model.
"""
from lhotse.utils import uuid4
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
logging.info(f"Saving batch to {filename}")
torch.save(batch, filename)
supervisions = batch["supervisions"]
features = batch["inputs"]
logging.info(f"features shape: {features.shape}")
y = sp.encode(supervisions["text"], out_type=int)
num_tokens = sum(len(i) for i in y)
logging.info(f"num tokens: {num_tokens}")
def scan_pessimistic_batches_for_oom( def scan_pessimistic_batches_for_oom(
model: nn.Module, model: nn.Module,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
@ -1006,7 +1042,7 @@ def scan_pessimistic_batches_for_oom(
loss.backward() loss.backward()
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
except RuntimeError as e: except Exception as e:
if "CUDA out of memory" in str(e): if "CUDA out of memory" in str(e):
logging.error( logging.error(
"Your GPU ran out of memory with the current " "Your GPU ran out of memory with the current "
@ -1015,6 +1051,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} " f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..." f"(={crit_values[criterion]}) ..."
) )
display_and_save_batch(batch, params=params, sp=sp)
raise raise