mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Merge branch 'k2-fsa:master' into master
This commit is contained in:
commit
50471d6f11
@ -375,7 +375,7 @@ Please see: [
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This file generates the manifest and computes the fbank features for AudioSet
|
||||
dataset. The generated manifests and features are stored in data/fbank.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import lhotse
|
||||
from lhotse import load_manifest
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input-manifest", type=str, default="data/fbank/cuts_audioset_full.jsonl.gz"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
required=True,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
# Reference: https://github.com/YuanGongND/ast/blob/master/egs/audioset/gen_weight_file.py
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
cuts = load_manifest(args.input_manifest)
|
||||
|
||||
print(f"A total of {len(cuts)} cuts.")
|
||||
|
||||
label_count = [0] * 527 # a total of 527 classes
|
||||
for c in cuts:
|
||||
audio_event = c.supervisions[0].audio_event
|
||||
labels = list(map(int, audio_event.split(";")))
|
||||
for label in labels:
|
||||
label_count[label] += 1
|
||||
|
||||
with open(args.output, "w") as f:
|
||||
for c in cuts:
|
||||
audio_event = c.supervisions[0].audio_event
|
||||
labels = list(map(int, audio_event.split(";")))
|
||||
weight = 0
|
||||
for label in labels:
|
||||
weight += 1000 / (label_count[label] + 0.01)
|
||||
f.write(f"{c.id} {weight}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -10,6 +10,7 @@ stage=-1
|
||||
stop_stage=4
|
||||
|
||||
dl_dir=$PWD/download
|
||||
fbank_dir=data/fbank
|
||||
|
||||
# we assume that you have your downloaded the AudioSet and placed
|
||||
# it under $dl_dir/audioset, the folder structure should look like
|
||||
@ -49,7 +50,6 @@ fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Construct the audioset manifest and compute the fbank features for balanced set"
|
||||
fbank_dir=data/fbank
|
||||
if [! -e $fbank_dir/.balanced.done]; then
|
||||
python local/generate_audioset_manifest.py \
|
||||
--dataset-dir $dl_dir/audioset \
|
||||
@ -102,3 +102,14 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
touch data/fbank/.musan.done
|
||||
fi
|
||||
fi
|
||||
|
||||
# The following stages are required to do weighted-sampling training
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Prepare for weighted-sampling training"
|
||||
if [ ! -e $fbank_dir/cuts_audioset_full.jsonl.gz ]; then
|
||||
lhotse combine $fbank_dir/cuts_audioset_balanced.jsonl.gz $fbank_dir/cuts_audioset_unbalanced.jsonl.gz $fbank_dir/cuts_audioset_full.jsonl.gz
|
||||
fi
|
||||
python ./local/compute_weight.py \
|
||||
--input-manifest $fbank_dir/cuts_audioset_full.jsonl.gz \
|
||||
--output $fbank_dir/sampling_weights_full.txt
|
||||
fi
|
||||
|
@ -31,6 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
WeightedSimpleCutSampler,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
@ -99,6 +100,20 @@ class AudioSetATDatamodule:
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--weighted-sampler",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, samples are drawn from by their weights. "
|
||||
"It cannot be used together with bucketing sampler",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=200000,
|
||||
help="The number of samples to be drawn in each epoch. Only be used"
|
||||
"for weighed sampler",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
@ -295,6 +310,9 @@ class AudioSetATDatamodule:
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
assert (
|
||||
not self.args.weighted_sampler
|
||||
), "weighted sampling is not supported in bucket sampler"
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
@ -304,13 +322,26 @@ class AudioSetATDatamodule:
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
if self.args.weighted_sampler:
|
||||
# assert self.args.audioset_subset == "full", "Only use weighted sampling for full audioset"
|
||||
logging.info("Using weighted SimpleCutSampler")
|
||||
weights = self.audioset_sampling_weights()
|
||||
train_sampler = WeightedSimpleCutSampler(
|
||||
cuts_train,
|
||||
weights,
|
||||
num_samples=self.args.num_samples,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False, # do not support shuffle
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
@ -373,11 +404,9 @@ class AudioSetATDatamodule:
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = AudioTaggingDataset(
|
||||
input_strategy=(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)()
|
||||
),
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
@ -397,21 +426,30 @@ class AudioSetATDatamodule:
|
||||
@lru_cache()
|
||||
def audioset_train_cuts(self) -> CutSet:
|
||||
logging.info("About to get the audioset training cuts.")
|
||||
balanced_cuts = load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz"
|
||||
)
|
||||
if self.args.audioset_subset == "full":
|
||||
unbalanced_cuts = load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz"
|
||||
)
|
||||
cuts = CutSet.mux(
|
||||
balanced_cuts,
|
||||
unbalanced_cuts,
|
||||
weights=[20000, 2000000],
|
||||
stop_early=True,
|
||||
if not self.args.weighted_sampler:
|
||||
balanced_cuts = load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz"
|
||||
)
|
||||
if self.args.audioset_subset == "full":
|
||||
unbalanced_cuts = load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz"
|
||||
)
|
||||
cuts = CutSet.mux(
|
||||
balanced_cuts,
|
||||
unbalanced_cuts,
|
||||
weights=[20000, 2000000],
|
||||
stop_early=True,
|
||||
)
|
||||
else:
|
||||
cuts = balanced_cuts
|
||||
else:
|
||||
cuts = balanced_cuts
|
||||
# assert self.args.audioset_subset == "full", "Only do weighted sampling for full AudioSet"
|
||||
cuts = load_manifest(
|
||||
self.args.manifest_dir
|
||||
/ f"cuts_audioset_{self.args.audioset_subset}.jsonl.gz"
|
||||
)
|
||||
logging.info(f"Get {len(cuts)} cuts in total.")
|
||||
|
||||
return cuts
|
||||
|
||||
@lru_cache()
|
||||
@ -420,3 +458,22 @@ class AudioSetATDatamodule:
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_audioset_eval.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def audioset_sampling_weights(self):
|
||||
logging.info(
|
||||
f"About to get the sampling weight for {self.args.audioset_subset} in AudioSet"
|
||||
)
|
||||
weights = []
|
||||
with open(
|
||||
self.args.manifest_dir / f"sample_weights_{self.args.audioset_subset}.txt",
|
||||
"r",
|
||||
) as f:
|
||||
while True:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
weight = float(line.split()[1])
|
||||
weights.append(weight)
|
||||
logging.info(f"Get the sampling weight for {len(weights)} cuts")
|
||||
return weights
|
||||
|
@ -789,12 +789,14 @@ def train_one_epoch(
|
||||
rank=0,
|
||||
)
|
||||
|
||||
num_samples = 0
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx % 10 == 0:
|
||||
set_batch_count(model, get_adjusted_batch_count(params))
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = batch["inputs"].size(0)
|
||||
num_samples += batch_size
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
@ -919,6 +921,12 @@ def train_one_epoch(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
|
||||
if num_samples > params.num_samples:
|
||||
logging.info(
|
||||
f"Number of training samples exceeds {params.num_samples} in this epoch, move on to next epoch"
|
||||
)
|
||||
break
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
@ -1032,7 +1040,8 @@ def run(rank, world_size, args):
|
||||
|
||||
return True
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
if not params.weighted_sampler:
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||
# We only load the sampler's state dict when it loads a checkpoint
|
||||
|
@ -29,17 +29,21 @@ def simple_cleanup(text: str) -> str:
|
||||
|
||||
# Assign text of the supervisions and remove unnecessary entries.
|
||||
def main():
|
||||
assert len(sys.argv) == 3, "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR"
|
||||
assert (
|
||||
len(sys.argv) == 4
|
||||
), "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR KEEP_CUSTOM_FIELDS"
|
||||
fname = Path(sys.argv[1]).name
|
||||
oname = Path(sys.argv[2]) / fname
|
||||
keep_custom_fields = bool(sys.argv[3])
|
||||
with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout:
|
||||
for line in fin:
|
||||
cut = json.loads(line)
|
||||
cut["supervisions"][0]["text"] = simple_cleanup(
|
||||
cut["supervisions"][0]["custom"]["texts"][0]
|
||||
)
|
||||
del cut["supervisions"][0]["custom"]
|
||||
del cut["custom"]
|
||||
if not keep_custom_fields:
|
||||
del cut["supervisions"][0]["custom"]
|
||||
del cut["custom"]
|
||||
fout.write((json.dumps(cut) + "\n").encode())
|
||||
|
||||
|
||||
|
@ -29,6 +29,11 @@ export CUDA_VISIBLE_DEVICES=""
|
||||
# - speech
|
||||
dl_dir=$PWD/download
|
||||
|
||||
# If you want to do PromptASR experiments, please set it to True
|
||||
# as this will keep the texts and pre_text information required for
|
||||
# the training of PromptASR.
|
||||
keep_custom_fields=False
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# vocab size for sentence piece models.
|
||||
@ -134,7 +139,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
for subset in small medium large dev test_clean test_other; do
|
||||
if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then
|
||||
log "Prepare manifest for subset : ${subset}"
|
||||
./local/prepare_manifest.py $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz $manifests_dir
|
||||
./local/prepare_manifest.py $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz $manifests_dir $keep_custom_fields
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
@ -307,6 +307,23 @@ done
|
||||
|
||||
To decode with external language models, please refer to the documentation [here](https://k2-fsa.github.io/icefall/decoding-with-langugage-models/index.html).
|
||||
|
||||
We also support training Zipformer with AMP+bf16 format (requires bf16 support). See [here](https://github.com/k2-fsa/icefall/pull/1700) for more details and pre-trained models. **The same command can be used for decoding and exporting the model.**
|
||||
|
||||
The amp+bf16 training command is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
./zipformer/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 50 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 0 \
|
||||
--use-bf16 1 \
|
||||
--exp-dir zipformer/exp_amp_bf16 \
|
||||
--causal 0 \
|
||||
--full-libri 1 \
|
||||
--max-duration 1000
|
||||
```
|
||||
|
||||
##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M
|
||||
|
||||
The tensorboard log can be found at
|
||||
|
@ -120,6 +120,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from lhotse import set_caching_enabled
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
@ -296,6 +297,13 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--skip-scoring",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Skip scoring, but still save the ASR output (for eval sets)."""
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -455,7 +463,7 @@ def decode_one_batch(
|
||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||
hyps = [s.split() for s in hyps]
|
||||
key = "ctc-decoding"
|
||||
return {key: hyps}
|
||||
return {key: hyps} # note: returns words
|
||||
|
||||
if params.decoding_method == "attention-decoder-rescoring-no-ngram":
|
||||
best_path_dict = rescore_with_attention_decoder_no_ngram(
|
||||
@ -492,7 +500,7 @@ def decode_one_batch(
|
||||
)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
|
||||
key = f"oracle_{params.num_paths}_nbest-scale-{params.nbest_scale}" # noqa
|
||||
return {key: hyps}
|
||||
|
||||
if params.decoding_method in ["1best", "nbest"]:
|
||||
@ -500,7 +508,7 @@ def decode_one_batch(
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
key = "no_rescore"
|
||||
key = "no-rescore"
|
||||
else:
|
||||
best_path = nbest_decoding(
|
||||
lattice=lattice,
|
||||
@ -508,11 +516,11 @@ def decode_one_batch(
|
||||
use_double_scores=params.use_double_scores,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
|
||||
key = f"no-rescore_nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
return {key: hyps}
|
||||
return {key: hyps} # note: returns BPE tokens
|
||||
|
||||
assert params.decoding_method in [
|
||||
"nbest-rescoring",
|
||||
@ -646,7 +654,27 @@ def decode_dataset(
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
def save_asr_output(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
"""
|
||||
Save text produced by ASR.
|
||||
"""
|
||||
for key, results in results_dict.items():
|
||||
|
||||
recogs_filename = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
)
|
||||
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recogs_filename, texts=results)
|
||||
|
||||
logging.info(f"The transcripts are stored in {recogs_filename}")
|
||||
|
||||
|
||||
def save_wer_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
@ -661,32 +689,30 @@ def save_results(
|
||||
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
|
||||
with open(errs_filename, "w", encoding="utf8") as fd:
|
||||
wer = write_error_stats(
|
||||
fd, f"{test_set_name}_{key}", results, enable_log=enable_log
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
if enable_log:
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
logging.info(f"Wrote detailed error stats to {errs_filename}")
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
|
||||
with open(wer_filename, "w", encoding="utf8") as fd:
|
||||
print("settings\tWER", file=fd)
|
||||
for key, val in test_set_wers:
|
||||
print(f"{key}\t{val}", file=fd)
|
||||
|
||||
s = f"\nFor {test_set_name}, WER of different settings are:\n"
|
||||
note = f"\tbest for {test_set_name}"
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
s += f"{key}\t{val}{note}\n"
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
@ -705,6 +731,9 @@ def main():
|
||||
params.update(get_decoding_params())
|
||||
params.update(vars(args))
|
||||
|
||||
# enable AudioCache
|
||||
set_caching_enabled(True) # lhotse
|
||||
|
||||
assert params.decoding_method in (
|
||||
"ctc-greedy-search",
|
||||
"ctc-decoding",
|
||||
@ -719,9 +748,9 @@ def main():
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
params.suffix = f"iter-{params.iter}_avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
params.suffix = f"epoch-{params.epoch}_avg-{params.avg}"
|
||||
|
||||
if params.causal:
|
||||
assert (
|
||||
@ -730,11 +759,11 @@ def main():
|
||||
assert (
|
||||
"," not in params.left_context_frames
|
||||
), "left_context_frames should be one value in decoding."
|
||||
params.suffix += f"-chunk-{params.chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context_frames}"
|
||||
params.suffix += f"_chunk-{params.chunk_size}"
|
||||
params.suffix += f"_left-context-{params.left_context_frames}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
params.suffix += "_use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
@ -940,12 +969,19 @@ def main():
|
||||
G=G,
|
||||
)
|
||||
|
||||
save_results(
|
||||
save_asr_output(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
if not params.skip_scoring:
|
||||
save_wer_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
|
@ -121,6 +121,7 @@ from beam_search import (
|
||||
modified_beam_search_lm_shallow_fusion,
|
||||
modified_beam_search_LODR,
|
||||
)
|
||||
from lhotse import set_caching_enabled
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall import ContextGraph, LmScorer, NgramLm
|
||||
@ -369,6 +370,14 @@ def get_parser():
|
||||
modified_beam_search_LODR.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--skip-scoring",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Skip scoring, but still save the ASR output (for eval sets).""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -590,21 +599,23 @@ def decode_one_batch(
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
# prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" )
|
||||
prefix = f"{params.decoding_method}"
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
prefix += f"_beam-{params.beam}"
|
||||
prefix += f"_max-contexts-{params.max_contexts}"
|
||||
prefix += f"_max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
prefix += f"_num-paths-{params.num_paths}"
|
||||
prefix += f"_nbest-scale-{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
prefix += f"_ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
return {prefix: hyps}
|
||||
elif "modified_beam_search" in params.decoding_method:
|
||||
prefix = f"beam_size_{params.beam_size}"
|
||||
prefix += f"_beam-size-{params.beam_size}"
|
||||
if params.decoding_method in (
|
||||
"modified_beam_search_lm_rescore",
|
||||
"modified_beam_search_lm_rescore_LODR",
|
||||
@ -617,10 +628,11 @@ def decode_one_batch(
|
||||
return ans
|
||||
else:
|
||||
if params.has_contexts:
|
||||
prefix += f"-context-score-{params.context_score}"
|
||||
prefix += f"_context-score-{params.context_score}"
|
||||
return {prefix: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
prefix += f"_beam-size-{params.beam_size}"
|
||||
return {prefix: hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -707,46 +719,58 @@ def decode_dataset(
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
def save_asr_output(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
"""
|
||||
Save text produced by ASR.
|
||||
"""
|
||||
for key, results in results_dict.items():
|
||||
|
||||
recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recogs_filename, texts=results)
|
||||
|
||||
logging.info(f"The transcripts are stored in {recogs_filename}")
|
||||
|
||||
|
||||
def save_wer_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]],
|
||||
):
|
||||
"""
|
||||
Save WER and per-utterance word alignments.
|
||||
"""
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w", encoding="utf8") as fd:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
fd, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
logging.info(f"Wrote detailed error stats to {errs_filename}")
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
|
||||
with open(wer_filename, "w", encoding="utf8") as fd:
|
||||
print("settings\tWER", file=fd)
|
||||
for key, val in test_set_wers:
|
||||
print(f"{key}\t{val}", file=fd)
|
||||
|
||||
s = f"\nFor {test_set_name}, WER of different settings are:\n"
|
||||
note = f"\tbest for {test_set_name}"
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
s += f"{key}\t{val}{note}\n"
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
@ -762,6 +786,9 @@ def main():
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
# enable AudioCache
|
||||
set_caching_enabled(True) # lhotse
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
@ -783,9 +810,9 @@ def main():
|
||||
params.has_contexts = False
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
params.suffix = f"iter-{params.iter}_avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
params.suffix = f"epoch-{params.epoch}_avg-{params.avg}"
|
||||
|
||||
if params.causal:
|
||||
assert (
|
||||
@ -794,20 +821,20 @@ def main():
|
||||
assert (
|
||||
"," not in params.left_context_frames
|
||||
), "left_context_frames should be one value in decoding."
|
||||
params.suffix += f"-chunk-{params.chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context_frames}"
|
||||
params.suffix += f"_chunk-{params.chunk_size}"
|
||||
params.suffix += f"_left-context-{params.left_context_frames}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
params.suffix += f"_beam-{params.beam}"
|
||||
params.suffix += f"_max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"_max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
params.suffix += f"_nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"_num-paths-{params.num_paths}"
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
params.suffix += f"_ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}"
|
||||
if params.decoding_method in (
|
||||
"modified_beam_search",
|
||||
"modified_beam_search_LODR",
|
||||
@ -815,19 +842,19 @@ def main():
|
||||
if params.has_contexts:
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
params.suffix += f"_context-{params.context_size}"
|
||||
params.suffix += f"_max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_shallow_fusion:
|
||||
params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
|
||||
params.suffix += f"_{params.lm_type}-lm-scale-{params.lm_scale}"
|
||||
|
||||
if "LODR" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
|
||||
f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
|
||||
)
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
params.suffix += "_use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
@ -1038,12 +1065,19 @@ def main():
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
)
|
||||
|
||||
save_results(
|
||||
save_asr_output(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
if not params.skip_scoring:
|
||||
save_wer_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
|
@ -218,7 +218,7 @@ class OnnxEncoder(nn.Module):
|
||||
- encoder_out_lens, A 1-D tensor of shape (N,)
|
||||
"""
|
||||
x, x_lens = self.encoder_embed(x, x_lens)
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
src_key_padding_mask = make_pad_mask(x_lens, x.shape[1])
|
||||
x = x.permute(1, 0, 2)
|
||||
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||
encoder_out = encoder_out.permute(1, 0, 2)
|
||||
|
@ -297,7 +297,7 @@ class SoftmaxFunction(torch.autograd.Function):
|
||||
# (presumably) that op does not support float16, and autocast
|
||||
# is enabled.
|
||||
if torch.is_autocast_enabled():
|
||||
ans = ans.to(torch.float16)
|
||||
ans = ans.to(torch.get_autocast_gpu_dtype())
|
||||
ctx.save_for_backward(ans)
|
||||
ctx.x_dtype = x.dtype
|
||||
ctx.dim = dim
|
||||
@ -1234,7 +1234,7 @@ class DoubleSwishFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor) -> Tensor:
|
||||
requires_grad = x.requires_grad
|
||||
if x.dtype == torch.float16:
|
||||
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
|
||||
x = x.to(torch.float32)
|
||||
|
||||
s = torch.sigmoid(x - 1.0)
|
||||
@ -1346,7 +1346,7 @@ class SwooshLFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor) -> Tensor:
|
||||
requires_grad = x.requires_grad
|
||||
if x.dtype == torch.float16:
|
||||
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
|
||||
x = x.to(torch.float32)
|
||||
|
||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||
@ -1379,7 +1379,7 @@ class SwooshLFunction(torch.autograd.Function):
|
||||
d_int = d_scaled.to(torch.uint8)
|
||||
ctx.save_for_backward(d_int)
|
||||
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
||||
y = y.to(torch.float16)
|
||||
y = y.to(torch.get_autocast_gpu_dtype())
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
@ -1425,7 +1425,7 @@ class SwooshRFunction(torch.autograd.Function):
|
||||
def forward(ctx, x: Tensor) -> Tensor:
|
||||
requires_grad = x.requires_grad
|
||||
|
||||
if x.dtype == torch.float16:
|
||||
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
|
||||
x = x.to(torch.float32)
|
||||
|
||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||
@ -1455,7 +1455,7 @@ class SwooshRFunction(torch.autograd.Function):
|
||||
d_int = d_scaled.to(torch.uint8)
|
||||
ctx.save_for_backward(d_int)
|
||||
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
||||
y = y.to(torch.float16)
|
||||
y = y.to(torch.get_autocast_gpu_dtype())
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
|
@ -43,7 +43,7 @@ import torch
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from decode_stream import DecodeStream
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
from lhotse import CutSet, set_caching_enabled
|
||||
from streaming_beam_search import (
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
@ -76,6 +76,13 @@ def get_parser():
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--label",
|
||||
type=str,
|
||||
default="",
|
||||
help="""Extra label of the decoding run.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
@ -188,6 +195,14 @@ def get_parser():
|
||||
help="The number of streams that can be decoded parallel.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--skip-scoring",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Skip scoring, but still save the ASR output (for eval sets)."""
|
||||
)
|
||||
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -640,46 +655,60 @@ def decode_dataset(
|
||||
return {key: decode_results}
|
||||
|
||||
|
||||
def save_results(
|
||||
def save_asr_output(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
"""
|
||||
Save text produced by ASR.
|
||||
"""
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
recogs_filename = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
store_transcripts(filename=recogs_filename, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recogs_filename}")
|
||||
|
||||
def save_wer_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
|
||||
):
|
||||
"""
|
||||
Save WER and per-utterance word alignments.
|
||||
"""
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
with open(errs_filename, "w", encoding="utf8") as fd:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
fd, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
logging.info(f"Wrote detailed error stats to {errs_filename}")
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
|
||||
wer_filename = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
with open(wer_filename, "w", encoding="utf8") as fd:
|
||||
print("settings\tWER", file=fd)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
print(f"{key}\t{val}", file=fd)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
s = f"\nFor {test_set_name}, WER of different settings are:\n"
|
||||
note = f"\tbest for {test_set_name}"
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
s += f"{key}\t{val}{note}\n"
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
@ -694,6 +723,9 @@ def main():
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
# enable AudioCache
|
||||
set_caching_enabled(True) # lhotse
|
||||
|
||||
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
@ -706,18 +738,21 @@ def main():
|
||||
assert (
|
||||
"," not in params.left_context_frames
|
||||
), "left_context_frames should be one value in decoding."
|
||||
params.suffix += f"-chunk-{params.chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context_frames}"
|
||||
params.suffix += f"_chunk-{params.chunk_size}"
|
||||
params.suffix += f"_left-context-{params.left_context_frames}"
|
||||
|
||||
# for fast_beam_search
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
params.suffix += f"_beam-{params.beam}"
|
||||
params.suffix += f"_max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"_max-states-{params.max_states}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
if params.label:
|
||||
params.suffix += f"-{params.label}"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
@ -845,12 +880,21 @@ def main():
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
||||
save_asr_output(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
|
||||
if not params.skip_scoring:
|
||||
save_wer_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
|
@ -521,6 +521,13 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-bf16",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use bf16 in AMP.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -1027,7 +1034,9 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch.cuda.amp.autocast(
|
||||
enabled=params.use_autocast, dtype=params.dtype
|
||||
):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1047,9 +1056,7 @@ def train_one_epoch(
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
except Exception as e:
|
||||
logging.info(
|
||||
f"Caught exception: {e}."
|
||||
)
|
||||
logging.info(f"Caught exception: {e}.")
|
||||
save_bad_model()
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
raise
|
||||
@ -1090,7 +1097,7 @@ def train_one_epoch(
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if batch_idx % 100 == 0 and params.use_fp16:
|
||||
if batch_idx % 100 == 0 and params.use_autocast:
|
||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||
# of the grad scaler is configurable, but we can't configure it to have different
|
||||
# behavior depending on the current grad scale.
|
||||
@ -1109,14 +1116,14 @@ def train_one_epoch(
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
||||
cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0
|
||||
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
f"lr: {cur_lr:.2e}, "
|
||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "")
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
@ -1128,7 +1135,7 @@ def train_one_epoch(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
if params.use_fp16:
|
||||
if params.use_autocast:
|
||||
tb_writer.add_scalar(
|
||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||
)
|
||||
@ -1204,9 +1211,25 @@ def run(rank, world_size, args):
|
||||
params.ctc_loss_scale = 1.0
|
||||
else:
|
||||
assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, (
|
||||
params.ctc_loss_scale, params.attention_decoder_loss_scale
|
||||
params.ctc_loss_scale,
|
||||
params.attention_decoder_loss_scale,
|
||||
)
|
||||
|
||||
if params.use_bf16: # amp + bf16
|
||||
assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!"
|
||||
assert not params.use_fp16, "You can only use either fp16 or bf16"
|
||||
params.dtype = torch.bfloat16
|
||||
params.use_autocast = True
|
||||
elif params.use_fp16: # amp + fp16
|
||||
params.dtype = torch.float16
|
||||
params.use_autocast = True
|
||||
else: # fp32
|
||||
params.dtype = torch.float32
|
||||
params.use_autocast = False
|
||||
|
||||
logging.info(f"Using dtype={params.dtype}")
|
||||
logging.info(f"Use AMP={params.use_autocast}")
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
@ -1339,7 +1362,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
@ -1439,7 +1462,9 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch.cuda.amp.autocast(
|
||||
enabled=params.use_autocast, dtype=params.dtype
|
||||
):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
Loading…
x
Reference in New Issue
Block a user