mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
Merge branch 'k2-fsa:master' into master
This commit is contained in:
commit
50471d6f11
@ -375,7 +375,7 @@ Please see: [
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file generates the manifest and computes the fbank features for AudioSet
|
||||||
|
dataset. The generated manifests and features are stored in data/fbank.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import lhotse
|
||||||
|
from lhotse import load_manifest
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-manifest", type=str, default="data/fbank/cuts_audioset_full.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Reference: https://github.com/YuanGongND/ast/blob/master/egs/audioset/gen_weight_file.py
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
cuts = load_manifest(args.input_manifest)
|
||||||
|
|
||||||
|
print(f"A total of {len(cuts)} cuts.")
|
||||||
|
|
||||||
|
label_count = [0] * 527 # a total of 527 classes
|
||||||
|
for c in cuts:
|
||||||
|
audio_event = c.supervisions[0].audio_event
|
||||||
|
labels = list(map(int, audio_event.split(";")))
|
||||||
|
for label in labels:
|
||||||
|
label_count[label] += 1
|
||||||
|
|
||||||
|
with open(args.output, "w") as f:
|
||||||
|
for c in cuts:
|
||||||
|
audio_event = c.supervisions[0].audio_event
|
||||||
|
labels = list(map(int, audio_event.split(";")))
|
||||||
|
weight = 0
|
||||||
|
for label in labels:
|
||||||
|
weight += 1000 / (label_count[label] + 0.01)
|
||||||
|
f.write(f"{c.id} {weight}\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -10,6 +10,7 @@ stage=-1
|
|||||||
stop_stage=4
|
stop_stage=4
|
||||||
|
|
||||||
dl_dir=$PWD/download
|
dl_dir=$PWD/download
|
||||||
|
fbank_dir=data/fbank
|
||||||
|
|
||||||
# we assume that you have your downloaded the AudioSet and placed
|
# we assume that you have your downloaded the AudioSet and placed
|
||||||
# it under $dl_dir/audioset, the folder structure should look like
|
# it under $dl_dir/audioset, the folder structure should look like
|
||||||
@ -49,7 +50,6 @@ fi
|
|||||||
|
|
||||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
log "Stage 0: Construct the audioset manifest and compute the fbank features for balanced set"
|
log "Stage 0: Construct the audioset manifest and compute the fbank features for balanced set"
|
||||||
fbank_dir=data/fbank
|
|
||||||
if [! -e $fbank_dir/.balanced.done]; then
|
if [! -e $fbank_dir/.balanced.done]; then
|
||||||
python local/generate_audioset_manifest.py \
|
python local/generate_audioset_manifest.py \
|
||||||
--dataset-dir $dl_dir/audioset \
|
--dataset-dir $dl_dir/audioset \
|
||||||
@ -102,3 +102,14 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|||||||
touch data/fbank/.musan.done
|
touch data/fbank/.musan.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# The following stages are required to do weighted-sampling training
|
||||||
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
|
log "Stage 5: Prepare for weighted-sampling training"
|
||||||
|
if [ ! -e $fbank_dir/cuts_audioset_full.jsonl.gz ]; then
|
||||||
|
lhotse combine $fbank_dir/cuts_audioset_balanced.jsonl.gz $fbank_dir/cuts_audioset_unbalanced.jsonl.gz $fbank_dir/cuts_audioset_full.jsonl.gz
|
||||||
|
fi
|
||||||
|
python ./local/compute_weight.py \
|
||||||
|
--input-manifest $fbank_dir/cuts_audioset_full.jsonl.gz \
|
||||||
|
--output $fbank_dir/sampling_weights_full.txt
|
||||||
|
fi
|
||||||
|
@ -31,6 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
|||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SimpleCutSampler,
|
SimpleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
|
WeightedSimpleCutSampler,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||||
AudioSamples,
|
AudioSamples,
|
||||||
@ -99,6 +100,20 @@ class AudioSetATDatamodule:
|
|||||||
help="Maximum pooled recordings duration (seconds) in a "
|
help="Maximum pooled recordings duration (seconds) in a "
|
||||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||||
)
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--weighted-sampler",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="When enabled, samples are drawn from by their weights. "
|
||||||
|
"It cannot be used together with bucketing sampler",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--num-samples",
|
||||||
|
type=int,
|
||||||
|
default=200000,
|
||||||
|
help="The number of samples to be drawn in each epoch. Only be used"
|
||||||
|
"for weighed sampler",
|
||||||
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--bucketing-sampler",
|
"--bucketing-sampler",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -295,6 +310,9 @@ class AudioSetATDatamodule:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.args.bucketing_sampler:
|
if self.args.bucketing_sampler:
|
||||||
|
assert (
|
||||||
|
not self.args.weighted_sampler
|
||||||
|
), "weighted sampling is not supported in bucket sampler"
|
||||||
logging.info("Using DynamicBucketingSampler.")
|
logging.info("Using DynamicBucketingSampler.")
|
||||||
train_sampler = DynamicBucketingSampler(
|
train_sampler = DynamicBucketingSampler(
|
||||||
cuts_train,
|
cuts_train,
|
||||||
@ -304,13 +322,26 @@ class AudioSetATDatamodule:
|
|||||||
drop_last=self.args.drop_last,
|
drop_last=self.args.drop_last,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SimpleCutSampler.")
|
if self.args.weighted_sampler:
|
||||||
train_sampler = SimpleCutSampler(
|
# assert self.args.audioset_subset == "full", "Only use weighted sampling for full audioset"
|
||||||
cuts_train,
|
logging.info("Using weighted SimpleCutSampler")
|
||||||
max_duration=self.args.max_duration,
|
weights = self.audioset_sampling_weights()
|
||||||
shuffle=self.args.shuffle,
|
train_sampler = WeightedSimpleCutSampler(
|
||||||
drop_last=self.args.drop_last,
|
cuts_train,
|
||||||
)
|
weights,
|
||||||
|
num_samples=self.args.num_samples,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=False, # do not support shuffle
|
||||||
|
drop_last=self.args.drop_last,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("Using SimpleCutSampler.")
|
||||||
|
train_sampler = SimpleCutSampler(
|
||||||
|
cuts_train,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=self.args.shuffle,
|
||||||
|
drop_last=self.args.drop_last,
|
||||||
|
)
|
||||||
logging.info("About to create train dataloader")
|
logging.info("About to create train dataloader")
|
||||||
|
|
||||||
if sampler_state_dict is not None:
|
if sampler_state_dict is not None:
|
||||||
@ -373,11 +404,9 @@ class AudioSetATDatamodule:
|
|||||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
logging.debug("About to create test dataset")
|
logging.debug("About to create test dataset")
|
||||||
test = AudioTaggingDataset(
|
test = AudioTaggingDataset(
|
||||||
input_strategy=(
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
if self.args.on_the_fly_feats
|
||||||
if self.args.on_the_fly_feats
|
else eval(self.args.input_strategy)(),
|
||||||
else eval(self.args.input_strategy)()
|
|
||||||
),
|
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
sampler = DynamicBucketingSampler(
|
sampler = DynamicBucketingSampler(
|
||||||
@ -397,21 +426,30 @@ class AudioSetATDatamodule:
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def audioset_train_cuts(self) -> CutSet:
|
def audioset_train_cuts(self) -> CutSet:
|
||||||
logging.info("About to get the audioset training cuts.")
|
logging.info("About to get the audioset training cuts.")
|
||||||
balanced_cuts = load_manifest_lazy(
|
if not self.args.weighted_sampler:
|
||||||
self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz"
|
balanced_cuts = load_manifest_lazy(
|
||||||
)
|
self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz"
|
||||||
if self.args.audioset_subset == "full":
|
|
||||||
unbalanced_cuts = load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz"
|
|
||||||
)
|
|
||||||
cuts = CutSet.mux(
|
|
||||||
balanced_cuts,
|
|
||||||
unbalanced_cuts,
|
|
||||||
weights=[20000, 2000000],
|
|
||||||
stop_early=True,
|
|
||||||
)
|
)
|
||||||
|
if self.args.audioset_subset == "full":
|
||||||
|
unbalanced_cuts = load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz"
|
||||||
|
)
|
||||||
|
cuts = CutSet.mux(
|
||||||
|
balanced_cuts,
|
||||||
|
unbalanced_cuts,
|
||||||
|
weights=[20000, 2000000],
|
||||||
|
stop_early=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cuts = balanced_cuts
|
||||||
else:
|
else:
|
||||||
cuts = balanced_cuts
|
# assert self.args.audioset_subset == "full", "Only do weighted sampling for full AudioSet"
|
||||||
|
cuts = load_manifest(
|
||||||
|
self.args.manifest_dir
|
||||||
|
/ f"cuts_audioset_{self.args.audioset_subset}.jsonl.gz"
|
||||||
|
)
|
||||||
|
logging.info(f"Get {len(cuts)} cuts in total.")
|
||||||
|
|
||||||
return cuts
|
return cuts
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
@ -420,3 +458,22 @@ class AudioSetATDatamodule:
|
|||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "cuts_audioset_eval.jsonl.gz"
|
self.args.manifest_dir / "cuts_audioset_eval.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def audioset_sampling_weights(self):
|
||||||
|
logging.info(
|
||||||
|
f"About to get the sampling weight for {self.args.audioset_subset} in AudioSet"
|
||||||
|
)
|
||||||
|
weights = []
|
||||||
|
with open(
|
||||||
|
self.args.manifest_dir / f"sample_weights_{self.args.audioset_subset}.txt",
|
||||||
|
"r",
|
||||||
|
) as f:
|
||||||
|
while True:
|
||||||
|
line = f.readline()
|
||||||
|
if not line:
|
||||||
|
break
|
||||||
|
weight = float(line.split()[1])
|
||||||
|
weights.append(weight)
|
||||||
|
logging.info(f"Get the sampling weight for {len(weights)} cuts")
|
||||||
|
return weights
|
||||||
|
@ -789,12 +789,14 @@ def train_one_epoch(
|
|||||||
rank=0,
|
rank=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_samples = 0
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
if batch_idx % 10 == 0:
|
if batch_idx % 10 == 0:
|
||||||
set_batch_count(model, get_adjusted_batch_count(params))
|
set_batch_count(model, get_adjusted_batch_count(params))
|
||||||
|
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = batch["inputs"].size(0)
|
batch_size = batch["inputs"].size(0)
|
||||||
|
num_samples += batch_size
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
@ -919,6 +921,12 @@ def train_one_epoch(
|
|||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if num_samples > params.num_samples:
|
||||||
|
logging.info(
|
||||||
|
f"Number of training samples exceeds {params.num_samples} in this epoch, move on to next epoch"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
if params.train_loss < params.best_train_loss:
|
if params.train_loss < params.best_train_loss:
|
||||||
@ -1032,7 +1040,8 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
if not params.weighted_sampler:
|
||||||
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||||
# We only load the sampler's state dict when it loads a checkpoint
|
# We only load the sampler's state dict when it loads a checkpoint
|
||||||
|
@ -29,17 +29,21 @@ def simple_cleanup(text: str) -> str:
|
|||||||
|
|
||||||
# Assign text of the supervisions and remove unnecessary entries.
|
# Assign text of the supervisions and remove unnecessary entries.
|
||||||
def main():
|
def main():
|
||||||
assert len(sys.argv) == 3, "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR"
|
assert (
|
||||||
|
len(sys.argv) == 4
|
||||||
|
), "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR KEEP_CUSTOM_FIELDS"
|
||||||
fname = Path(sys.argv[1]).name
|
fname = Path(sys.argv[1]).name
|
||||||
oname = Path(sys.argv[2]) / fname
|
oname = Path(sys.argv[2]) / fname
|
||||||
|
keep_custom_fields = bool(sys.argv[3])
|
||||||
with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout:
|
with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout:
|
||||||
for line in fin:
|
for line in fin:
|
||||||
cut = json.loads(line)
|
cut = json.loads(line)
|
||||||
cut["supervisions"][0]["text"] = simple_cleanup(
|
cut["supervisions"][0]["text"] = simple_cleanup(
|
||||||
cut["supervisions"][0]["custom"]["texts"][0]
|
cut["supervisions"][0]["custom"]["texts"][0]
|
||||||
)
|
)
|
||||||
del cut["supervisions"][0]["custom"]
|
if not keep_custom_fields:
|
||||||
del cut["custom"]
|
del cut["supervisions"][0]["custom"]
|
||||||
|
del cut["custom"]
|
||||||
fout.write((json.dumps(cut) + "\n").encode())
|
fout.write((json.dumps(cut) + "\n").encode())
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,6 +29,11 @@ export CUDA_VISIBLE_DEVICES=""
|
|||||||
# - speech
|
# - speech
|
||||||
dl_dir=$PWD/download
|
dl_dir=$PWD/download
|
||||||
|
|
||||||
|
# If you want to do PromptASR experiments, please set it to True
|
||||||
|
# as this will keep the texts and pre_text information required for
|
||||||
|
# the training of PromptASR.
|
||||||
|
keep_custom_fields=False
|
||||||
|
|
||||||
. shared/parse_options.sh || exit 1
|
. shared/parse_options.sh || exit 1
|
||||||
|
|
||||||
# vocab size for sentence piece models.
|
# vocab size for sentence piece models.
|
||||||
@ -134,7 +139,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
for subset in small medium large dev test_clean test_other; do
|
for subset in small medium large dev test_clean test_other; do
|
||||||
if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then
|
if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then
|
||||||
log "Prepare manifest for subset : ${subset}"
|
log "Prepare manifest for subset : ${subset}"
|
||||||
./local/prepare_manifest.py $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz $manifests_dir
|
./local/prepare_manifest.py $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz $manifests_dir $keep_custom_fields
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
@ -307,6 +307,23 @@ done
|
|||||||
|
|
||||||
To decode with external language models, please refer to the documentation [here](https://k2-fsa.github.io/icefall/decoding-with-langugage-models/index.html).
|
To decode with external language models, please refer to the documentation [here](https://k2-fsa.github.io/icefall/decoding-with-langugage-models/index.html).
|
||||||
|
|
||||||
|
We also support training Zipformer with AMP+bf16 format (requires bf16 support). See [here](https://github.com/k2-fsa/icefall/pull/1700) for more details and pre-trained models. **The same command can be used for decoding and exporting the model.**
|
||||||
|
|
||||||
|
The amp+bf16 training command is:
|
||||||
|
```bash
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
./zipformer/train.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--num-epochs 50 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--use-fp16 0 \
|
||||||
|
--use-bf16 1 \
|
||||||
|
--exp-dir zipformer/exp_amp_bf16 \
|
||||||
|
--causal 0 \
|
||||||
|
--full-libri 1 \
|
||||||
|
--max-duration 1000
|
||||||
|
```
|
||||||
|
|
||||||
##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M
|
##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M
|
||||||
|
|
||||||
The tensorboard log can be found at
|
The tensorboard log can be found at
|
||||||
|
@ -120,6 +120,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from lhotse import set_caching_enabled
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -296,6 +297,13 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip-scoring",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Skip scoring, but still save the ASR output (for eval sets)."""
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -455,7 +463,7 @@ def decode_one_batch(
|
|||||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||||
hyps = [s.split() for s in hyps]
|
hyps = [s.split() for s in hyps]
|
||||||
key = "ctc-decoding"
|
key = "ctc-decoding"
|
||||||
return {key: hyps}
|
return {key: hyps} # note: returns words
|
||||||
|
|
||||||
if params.decoding_method == "attention-decoder-rescoring-no-ngram":
|
if params.decoding_method == "attention-decoder-rescoring-no-ngram":
|
||||||
best_path_dict = rescore_with_attention_decoder_no_ngram(
|
best_path_dict = rescore_with_attention_decoder_no_ngram(
|
||||||
@ -492,7 +500,7 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
hyps = get_texts(best_path)
|
hyps = get_texts(best_path)
|
||||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||||
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
|
key = f"oracle_{params.num_paths}_nbest-scale-{params.nbest_scale}" # noqa
|
||||||
return {key: hyps}
|
return {key: hyps}
|
||||||
|
|
||||||
if params.decoding_method in ["1best", "nbest"]:
|
if params.decoding_method in ["1best", "nbest"]:
|
||||||
@ -500,7 +508,7 @@ def decode_one_batch(
|
|||||||
best_path = one_best_decoding(
|
best_path = one_best_decoding(
|
||||||
lattice=lattice, use_double_scores=params.use_double_scores
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
)
|
)
|
||||||
key = "no_rescore"
|
key = "no-rescore"
|
||||||
else:
|
else:
|
||||||
best_path = nbest_decoding(
|
best_path = nbest_decoding(
|
||||||
lattice=lattice,
|
lattice=lattice,
|
||||||
@ -508,11 +516,11 @@ def decode_one_batch(
|
|||||||
use_double_scores=params.use_double_scores,
|
use_double_scores=params.use_double_scores,
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
)
|
)
|
||||||
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
|
key = f"no-rescore_nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
|
||||||
|
|
||||||
hyps = get_texts(best_path)
|
hyps = get_texts(best_path)
|
||||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||||
return {key: hyps}
|
return {key: hyps} # note: returns BPE tokens
|
||||||
|
|
||||||
assert params.decoding_method in [
|
assert params.decoding_method in [
|
||||||
"nbest-rescoring",
|
"nbest-rescoring",
|
||||||
@ -646,7 +654,27 @@ def decode_dataset(
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def save_results(
|
def save_asr_output(
|
||||||
|
params: AttributeDict,
|
||||||
|
test_set_name: str,
|
||||||
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Save text produced by ASR.
|
||||||
|
"""
|
||||||
|
for key, results in results_dict.items():
|
||||||
|
|
||||||
|
recogs_filename = (
|
||||||
|
params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
|
||||||
|
results = sorted(results)
|
||||||
|
store_transcripts(filename=recogs_filename, texts=results)
|
||||||
|
|
||||||
|
logging.info(f"The transcripts are stored in {recogs_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def save_wer_results(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
@ -661,32 +689,30 @@ def save_results(
|
|||||||
|
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
|
||||||
results = sorted(results)
|
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
|
||||||
if enable_log:
|
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w", encoding="utf8") as fd:
|
||||||
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
|
wer = write_error_stats(
|
||||||
|
fd, f"{test_set_name}_{key}", results, enable_log=enable_log
|
||||||
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
if enable_log:
|
|
||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info(f"Wrote detailed error stats to {errs_filename}")
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
|
||||||
with open(errs_info, "w") as f:
|
|
||||||
print("settings\tWER", file=f)
|
|
||||||
for key, val in test_set_wers:
|
|
||||||
print("{}\t{}".format(key, val), file=f)
|
|
||||||
|
|
||||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
note = "\tbest for {}".format(test_set_name)
|
|
||||||
|
with open(wer_filename, "w", encoding="utf8") as fd:
|
||||||
|
print("settings\tWER", file=fd)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
print(f"{key}\t{val}", file=fd)
|
||||||
|
|
||||||
|
s = f"\nFor {test_set_name}, WER of different settings are:\n"
|
||||||
|
note = f"\tbest for {test_set_name}"
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
s += "{}\t{}{}\n".format(key, val, note)
|
s += f"{key}\t{val}{note}\n"
|
||||||
note = ""
|
note = ""
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
@ -705,6 +731,9 @@ def main():
|
|||||||
params.update(get_decoding_params())
|
params.update(get_decoding_params())
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
|
# enable AudioCache
|
||||||
|
set_caching_enabled(True) # lhotse
|
||||||
|
|
||||||
assert params.decoding_method in (
|
assert params.decoding_method in (
|
||||||
"ctc-greedy-search",
|
"ctc-greedy-search",
|
||||||
"ctc-decoding",
|
"ctc-decoding",
|
||||||
@ -719,9 +748,9 @@ def main():
|
|||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
params.suffix = f"iter-{params.iter}_avg-{params.avg}"
|
||||||
else:
|
else:
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}_avg-{params.avg}"
|
||||||
|
|
||||||
if params.causal:
|
if params.causal:
|
||||||
assert (
|
assert (
|
||||||
@ -730,11 +759,11 @@ def main():
|
|||||||
assert (
|
assert (
|
||||||
"," not in params.left_context_frames
|
"," not in params.left_context_frames
|
||||||
), "left_context_frames should be one value in decoding."
|
), "left_context_frames should be one value in decoding."
|
||||||
params.suffix += f"-chunk-{params.chunk_size}"
|
params.suffix += f"_chunk-{params.chunk_size}"
|
||||||
params.suffix += f"-left-context-{params.left_context_frames}"
|
params.suffix += f"_left-context-{params.left_context_frames}"
|
||||||
|
|
||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "-use-averaged-model"
|
params.suffix += "_use-averaged-model"
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
@ -940,12 +969,19 @@ def main():
|
|||||||
G=G,
|
G=G,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_asr_output(
|
||||||
params=params,
|
params=params,
|
||||||
test_set_name=test_set,
|
test_set_name=test_set,
|
||||||
results_dict=results_dict,
|
results_dict=results_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not params.skip_scoring:
|
||||||
|
save_wer_results(
|
||||||
|
params=params,
|
||||||
|
test_set_name=test_set,
|
||||||
|
results_dict=results_dict,
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
@ -121,6 +121,7 @@ from beam_search import (
|
|||||||
modified_beam_search_lm_shallow_fusion,
|
modified_beam_search_lm_shallow_fusion,
|
||||||
modified_beam_search_LODR,
|
modified_beam_search_LODR,
|
||||||
)
|
)
|
||||||
|
from lhotse import set_caching_enabled
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall import ContextGraph, LmScorer, NgramLm
|
from icefall import ContextGraph, LmScorer, NgramLm
|
||||||
@ -369,6 +370,14 @@ def get_parser():
|
|||||||
modified_beam_search_LODR.
|
modified_beam_search_LODR.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip-scoring",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Skip scoring, but still save the ASR output (for eval sets).""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -590,21 +599,23 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(sp.decode(hyp).split())
|
||||||
|
|
||||||
|
# prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" )
|
||||||
|
prefix = f"{params.decoding_method}"
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search": hyps}
|
||||||
elif "fast_beam_search" in params.decoding_method:
|
elif "fast_beam_search" in params.decoding_method:
|
||||||
key = f"beam_{params.beam}_"
|
prefix += f"_beam-{params.beam}"
|
||||||
key += f"max_contexts_{params.max_contexts}_"
|
prefix += f"_max-contexts-{params.max_contexts}"
|
||||||
key += f"max_states_{params.max_states}"
|
prefix += f"_max-states-{params.max_states}"
|
||||||
if "nbest" in params.decoding_method:
|
if "nbest" in params.decoding_method:
|
||||||
key += f"_num_paths_{params.num_paths}_"
|
prefix += f"_num-paths-{params.num_paths}"
|
||||||
key += f"nbest_scale_{params.nbest_scale}"
|
prefix += f"_nbest-scale-{params.nbest_scale}"
|
||||||
if "LG" in params.decoding_method:
|
if "LG" in params.decoding_method:
|
||||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
prefix += f"_ngram-lm-scale-{params.ngram_lm_scale}"
|
||||||
|
|
||||||
return {key: hyps}
|
return {prefix: hyps}
|
||||||
elif "modified_beam_search" in params.decoding_method:
|
elif "modified_beam_search" in params.decoding_method:
|
||||||
prefix = f"beam_size_{params.beam_size}"
|
prefix += f"_beam-size-{params.beam_size}"
|
||||||
if params.decoding_method in (
|
if params.decoding_method in (
|
||||||
"modified_beam_search_lm_rescore",
|
"modified_beam_search_lm_rescore",
|
||||||
"modified_beam_search_lm_rescore_LODR",
|
"modified_beam_search_lm_rescore_LODR",
|
||||||
@ -617,10 +628,11 @@ def decode_one_batch(
|
|||||||
return ans
|
return ans
|
||||||
else:
|
else:
|
||||||
if params.has_contexts:
|
if params.has_contexts:
|
||||||
prefix += f"-context-score-{params.context_score}"
|
prefix += f"_context-score-{params.context_score}"
|
||||||
return {prefix: hyps}
|
return {prefix: hyps}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
prefix += f"_beam-size-{params.beam_size}"
|
||||||
|
return {prefix: hyps}
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
@ -707,46 +719,58 @@ def decode_dataset(
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def save_results(
|
def save_asr_output(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Save text produced by ASR.
|
||||||
|
"""
|
||||||
|
for key, results in results_dict.items():
|
||||||
|
|
||||||
|
recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
|
|
||||||
|
results = sorted(results)
|
||||||
|
store_transcripts(filename=recogs_filename, texts=results)
|
||||||
|
|
||||||
|
logging.info(f"The transcripts are stored in {recogs_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def save_wer_results(
|
||||||
|
params: AttributeDict,
|
||||||
|
test_set_name: str,
|
||||||
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Save WER and per-utterance word alignments.
|
||||||
|
"""
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
with open(errs_filename, "w", encoding="utf8") as fd:
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
fd, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info(f"Wrote detailed error stats to {errs_filename}")
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
|
||||||
print("settings\tWER", file=f)
|
|
||||||
for key, val in test_set_wers:
|
|
||||||
print("{}\t{}".format(key, val), file=f)
|
|
||||||
|
|
||||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
note = "\tbest for {}".format(test_set_name)
|
|
||||||
|
with open(wer_filename, "w", encoding="utf8") as fd:
|
||||||
|
print("settings\tWER", file=fd)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
print(f"{key}\t{val}", file=fd)
|
||||||
|
|
||||||
|
s = f"\nFor {test_set_name}, WER of different settings are:\n"
|
||||||
|
note = f"\tbest for {test_set_name}"
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
s += "{}\t{}{}\n".format(key, val, note)
|
s += f"{key}\t{val}{note}\n"
|
||||||
note = ""
|
note = ""
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
@ -762,6 +786,9 @@ def main():
|
|||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
|
# enable AudioCache
|
||||||
|
set_caching_enabled(True) # lhotse
|
||||||
|
|
||||||
assert params.decoding_method in (
|
assert params.decoding_method in (
|
||||||
"greedy_search",
|
"greedy_search",
|
||||||
"beam_search",
|
"beam_search",
|
||||||
@ -783,9 +810,9 @@ def main():
|
|||||||
params.has_contexts = False
|
params.has_contexts = False
|
||||||
|
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
params.suffix = f"iter-{params.iter}_avg-{params.avg}"
|
||||||
else:
|
else:
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}_avg-{params.avg}"
|
||||||
|
|
||||||
if params.causal:
|
if params.causal:
|
||||||
assert (
|
assert (
|
||||||
@ -794,20 +821,20 @@ def main():
|
|||||||
assert (
|
assert (
|
||||||
"," not in params.left_context_frames
|
"," not in params.left_context_frames
|
||||||
), "left_context_frames should be one value in decoding."
|
), "left_context_frames should be one value in decoding."
|
||||||
params.suffix += f"-chunk-{params.chunk_size}"
|
params.suffix += f"_chunk-{params.chunk_size}"
|
||||||
params.suffix += f"-left-context-{params.left_context_frames}"
|
params.suffix += f"_left-context-{params.left_context_frames}"
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"_beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"_max-contexts-{params.max_contexts}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"_max-states-{params.max_states}"
|
||||||
if "nbest" in params.decoding_method:
|
if "nbest" in params.decoding_method:
|
||||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
params.suffix += f"_nbest-scale-{params.nbest_scale}"
|
||||||
params.suffix += f"-num-paths-{params.num_paths}"
|
params.suffix += f"_num-paths-{params.num_paths}"
|
||||||
if "LG" in params.decoding_method:
|
if "LG" in params.decoding_method:
|
||||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
params.suffix += f"_ngram-lm-scale-{params.ngram_lm_scale}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}"
|
||||||
if params.decoding_method in (
|
if params.decoding_method in (
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
"modified_beam_search_LODR",
|
"modified_beam_search_LODR",
|
||||||
@ -815,19 +842,19 @@ def main():
|
|||||||
if params.has_contexts:
|
if params.has_contexts:
|
||||||
params.suffix += f"-context-score-{params.context_score}"
|
params.suffix += f"-context-score-{params.context_score}"
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"_context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"_max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
|
||||||
if params.use_shallow_fusion:
|
if params.use_shallow_fusion:
|
||||||
params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
|
params.suffix += f"_{params.lm_type}-lm-scale-{params.lm_scale}"
|
||||||
|
|
||||||
if "LODR" in params.decoding_method:
|
if "LODR" in params.decoding_method:
|
||||||
params.suffix += (
|
params.suffix += (
|
||||||
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
|
f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "-use-averaged-model"
|
params.suffix += "_use-averaged-model"
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
@ -1038,12 +1065,19 @@ def main():
|
|||||||
ngram_lm_scale=ngram_lm_scale,
|
ngram_lm_scale=ngram_lm_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_asr_output(
|
||||||
params=params,
|
params=params,
|
||||||
test_set_name=test_set,
|
test_set_name=test_set,
|
||||||
results_dict=results_dict,
|
results_dict=results_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not params.skip_scoring:
|
||||||
|
save_wer_results(
|
||||||
|
params=params,
|
||||||
|
test_set_name=test_set,
|
||||||
|
results_dict=results_dict,
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
@ -218,7 +218,7 @@ class OnnxEncoder(nn.Module):
|
|||||||
- encoder_out_lens, A 1-D tensor of shape (N,)
|
- encoder_out_lens, A 1-D tensor of shape (N,)
|
||||||
"""
|
"""
|
||||||
x, x_lens = self.encoder_embed(x, x_lens)
|
x, x_lens = self.encoder_embed(x, x_lens)
|
||||||
src_key_padding_mask = make_pad_mask(x_lens)
|
src_key_padding_mask = make_pad_mask(x_lens, x.shape[1])
|
||||||
x = x.permute(1, 0, 2)
|
x = x.permute(1, 0, 2)
|
||||||
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||||
encoder_out = encoder_out.permute(1, 0, 2)
|
encoder_out = encoder_out.permute(1, 0, 2)
|
||||||
|
@ -297,7 +297,7 @@ class SoftmaxFunction(torch.autograd.Function):
|
|||||||
# (presumably) that op does not support float16, and autocast
|
# (presumably) that op does not support float16, and autocast
|
||||||
# is enabled.
|
# is enabled.
|
||||||
if torch.is_autocast_enabled():
|
if torch.is_autocast_enabled():
|
||||||
ans = ans.to(torch.float16)
|
ans = ans.to(torch.get_autocast_gpu_dtype())
|
||||||
ctx.save_for_backward(ans)
|
ctx.save_for_backward(ans)
|
||||||
ctx.x_dtype = x.dtype
|
ctx.x_dtype = x.dtype
|
||||||
ctx.dim = dim
|
ctx.dim = dim
|
||||||
@ -1234,7 +1234,7 @@ class DoubleSwishFunction(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor) -> Tensor:
|
def forward(ctx, x: Tensor) -> Tensor:
|
||||||
requires_grad = x.requires_grad
|
requires_grad = x.requires_grad
|
||||||
if x.dtype == torch.float16:
|
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
|
||||||
x = x.to(torch.float32)
|
x = x.to(torch.float32)
|
||||||
|
|
||||||
s = torch.sigmoid(x - 1.0)
|
s = torch.sigmoid(x - 1.0)
|
||||||
@ -1346,7 +1346,7 @@ class SwooshLFunction(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor) -> Tensor:
|
def forward(ctx, x: Tensor) -> Tensor:
|
||||||
requires_grad = x.requires_grad
|
requires_grad = x.requires_grad
|
||||||
if x.dtype == torch.float16:
|
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
|
||||||
x = x.to(torch.float32)
|
x = x.to(torch.float32)
|
||||||
|
|
||||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||||
@ -1379,7 +1379,7 @@ class SwooshLFunction(torch.autograd.Function):
|
|||||||
d_int = d_scaled.to(torch.uint8)
|
d_int = d_scaled.to(torch.uint8)
|
||||||
ctx.save_for_backward(d_int)
|
ctx.save_for_backward(d_int)
|
||||||
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
||||||
y = y.to(torch.float16)
|
y = y.to(torch.get_autocast_gpu_dtype())
|
||||||
return y
|
return y
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -1425,7 +1425,7 @@ class SwooshRFunction(torch.autograd.Function):
|
|||||||
def forward(ctx, x: Tensor) -> Tensor:
|
def forward(ctx, x: Tensor) -> Tensor:
|
||||||
requires_grad = x.requires_grad
|
requires_grad = x.requires_grad
|
||||||
|
|
||||||
if x.dtype == torch.float16:
|
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
|
||||||
x = x.to(torch.float32)
|
x = x.to(torch.float32)
|
||||||
|
|
||||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||||
@ -1455,7 +1455,7 @@ class SwooshRFunction(torch.autograd.Function):
|
|||||||
d_int = d_scaled.to(torch.uint8)
|
d_int = d_scaled.to(torch.uint8)
|
||||||
ctx.save_for_backward(d_int)
|
ctx.save_for_backward(d_int)
|
||||||
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
||||||
y = y.to(torch.float16)
|
y = y.to(torch.get_autocast_gpu_dtype())
|
||||||
return y
|
return y
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -43,7 +43,7 @@ import torch
|
|||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from decode_stream import DecodeStream
|
from decode_stream import DecodeStream
|
||||||
from kaldifeat import Fbank, FbankOptions
|
from kaldifeat import Fbank, FbankOptions
|
||||||
from lhotse import CutSet
|
from lhotse import CutSet, set_caching_enabled
|
||||||
from streaming_beam_search import (
|
from streaming_beam_search import (
|
||||||
fast_beam_search_one_best,
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
@ -76,6 +76,13 @@ def get_parser():
|
|||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--label",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""Extra label of the decoding run.""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
@ -188,6 +195,14 @@ def get_parser():
|
|||||||
help="The number of streams that can be decoded parallel.",
|
help="The number of streams that can be decoded parallel.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip-scoring",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Skip scoring, but still save the ASR output (for eval sets)."""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -640,46 +655,60 @@ def decode_dataset(
|
|||||||
return {key: decode_results}
|
return {key: decode_results}
|
||||||
|
|
||||||
|
|
||||||
def save_results(
|
def save_asr_output(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
|
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
|
||||||
):
|
):
|
||||||
test_set_wers = dict()
|
"""
|
||||||
|
Save text produced by ASR.
|
||||||
|
"""
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recogs_filename = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recogs_filename, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recogs_filename}")
|
||||||
|
|
||||||
|
def save_wer_results(
|
||||||
|
params: AttributeDict,
|
||||||
|
test_set_name: str,
|
||||||
|
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Save WER and per-utterance word alignments.
|
||||||
|
"""
|
||||||
|
test_set_wers = dict()
|
||||||
|
for key, results in results_dict.items():
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = (
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w", encoding="utf8") as fd:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
fd, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info(f"Wrote detailed error stats to {errs_filename}")
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
|
||||||
|
wer_filename = (
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
with open(errs_info, "w") as f:
|
with open(wer_filename, "w", encoding="utf8") as fd:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=fd)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
print("{}\t{}".format(key, val), file=f)
|
print(f"{key}\t{val}", file=fd)
|
||||||
|
|
||||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
s = f"\nFor {test_set_name}, WER of different settings are:\n"
|
||||||
note = "\tbest for {}".format(test_set_name)
|
note = f"\tbest for {test_set_name}"
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
s += "{}\t{}{}\n".format(key, val, note)
|
s += f"{key}\t{val}{note}\n"
|
||||||
note = ""
|
note = ""
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
@ -694,6 +723,9 @@ def main():
|
|||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
|
# enable AudioCache
|
||||||
|
set_caching_enabled(True) # lhotse
|
||||||
|
|
||||||
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
||||||
|
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
@ -706,18 +738,21 @@ def main():
|
|||||||
assert (
|
assert (
|
||||||
"," not in params.left_context_frames
|
"," not in params.left_context_frames
|
||||||
), "left_context_frames should be one value in decoding."
|
), "left_context_frames should be one value in decoding."
|
||||||
params.suffix += f"-chunk-{params.chunk_size}"
|
params.suffix += f"_chunk-{params.chunk_size}"
|
||||||
params.suffix += f"-left-context-{params.left_context_frames}"
|
params.suffix += f"_left-context-{params.left_context_frames}"
|
||||||
|
|
||||||
# for fast_beam_search
|
# for fast_beam_search
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"_beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"_max-contexts-{params.max_contexts}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"_max-states-{params.max_states}"
|
||||||
|
|
||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "-use-averaged-model"
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
|
if params.label:
|
||||||
|
params.suffix += f"-{params.label}"
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
|
||||||
@ -845,12 +880,21 @@ def main():
|
|||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
|
||||||
|
save_asr_output(
|
||||||
params=params,
|
params=params,
|
||||||
test_set_name=test_set,
|
test_set_name=test_set,
|
||||||
results_dict=results_dict,
|
results_dict=results_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if not params.skip_scoring:
|
||||||
|
save_wer_results(
|
||||||
|
params=params,
|
||||||
|
test_set_name=test_set,
|
||||||
|
results_dict=results_dict,
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
@ -521,6 +521,13 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-bf16",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to use bf16 in AMP.",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -1027,7 +1034,9 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(
|
||||||
|
enabled=params.use_autocast, dtype=params.dtype
|
||||||
|
):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1047,9 +1056,7 @@ def train_one_epoch(
|
|||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info(
|
logging.info(f"Caught exception: {e}.")
|
||||||
f"Caught exception: {e}."
|
|
||||||
)
|
|
||||||
save_bad_model()
|
save_bad_model()
|
||||||
display_and_save_batch(batch, params=params, sp=sp)
|
display_and_save_batch(batch, params=params, sp=sp)
|
||||||
raise
|
raise
|
||||||
@ -1090,7 +1097,7 @@ def train_one_epoch(
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % 100 == 0 and params.use_fp16:
|
if batch_idx % 100 == 0 and params.use_autocast:
|
||||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||||
# of the grad scaler is configurable, but we can't configure it to have different
|
# of the grad scaler is configurable, but we can't configure it to have different
|
||||||
# behavior depending on the current grad scale.
|
# behavior depending on the current grad scale.
|
||||||
@ -1109,14 +1116,14 @@ def train_one_epoch(
|
|||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = max(scheduler.get_last_lr())
|
cur_lr = max(scheduler.get_last_lr())
|
||||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, "
|
f"Epoch {params.cur_epoch}, "
|
||||||
f"batch {batch_idx}, loss[{loss_info}], "
|
f"batch {batch_idx}, loss[{loss_info}], "
|
||||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||||
f"lr: {cur_lr:.2e}, "
|
f"lr: {cur_lr:.2e}, "
|
||||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
+ (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "")
|
||||||
)
|
)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
@ -1128,7 +1135,7 @@ def train_one_epoch(
|
|||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
if params.use_fp16:
|
if params.use_autocast:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||||
)
|
)
|
||||||
@ -1204,9 +1211,25 @@ def run(rank, world_size, args):
|
|||||||
params.ctc_loss_scale = 1.0
|
params.ctc_loss_scale = 1.0
|
||||||
else:
|
else:
|
||||||
assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, (
|
assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, (
|
||||||
params.ctc_loss_scale, params.attention_decoder_loss_scale
|
params.ctc_loss_scale,
|
||||||
|
params.attention_decoder_loss_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if params.use_bf16: # amp + bf16
|
||||||
|
assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!"
|
||||||
|
assert not params.use_fp16, "You can only use either fp16 or bf16"
|
||||||
|
params.dtype = torch.bfloat16
|
||||||
|
params.use_autocast = True
|
||||||
|
elif params.use_fp16: # amp + fp16
|
||||||
|
params.dtype = torch.float16
|
||||||
|
params.use_autocast = True
|
||||||
|
else: # fp32
|
||||||
|
params.dtype = torch.float32
|
||||||
|
params.use_autocast = False
|
||||||
|
|
||||||
|
logging.info(f"Using dtype={params.dtype}")
|
||||||
|
logging.info(f"Use AMP={params.use_autocast}")
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
@ -1339,7 +1362,7 @@ def run(rank, world_size, args):
|
|||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1439,7 +1462,9 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(
|
||||||
|
enabled=params.use_autocast, dtype=params.dtype
|
||||||
|
):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user