Merge branch 'k2-fsa:master' into master

This commit is contained in:
Yu Lianjie 2024-08-24 12:20:54 +08:00 committed by GitHub
commit 50471d6f11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 498 additions and 159 deletions

View File

@ -375,7 +375,7 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad
[libricss]: egs/libricss/SURT
[libriheavy]: egs/libriheavy/ASR
[mgb2]: egs/mgb2/ASR
[peoplespeech]: egs/peoplespeech/ASR
[peoplespeech]: egs/peoples_speech/ASR
[spgispeech]: egs/spgispeech/ASR
[voxpopuli]: egs/voxpopuli/ASR
[xbmu-amdo31]: egs/xbmu-amdo31/ASR

View File

@ -35,16 +35,40 @@ python zipformer/train.py \
--master-port 13455
```
We recommend that you train the model with weighted sampler, as the model converges
faster with better performance:
| Model | mAP |
| ------ | ------- |
| Zipformer-AT, train with weighted sampler | 46.6 |
The evaluation command is:
```bash
python zipformer/evaluate.py \
--epoch 32 \
--avg 8 \
--exp-dir zipformer/exp_at_as_full \
--max-duration 500
export CUDA_VISIBLE_DEVICES="4,5,6,7"
subset=full
weighted_sampler=1
bucket_sampler=0
lr_epochs=15
python zipformer/train.py \
--world-size 4 \
--audioset-subset $subset \
--num-epochs 120 \
--start-epoch 1 \
--use-fp16 1 \
--num-events 527 \
--lr-epochs $lr_epochs \
--exp-dir zipformer/exp_AS_${subset}_weighted_sampler${weighted_sampler} \
--weighted-sampler $weighted_sampler \
--bucketing-sampler $bucket_sampler \
--max-duration 1000 \
--enable-musan True \
--master-port 13452
```
The command for evaluation is the same. The pre-trained model can be downloaded from https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-M-weighted-sampler
#### small-scaled model, number of model parameters: 22125218, i.e., 22.13 M

View File

@ -0,0 +1,73 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file generates the manifest and computes the fbank features for AudioSet
dataset. The generated manifests and features are stored in data/fbank.
"""
import argparse
import lhotse
from lhotse import load_manifest
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--input-manifest", type=str, default="data/fbank/cuts_audioset_full.jsonl.gz"
)
parser.add_argument(
"--output",
type=str,
required=True,
)
return parser
def main():
# Reference: https://github.com/YuanGongND/ast/blob/master/egs/audioset/gen_weight_file.py
parser = get_parser()
args = parser.parse_args()
cuts = load_manifest(args.input_manifest)
print(f"A total of {len(cuts)} cuts.")
label_count = [0] * 527 # a total of 527 classes
for c in cuts:
audio_event = c.supervisions[0].audio_event
labels = list(map(int, audio_event.split(";")))
for label in labels:
label_count[label] += 1
with open(args.output, "w") as f:
for c in cuts:
audio_event = c.supervisions[0].audio_event
labels = list(map(int, audio_event.split(";")))
weight = 0
for label in labels:
weight += 1000 / (label_count[label] + 0.01)
f.write(f"{c.id} {weight}\n")
if __name__ == "__main__":
main()

View File

@ -10,6 +10,7 @@ stage=-1
stop_stage=4
dl_dir=$PWD/download
fbank_dir=data/fbank
# we assume that you have your downloaded the AudioSet and placed
# it under $dl_dir/audioset, the folder structure should look like
@ -49,7 +50,6 @@ fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Construct the audioset manifest and compute the fbank features for balanced set"
fbank_dir=data/fbank
if [! -e $fbank_dir/.balanced.done]; then
python local/generate_audioset_manifest.py \
--dataset-dir $dl_dir/audioset \
@ -102,3 +102,14 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
touch data/fbank/.musan.done
fi
fi
# The following stages are required to do weighted-sampling training
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare for weighted-sampling training"
if [ ! -e $fbank_dir/cuts_audioset_full.jsonl.gz ]; then
lhotse combine $fbank_dir/cuts_audioset_balanced.jsonl.gz $fbank_dir/cuts_audioset_unbalanced.jsonl.gz $fbank_dir/cuts_audioset_full.jsonl.gz
fi
python ./local/compute_weight.py \
--input-manifest $fbank_dir/cuts_audioset_full.jsonl.gz \
--output $fbank_dir/sampling_weights_full.txt
fi

View File

@ -31,6 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
WeightedSimpleCutSampler,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
@ -99,6 +100,20 @@ class AudioSetATDatamodule:
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--weighted-sampler",
type=str2bool,
default=False,
help="When enabled, samples are drawn from by their weights. "
"It cannot be used together with bucketing sampler",
)
group.add_argument(
"--num-samples",
type=int,
default=200000,
help="The number of samples to be drawn in each epoch. Only be used"
"for weighed sampler",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
@ -295,6 +310,9 @@ class AudioSetATDatamodule:
)
if self.args.bucketing_sampler:
assert (
not self.args.weighted_sampler
), "weighted sampling is not supported in bucket sampler"
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
@ -304,13 +322,26 @@ class AudioSetATDatamodule:
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
drop_last=self.args.drop_last,
)
if self.args.weighted_sampler:
# assert self.args.audioset_subset == "full", "Only use weighted sampling for full audioset"
logging.info("Using weighted SimpleCutSampler")
weights = self.audioset_sampling_weights()
train_sampler = WeightedSimpleCutSampler(
cuts_train,
weights,
num_samples=self.args.num_samples,
max_duration=self.args.max_duration,
shuffle=False, # do not support shuffle
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
drop_last=self.args.drop_last,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
@ -373,11 +404,9 @@ class AudioSetATDatamodule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = AudioTaggingDataset(
input_strategy=(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)()
),
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
@ -397,21 +426,30 @@ class AudioSetATDatamodule:
@lru_cache()
def audioset_train_cuts(self) -> CutSet:
logging.info("About to get the audioset training cuts.")
balanced_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz"
)
if self.args.audioset_subset == "full":
unbalanced_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz"
)
cuts = CutSet.mux(
balanced_cuts,
unbalanced_cuts,
weights=[20000, 2000000],
stop_early=True,
if not self.args.weighted_sampler:
balanced_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz"
)
if self.args.audioset_subset == "full":
unbalanced_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz"
)
cuts = CutSet.mux(
balanced_cuts,
unbalanced_cuts,
weights=[20000, 2000000],
stop_early=True,
)
else:
cuts = balanced_cuts
else:
cuts = balanced_cuts
# assert self.args.audioset_subset == "full", "Only do weighted sampling for full AudioSet"
cuts = load_manifest(
self.args.manifest_dir
/ f"cuts_audioset_{self.args.audioset_subset}.jsonl.gz"
)
logging.info(f"Get {len(cuts)} cuts in total.")
return cuts
@lru_cache()
@ -420,3 +458,22 @@ class AudioSetATDatamodule:
return load_manifest_lazy(
self.args.manifest_dir / "cuts_audioset_eval.jsonl.gz"
)
@lru_cache()
def audioset_sampling_weights(self):
logging.info(
f"About to get the sampling weight for {self.args.audioset_subset} in AudioSet"
)
weights = []
with open(
self.args.manifest_dir / f"sample_weights_{self.args.audioset_subset}.txt",
"r",
) as f:
while True:
line = f.readline()
if not line:
break
weight = float(line.split()[1])
weights.append(weight)
logging.info(f"Get the sampling weight for {len(weights)} cuts")
return weights

View File

@ -789,12 +789,14 @@ def train_one_epoch(
rank=0,
)
num_samples = 0
for batch_idx, batch in enumerate(train_dl):
if batch_idx % 10 == 0:
set_batch_count(model, get_adjusted_batch_count(params))
params.batch_idx_train += 1
batch_size = batch["inputs"].size(0)
num_samples += batch_size
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
@ -919,6 +921,12 @@ def train_one_epoch(
tb_writer, "train/valid_", params.batch_idx_train
)
if num_samples > params.num_samples:
logging.info(
f"Number of training samples exceeds {params.num_samples} in this epoch, move on to next epoch"
)
break
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
@ -1032,7 +1040,8 @@ def run(rank, world_size, args):
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)
if not params.weighted_sampler:
train_cuts = train_cuts.filter(remove_short_and_long_utt)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint

View File

@ -29,17 +29,21 @@ def simple_cleanup(text: str) -> str:
# Assign text of the supervisions and remove unnecessary entries.
def main():
assert len(sys.argv) == 3, "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR"
assert (
len(sys.argv) == 4
), "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR KEEP_CUSTOM_FIELDS"
fname = Path(sys.argv[1]).name
oname = Path(sys.argv[2]) / fname
keep_custom_fields = bool(sys.argv[3])
with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout:
for line in fin:
cut = json.loads(line)
cut["supervisions"][0]["text"] = simple_cleanup(
cut["supervisions"][0]["custom"]["texts"][0]
)
del cut["supervisions"][0]["custom"]
del cut["custom"]
if not keep_custom_fields:
del cut["supervisions"][0]["custom"]
del cut["custom"]
fout.write((json.dumps(cut) + "\n").encode())

View File

@ -29,6 +29,11 @@ export CUDA_VISIBLE_DEVICES=""
# - speech
dl_dir=$PWD/download
# If you want to do PromptASR experiments, please set it to True
# as this will keep the texts and pre_text information required for
# the training of PromptASR.
keep_custom_fields=False
. shared/parse_options.sh || exit 1
# vocab size for sentence piece models.
@ -134,7 +139,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
for subset in small medium large dev test_clean test_other; do
if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then
log "Prepare manifest for subset : ${subset}"
./local/prepare_manifest.py $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz $manifests_dir
./local/prepare_manifest.py $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz $manifests_dir $keep_custom_fields
fi
done
fi

View File

@ -307,6 +307,23 @@ done
To decode with external language models, please refer to the documentation [here](https://k2-fsa.github.io/icefall/decoding-with-langugage-models/index.html).
We also support training Zipformer with AMP+bf16 format (requires bf16 support). See [here](https://github.com/k2-fsa/icefall/pull/1700) for more details and pre-trained models. **The same command can be used for decoding and exporting the model.**
The amp+bf16 training command is:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer/train.py \
--world-size 4 \
--num-epochs 50 \
--start-epoch 1 \
--use-fp16 0 \
--use-bf16 1 \
--exp-dir zipformer/exp_amp_bf16 \
--causal 0 \
--full-libri 1 \
--max-duration 1000
```
##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M
The tensorboard log can be found at

View File

@ -120,6 +120,7 @@ import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from lhotse import set_caching_enabled
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
@ -296,6 +297,13 @@ def get_parser():
""",
)
parser.add_argument(
"--skip-scoring",
type=str2bool,
default=False,
help="""Skip scoring, but still save the ASR output (for eval sets)."""
)
add_model_arguments(parser)
return parser
@ -455,7 +463,7 @@ def decode_one_batch(
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "ctc-decoding"
return {key: hyps}
return {key: hyps} # note: returns words
if params.decoding_method == "attention-decoder-rescoring-no-ngram":
best_path_dict = rescore_with_attention_decoder_no_ngram(
@ -492,7 +500,7 @@ def decode_one_batch(
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
key = f"oracle_{params.num_paths}_nbest-scale-{params.nbest_scale}" # noqa
return {key: hyps}
if params.decoding_method in ["1best", "nbest"]:
@ -500,7 +508,7 @@ def decode_one_batch(
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = "no_rescore"
key = "no-rescore"
else:
best_path = nbest_decoding(
lattice=lattice,
@ -508,11 +516,11 @@ def decode_one_batch(
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
key = f"no-rescore_nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
return {key: hyps}
return {key: hyps} # note: returns BPE tokens
assert params.decoding_method in [
"nbest-rescoring",
@ -646,7 +654,27 @@ def decode_dataset(
return results
def save_results(
def save_asr_output(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
"""
Save text produced by ASR.
"""
for key, results in results_dict.items():
recogs_filename = (
params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recogs_filename, texts=results)
logging.info(f"The transcripts are stored in {recogs_filename}")
def save_wer_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
@ -661,32 +689,30 @@ def save_results(
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
with open(errs_filename, "w", encoding="utf8") as fd:
wer = write_error_stats(
fd, f"{test_set_name}_{key}", results, enable_log=enable_log
)
test_set_wers[key] = wer
if enable_log:
logging.info("Wrote detailed error stats to {}".format(errs_filename))
logging.info(f"Wrote detailed error stats to {errs_filename}")
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(wer_filename, "w", encoding="utf8") as fd:
print("settings\tWER", file=fd)
for key, val in test_set_wers:
print(f"{key}\t{val}", file=fd)
s = f"\nFor {test_set_name}, WER of different settings are:\n"
note = f"\tbest for {test_set_name}"
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
s += f"{key}\t{val}{note}\n"
note = ""
logging.info(s)
@ -705,6 +731,9 @@ def main():
params.update(get_decoding_params())
params.update(vars(args))
# enable AudioCache
set_caching_enabled(True) # lhotse
assert params.decoding_method in (
"ctc-greedy-search",
"ctc-decoding",
@ -719,9 +748,9 @@ def main():
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
params.suffix = f"iter-{params.iter}_avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
params.suffix = f"epoch-{params.epoch}_avg-{params.avg}"
if params.causal:
assert (
@ -730,11 +759,11 @@ def main():
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
params.suffix += f"_chunk-{params.chunk_size}"
params.suffix += f"_left-context-{params.left_context_frames}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
params.suffix += "_use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
@ -940,12 +969,19 @@ def main():
G=G,
)
save_results(
save_asr_output(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
if not params.skip_scoring:
save_wer_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")

View File

@ -121,6 +121,7 @@ from beam_search import (
modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR,
)
from lhotse import set_caching_enabled
from train import add_model_arguments, get_model, get_params
from icefall import ContextGraph, LmScorer, NgramLm
@ -369,6 +370,14 @@ def get_parser():
modified_beam_search_LODR.
""",
)
parser.add_argument(
"--skip-scoring",
type=str2bool,
default=False,
help="""Skip scoring, but still save the ASR output (for eval sets).""",
)
add_model_arguments(parser)
return parser
@ -590,21 +599,23 @@ def decode_one_batch(
)
hyps.append(sp.decode(hyp).split())
# prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" )
prefix = f"{params.decoding_method}"
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
prefix += f"_beam-{params.beam}"
prefix += f"_max-contexts-{params.max_contexts}"
prefix += f"_max-states-{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
prefix += f"_num-paths-{params.num_paths}"
prefix += f"_nbest-scale-{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
prefix += f"_ngram-lm-scale-{params.ngram_lm_scale}"
return {key: hyps}
return {prefix: hyps}
elif "modified_beam_search" in params.decoding_method:
prefix = f"beam_size_{params.beam_size}"
prefix += f"_beam-size-{params.beam_size}"
if params.decoding_method in (
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
@ -617,10 +628,11 @@ def decode_one_batch(
return ans
else:
if params.has_contexts:
prefix += f"-context-score-{params.context_score}"
prefix += f"_context-score-{params.context_score}"
return {prefix: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
prefix += f"_beam-size-{params.beam_size}"
return {prefix: hyps}
def decode_dataset(
@ -707,46 +719,58 @@ def decode_dataset(
return results
def save_results(
def save_asr_output(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
"""
Save text produced by ASR.
"""
for key, results in results_dict.items():
recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recogs_filename, texts=results)
logging.info(f"The transcripts are stored in {recogs_filename}")
def save_wer_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]],
):
"""
Save WER and per-utterance word alignments.
"""
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w", encoding="utf8") as fd:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
fd, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
logging.info(f"Wrote detailed error stats to {errs_filename}")
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(wer_filename, "w", encoding="utf8") as fd:
print("settings\tWER", file=fd)
for key, val in test_set_wers:
print(f"{key}\t{val}", file=fd)
s = f"\nFor {test_set_name}, WER of different settings are:\n"
note = f"\tbest for {test_set_name}"
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
s += f"{key}\t{val}{note}\n"
note = ""
logging.info(s)
@ -762,6 +786,9 @@ def main():
params = get_params()
params.update(vars(args))
# enable AudioCache
set_caching_enabled(True) # lhotse
assert params.decoding_method in (
"greedy_search",
"beam_search",
@ -783,9 +810,9 @@ def main():
params.has_contexts = False
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
params.suffix = f"iter-{params.iter}_avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
params.suffix = f"epoch-{params.epoch}_avg-{params.avg}"
if params.causal:
assert (
@ -794,20 +821,20 @@ def main():
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
params.suffix += f"_chunk-{params.chunk_size}"
params.suffix += f"_left-context-{params.left_context_frames}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
params.suffix += f"_beam-{params.beam}"
params.suffix += f"_max-contexts-{params.max_contexts}"
params.suffix += f"_max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
params.suffix += f"_nbest-scale-{params.nbest_scale}"
params.suffix += f"_num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
params.suffix += f"_ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}"
if params.decoding_method in (
"modified_beam_search",
"modified_beam_search_LODR",
@ -815,19 +842,19 @@ def main():
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"_context-{params.context_size}"
params.suffix += f"_max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_shallow_fusion:
params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
params.suffix += f"_{params.lm_type}-lm-scale-{params.lm_scale}"
if "LODR" in params.decoding_method:
params.suffix += (
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
)
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
params.suffix += "_use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
@ -1038,12 +1065,19 @@ def main():
ngram_lm_scale=ngram_lm_scale,
)
save_results(
save_asr_output(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
if not params.skip_scoring:
save_wer_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")

View File

@ -218,7 +218,7 @@ class OnnxEncoder(nn.Module):
- encoder_out_lens, A 1-D tensor of shape (N,)
"""
x, x_lens = self.encoder_embed(x, x_lens)
src_key_padding_mask = make_pad_mask(x_lens)
src_key_padding_mask = make_pad_mask(x_lens, x.shape[1])
x = x.permute(1, 0, 2)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2)

View File

@ -297,7 +297,7 @@ class SoftmaxFunction(torch.autograd.Function):
# (presumably) that op does not support float16, and autocast
# is enabled.
if torch.is_autocast_enabled():
ans = ans.to(torch.float16)
ans = ans.to(torch.get_autocast_gpu_dtype())
ctx.save_for_backward(ans)
ctx.x_dtype = x.dtype
ctx.dim = dim
@ -1234,7 +1234,7 @@ class DoubleSwishFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
requires_grad = x.requires_grad
if x.dtype == torch.float16:
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
x = x.to(torch.float32)
s = torch.sigmoid(x - 1.0)
@ -1346,7 +1346,7 @@ class SwooshLFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
requires_grad = x.requires_grad
if x.dtype == torch.float16:
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
x = x.to(torch.float32)
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
@ -1379,7 +1379,7 @@ class SwooshLFunction(torch.autograd.Function):
d_int = d_scaled.to(torch.uint8)
ctx.save_for_backward(d_int)
if x.dtype == torch.float16 or torch.is_autocast_enabled():
y = y.to(torch.float16)
y = y.to(torch.get_autocast_gpu_dtype())
return y
@staticmethod
@ -1425,7 +1425,7 @@ class SwooshRFunction(torch.autograd.Function):
def forward(ctx, x: Tensor) -> Tensor:
requires_grad = x.requires_grad
if x.dtype == torch.float16:
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
x = x.to(torch.float32)
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
@ -1455,7 +1455,7 @@ class SwooshRFunction(torch.autograd.Function):
d_int = d_scaled.to(torch.uint8)
ctx.save_for_backward(d_int)
if x.dtype == torch.float16 or torch.is_autocast_enabled():
y = y.to(torch.float16)
y = y.to(torch.get_autocast_gpu_dtype())
return y
@staticmethod

View File

@ -43,7 +43,7 @@ import torch
from asr_datamodule import LibriSpeechAsrDataModule
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from lhotse import CutSet, set_caching_enabled
from streaming_beam_search import (
fast_beam_search_one_best,
greedy_search,
@ -76,6 +76,13 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--label",
type=str,
default="",
help="""Extra label of the decoding run.""",
)
parser.add_argument(
"--epoch",
type=int,
@ -188,6 +195,14 @@ def get_parser():
help="The number of streams that can be decoded parallel.",
)
parser.add_argument(
"--skip-scoring",
type=str2bool,
default=False,
help="""Skip scoring, but still save the ASR output (for eval sets)."""
)
add_model_arguments(parser)
return parser
@ -640,46 +655,60 @@ def decode_dataset(
return {key: decode_results}
def save_results(
def save_asr_output(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
"""
Save text produced by ASR.
"""
for key, results in results_dict.items():
recog_path = (
recogs_filename = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
store_transcripts(filename=recogs_filename, texts=results)
logging.info(f"The transcripts are stored in {recogs_filename}")
def save_wer_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
"""
Save WER and per-utterance word alignments.
"""
test_set_wers = dict()
for key, results in results_dict.items():
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
with open(errs_filename, "w", encoding="utf8") as fd:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
fd, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
logging.info(f"Wrote detailed error stats to {errs_filename}")
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
wer_filename = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
with open(wer_filename, "w", encoding="utf8") as fd:
print("settings\tWER", file=fd)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
print(f"{key}\t{val}", file=fd)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
s = f"\nFor {test_set_name}, WER of different settings are:\n"
note = f"\tbest for {test_set_name}"
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
s += f"{key}\t{val}{note}\n"
note = ""
logging.info(s)
@ -694,6 +723,9 @@ def main():
params = get_params()
params.update(vars(args))
# enable AudioCache
set_caching_enabled(True) # lhotse
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
@ -706,18 +738,21 @@ def main():
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
params.suffix += f"_chunk-{params.chunk_size}"
params.suffix += f"_left-context-{params.left_context_frames}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
params.suffix += f"_beam-{params.beam}"
params.suffix += f"_max-contexts-{params.max_contexts}"
params.suffix += f"_max-states-{params.max_states}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
if params.label:
params.suffix += f"-{params.label}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
@ -845,12 +880,21 @@ def main():
decoding_graph=decoding_graph,
)
save_results(
save_asr_output(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
if not params.skip_scoring:
save_wer_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")

View File

@ -521,6 +521,13 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--use-bf16",
type=str2bool,
default=False,
help="Whether to use bf16 in AMP.",
)
add_model_arguments(parser)
return parser
@ -1027,7 +1034,9 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch.cuda.amp.autocast(
enabled=params.use_autocast, dtype=params.dtype
):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1047,9 +1056,7 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
except Exception as e:
logging.info(
f"Caught exception: {e}."
)
logging.info(f"Caught exception: {e}.")
save_bad_model()
display_and_save_batch(batch, params=params, sp=sp)
raise
@ -1090,7 +1097,7 @@ def train_one_epoch(
rank=rank,
)
if batch_idx % 100 == 0 and params.use_fp16:
if batch_idx % 100 == 0 and params.use_autocast:
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
@ -1109,14 +1116,14 @@ def train_one_epoch(
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "")
)
if tb_writer is not None:
@ -1128,7 +1135,7 @@ def train_one_epoch(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
if params.use_autocast:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
@ -1204,9 +1211,25 @@ def run(rank, world_size, args):
params.ctc_loss_scale = 1.0
else:
assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, (
params.ctc_loss_scale, params.attention_decoder_loss_scale
params.ctc_loss_scale,
params.attention_decoder_loss_scale,
)
if params.use_bf16: # amp + bf16
assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!"
assert not params.use_fp16, "You can only use either fp16 or bf16"
params.dtype = torch.bfloat16
params.use_autocast = True
elif params.use_fp16: # amp + fp16
params.dtype = torch.float16
params.use_autocast = True
else: # fp32
params.dtype = torch.float32
params.use_autocast = False
logging.info(f"Using dtype={params.dtype}")
logging.info(f"Use AMP={params.use_autocast}")
logging.info(params)
logging.info("About to create model")
@ -1339,7 +1362,7 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1439,7 +1462,9 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch.cuda.amp.autocast(
enabled=params.use_autocast, dtype=params.dtype
):
loss, _ = compute_loss(
params=params,
model=model,