mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
PromptASR for contextualized ASR with controllable style (#1250)
* Add PromptASR with BERT as text encoder * Support using word-list based content prompts for context biasing * Upload the pretrained models to huggingface * Add usage example
This commit is contained in:
parent
cb874e9905
commit
16a2748d6c
205
egs/libriheavy/ASR/RESULTS.md
Normal file
205
egs/libriheavy/ASR/RESULTS.md
Normal file
@ -0,0 +1,205 @@
|
||||
## Results
|
||||
|
||||
### Zipformer PromptASR (zipformer + PromptASR + BERT text encoder)
|
||||
|
||||
#### [zipformer_prompt_asr](./zipformer_prompt_asr)
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/1250> for commit history and
|
||||
our paper <https://arxiv.org/abs/2309.07414> for more details.
|
||||
|
||||
|
||||
|
||||
##### Training on the medium subset, with content & style prompt, **no** context list
|
||||
|
||||
You can find a pre-trained model, training logs, decoding logs, and decoding results at: <https://huggingface.co/marcoyang/icefall-promptasr-libriheavy-zipformer-BERT-2023-10-10>
|
||||
|
||||
The training command is:
|
||||
|
||||
```bash
|
||||
causal=0
|
||||
subset=medium
|
||||
memory_dropout_rate=0.05
|
||||
text_encoder_type=BERT
|
||||
|
||||
python ./zipformer_prompt_asr/train_bert_encoder.py \
|
||||
--world-size 4 \
|
||||
--start-epoch 1 \
|
||||
--num-epochs 60 \
|
||||
--exp-dir ./zipformer_prompt_asr/exp \
|
||||
--use-fp16 True \
|
||||
--memory-dropout-rate $memory_dropout_rate \
|
||||
--causal $causal \
|
||||
--subset $subset \
|
||||
--manifest-dir data/fbank \
|
||||
--bpe-model data/lang_bpe_500_fallback_coverage_0.99/bpe.model \
|
||||
--max-duration 1000 \
|
||||
--text-encoder-type $text_encoder_type \
|
||||
--text-encoder-dim 768 \
|
||||
--use-context-list 0 \
|
||||
--top-k $top_k \
|
||||
--use-style-prompt 1
|
||||
```
|
||||
|
||||
The decoding results using utterance-level context (epoch-60-avg-10):
|
||||
|
||||
| decoding method | lh-test-clean | lh-test-other | comment |
|
||||
|----------------------|---------------|---------------|---------------------|
|
||||
| modified_beam_search | 3.13 | 6.78 | --use-pre-text False --use-style-prompt False |
|
||||
| modified_beam_search | 2.86 | 5.93 | --pre-text-transform upper-no-punc --style-text-transform upper-no-punc |
|
||||
| modified_beam_search | 2.6 | 5.5 | --pre-text-transform mixed-punc --style-text-transform mixed-punc |
|
||||
|
||||
|
||||
The decoding command is:
|
||||
|
||||
```bash
|
||||
for style in mixed-punc upper-no-punc; do
|
||||
python ./zipformer_prompt_asr/decode_bert.py \
|
||||
--epoch 60 \
|
||||
--avg 10 \
|
||||
--use-averaged-model True \
|
||||
--post-normalization True \
|
||||
--causal False \
|
||||
--exp-dir ./zipformer_prompt_asr/exp \
|
||||
--manifest-dir data/fbank \
|
||||
--bpe-model data/lang_bpe_500_fallback_coverage_0.99/bpe.model \
|
||||
--max-duration 1000 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
--text-encoder-type BERT \
|
||||
--text-encoder-dim 768 \
|
||||
--memory-layer 0 \
|
||||
--use-ls-test-set False \
|
||||
--use-ls-context-list False \
|
||||
--max-prompt-lens 1000 \
|
||||
--use-pre-text True \
|
||||
--use-style-prompt True \
|
||||
--style-text-transform $style \
|
||||
--pre-text-transform $style \
|
||||
--compute-CER 0
|
||||
done
|
||||
```
|
||||
|
||||
##### Training on the medium subset, with content & style prompt, **with** context list
|
||||
|
||||
You can find a pre-trained model, training logs, decoding logs, and decoding results at: <https://huggingface.co/marcoyang/icefall-promptasr-with-context-libriheavy-zipformer-BERT-2023-10-10>
|
||||
|
||||
This model is trained with an extra type of content prompt (context words), thus it does better
|
||||
on **word-level** context biasing. Note that to train this model, please first run `prepare_prompt_asr.sh`
|
||||
to prepare a manifest containing context words.
|
||||
|
||||
The training command is:
|
||||
|
||||
```bash
|
||||
|
||||
causal=0
|
||||
subset=medium
|
||||
memory_dropout_rate=0.05
|
||||
text_encoder_type=BERT
|
||||
use_context_list=True
|
||||
|
||||
# prepare the required data for context biasing
|
||||
./prepare_prompt_asr.sh --stage 0 --stop_stage 1
|
||||
|
||||
python ./zipformer_prompt_asr/train_bert_encoder.py \
|
||||
--world-size 4 \
|
||||
--start-epoch 1 \
|
||||
--num-epochs 50 \
|
||||
--exp-dir ./zipformer_prompt_asr/exp \
|
||||
--use-fp16 True \
|
||||
--memory-dropout-rate $memory_dropout_rate \
|
||||
--causal $causal \
|
||||
--subset $subset \
|
||||
--manifest-dir data/fbank \
|
||||
--bpe-model data/lang_bpe_500_fallback_coverage_0.99/bpe.model \
|
||||
--max-duration 1000 \
|
||||
--text-encoder-type $text_encoder_type \
|
||||
--text-encoder-dim 768 \
|
||||
--use-context-list $use_context_list \
|
||||
--top-k 10000 \
|
||||
--use-style-prompt 1
|
||||
```
|
||||
|
||||
*Utterance-level biasing:*
|
||||
|
||||
| decoding method | lh-test-clean | lh-test-other | comment |
|
||||
|----------------------|---------------|---------------|---------------------|
|
||||
| modified_beam_search | 3.17 | 6.72 | --use-pre-text 0 --use-style-prompt 0 |
|
||||
| modified_beam_search | 2.91 | 6.24 | --pre-text-transform upper-no-punc --style-text-transform upper-no-punc |
|
||||
| modified_beam_search | 2.72 | 5.72 | --pre-text-transform mixed-punc --style-text-transform mixed-punc |
|
||||
|
||||
|
||||
The decoding command for the table above is:
|
||||
|
||||
```bash
|
||||
for style in mixed-punc upper-no-punc; do
|
||||
python ./zipformer_prompt_asr/decode_bert.py \
|
||||
--epoch 50 \
|
||||
--avg 10 \
|
||||
--use-averaged-model True \
|
||||
--post-normalization True \
|
||||
--causal False \
|
||||
--exp-dir ./zipformer_prompt_asr/exp \
|
||||
--manifest-dir data/fbank \
|
||||
--bpe-model data/lang_bpe_500_fallback_coverage_0.99/bpe.model \
|
||||
--max-duration 1000 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
--text-encoder-type BERT \
|
||||
--text-encoder-dim 768 \
|
||||
--memory-layer 0 \
|
||||
--use-ls-test-set False \
|
||||
--use-ls-context-list False \
|
||||
--max-prompt-lens 1000 \
|
||||
--use-pre-text True \
|
||||
--use-style-prompt True \
|
||||
--style-text-transform $style \
|
||||
--pre-text-transform $style \
|
||||
--compute-CER 0
|
||||
done
|
||||
```
|
||||
|
||||
*Word-level biasing:*
|
||||
|
||||
The results are reported on LibriSpeech test-sets using the biasing list provided from <https://arxiv.org/abs/2104.02194>.
|
||||
You need to set `--use-ls-test-set True` so that the LibriSpeech test sets are used.
|
||||
|
||||
| decoding method | ls-test-clean | ls-test-other | comment |
|
||||
|----------------------|---------------|---------------|---------------------|
|
||||
| modified_beam_search | 2.4 | 5.08 | --use-pre-text 0 --use-style-prompt 0 |
|
||||
| modified_beam_search | 2.14 | 4.62 | --use-ls-context-list 1 --pre-text-transform mixed-punc --style-text-transform mixed-punc --ls-distractors 0 |
|
||||
| modified_beam_search | 2.14 | 4.64 | --use-ls-context-list 1 --pre-text-transform mixed-punc --style-text-transform mixed-punc --ls-distractors 100 |
|
||||
|
||||
The decoding command is for the table above is:
|
||||
|
||||
```bash
|
||||
use_ls_test_set=1
|
||||
use_ls_context_list=1
|
||||
|
||||
for ls_distractors in 0 100; do
|
||||
python ./zipformer_prompt_asr/decode_bert.py \
|
||||
--epoch 50 \
|
||||
--avg 10 \
|
||||
--use-averaged-model True \
|
||||
--post-normalization True \
|
||||
--causal False \
|
||||
--exp-dir ./zipformer_prompt_asr/exp \
|
||||
--manifest-dir data/fbank \
|
||||
--bpe-model data/lang_bpe_500_fallback_coverage_0.99/bpe.model \
|
||||
--max-duration 1000 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
--text-encoder-type BERT \
|
||||
--text-encoder-dim 768 \
|
||||
--memory-layer 0 \
|
||||
--use-ls-test-set $use_ls_test_setse \
|
||||
--use-ls-context-list $use_ls_context_list \
|
||||
--ls-distractors $ls_distractors \
|
||||
--max-prompt-lens 1000 \
|
||||
--use-pre-text True \
|
||||
--use-style-prompt True \
|
||||
--style-text-transform mixed-punc \
|
||||
--pre-text-transform mixed-punc \
|
||||
--compute-CER 0
|
||||
done
|
||||
|
||||
```
|
36
egs/libriheavy/ASR/prepare_prompt_asr.sh
Executable file
36
egs/libriheavy/ASR/prepare_prompt_asr.sh
Executable file
@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
# This is the preparation recipe for PromptASR: https://arxiv.org/pdf/2309.07414
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
manifest_dir=data/fbank
|
||||
subset=medium
|
||||
topk=10000
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download the meta biasing list for LibriSpeech"
|
||||
mkdir -p data/context_biasing
|
||||
cd data/context_biasing
|
||||
git clone https://github.com/facebookresearch/fbai-speech.git
|
||||
cd ../..
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Add rare-words for context biasing to the manifest"
|
||||
python zipformer_prompt_asr/utils.py \
|
||||
--manifest-dir $manifest_dir \
|
||||
--subset $subset \
|
||||
--top-k $topk
|
||||
|
||||
fi
|
1
egs/libriheavy/ASR/shared
Symbolic link
1
egs/libriheavy/ASR/shared
Symbolic link
@ -0,0 +1 @@
|
||||
../../../icefall/shared
|
0
egs/libriheavy/ASR/zipformer_prompt_asr/__init__.py
Normal file
0
egs/libriheavy/ASR/zipformer_prompt_asr/__init__.py
Normal file
520
egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py
Normal file
520
egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py
Normal file
@ -0,0 +1,520 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from dataset import PromptASRDataset
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import ( # SingleCutSampler,
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
ExtraPadding,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class LibriHeavyAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
if args.use_context_list:
|
||||
assert args.rare_word_file is not None
|
||||
with open(args.rare_word_file, "r") as f:
|
||||
self.rare_word_list = (
|
||||
f.read().lower().split()
|
||||
) # Use lower-cased for easier style transform
|
||||
else:
|
||||
self.rare_word_list = None
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--concatenate-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, utterances (cuts) will be concatenated "
|
||||
"to minimize the amount of padding.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--duration-factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Determines the maximum duration of a concatenated cut "
|
||||
"relative to the duration of the longest cut in a batch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gap",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The amount of padding (in seconds) inserted between "
|
||||
"concatenated cuts. This padding is filled with noise when "
|
||||
"noise augmentation is used.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it "
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
# Libriheavy specific arguments
|
||||
group.add_argument(
|
||||
"--subset",
|
||||
type=str,
|
||||
default="small",
|
||||
help="Select the Libriheavy subset (small|medium|large)",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--use-context-list",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use the context list of libri heavy",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=10000,
|
||||
help="""The top-k words are identified as common words,
|
||||
the rest as rare words""",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--with-decoding",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="If the texts field contain decoding",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--random-left-padding",
|
||||
type=str2bool,
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--rare-word-file",
|
||||
type=str,
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--long-audio-cuts",
|
||||
type=str,
|
||||
default="data/manifest_npr/npr1_cuts_all_guids_0.jsonl.gz",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
text_sampling_func: Callable[[List[str]], str] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||
)
|
||||
# Cut concatenation should be the first transform in the list,
|
||||
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||
# different utterances.
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = PromptASRDataset(
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
text_sampling_func=text_sampling_func,
|
||||
rare_word_list=self.rare_word_list,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
# NOTE: the PerturbSpeed transform should be added only if we
|
||||
# remove it from data prep stage.
|
||||
# Add on-the-fly speed perturbation; since originally it would
|
||||
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||
# 3x more epochs.
|
||||
# Speed perturbation probably should come first before
|
||||
# concatenation, but in principle the transforms order doesn't have
|
||||
# to be strict (e.g. could be randomized)
|
||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# Drop feats to be on the safe side.
|
||||
train = PromptASRDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
text_sampling_func=text_sampling_func,
|
||||
rare_word_list=self.rare_word_list,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=True,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"SingleCutSampler is no longer supported by lhotse"
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(
|
||||
self,
|
||||
cuts_valid: CutSet,
|
||||
text_sampling_func: Callable[[List[str]], str] = None,
|
||||
) -> DataLoader:
|
||||
transforms = []
|
||||
if self.args.random_left_padding:
|
||||
logging.info("Enable random left padding")
|
||||
transforms.append(
|
||||
ExtraPadding(extra_frames=16, randomized=True, direction="left")
|
||||
)
|
||||
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = PromptASRDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
text_sampling_func=text_sampling_func,
|
||||
rare_word_list=self.rare_word_list,
|
||||
)
|
||||
else:
|
||||
validate = PromptASRDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
text_sampling_func=text_sampling_func,
|
||||
rare_word_list=self.rare_word_list,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info(f"About to get {self.args.subset} cuts")
|
||||
|
||||
if self.args.use_context_list:
|
||||
path = (
|
||||
self.args.manifest_dir
|
||||
/ f"libriheavy_cuts_{self.args.subset}_with_context_list_topk_{self.args.top_k}.jsonl.gz"
|
||||
)
|
||||
elif self.args.with_decoding:
|
||||
path = (
|
||||
self.args.manifest_dir
|
||||
/ f"libriheavy_cuts_{self.args.subset}_with_decoding.jsonl.gz"
|
||||
)
|
||||
else:
|
||||
path = (
|
||||
self.args.manifest_dir / f"libriheavy_cuts_{self.args.subset}.jsonl.gz"
|
||||
)
|
||||
|
||||
logging.info(f"Loading manifest from {path}.")
|
||||
cuts_train = CutSet.from_jsonl_lazy(path)
|
||||
return cuts_train
|
||||
|
||||
@lru_cache()
|
||||
def dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
cuts_valid = load_manifest_lazy(
|
||||
self.args.manifest_dir / "libriheavy_cuts_dev.jsonl.gz"
|
||||
)
|
||||
return cuts_valid
|
||||
|
||||
@lru_cache()
|
||||
def test_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-clean cuts")
|
||||
cuts_valid = load_manifest_lazy(
|
||||
self.args.manifest_dir / "libriheavy_cuts_test-clean_official.jsonl.gz"
|
||||
)
|
||||
return cuts_valid
|
||||
|
||||
@lru_cache()
|
||||
def test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-other cuts")
|
||||
cuts_valid = load_manifest_lazy(
|
||||
self.args.manifest_dir / "libriheavy_cuts_test-other_official.jsonl.gz"
|
||||
)
|
||||
return cuts_valid
|
||||
|
||||
@lru_cache()
|
||||
def librispeech_test_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def librispeech_test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def long_audio_cuts(self) -> CutSet:
|
||||
logging.info("About to get long audio cuts")
|
||||
cuts = load_manifest_lazy(
|
||||
self.args.long_audio_cuts,
|
||||
)
|
||||
return cuts
|
||||
|
||||
@lru_cache()
|
||||
def test_dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get test dev cuts")
|
||||
cuts = load_manifest_lazy(
|
||||
self.args.manifest_dir / "libriheavy_cuts_test_dev.jsonl.gz"
|
||||
)
|
||||
return cuts
|
1
egs/libriheavy/ASR/zipformer_prompt_asr/beam_search.py
Symbolic link
1
egs/libriheavy/ASR/zipformer_prompt_asr/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
|
586
egs/libriheavy/ASR/zipformer_prompt_asr/dataset.py
Normal file
586
egs/libriheavy/ASR/zipformer_prompt_asr/dataset.py
Normal file
@ -0,0 +1,586 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||
#
|
||||
# See ../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from lhotse import validate
|
||||
from lhotse.cut import CutSet
|
||||
from lhotse.dataset import K2SpeechRecognitionDataset
|
||||
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
|
||||
from lhotse.utils import compute_num_frames, ifnone
|
||||
from text_normalization import (
|
||||
lower_all_char,
|
||||
lower_only_alpha,
|
||||
remove_non_alphabetic,
|
||||
train_text_normalization,
|
||||
upper_all_char,
|
||||
upper_only_alpha,
|
||||
)
|
||||
from torch.utils.data.dataloader import DataLoader, default_collate
|
||||
|
||||
|
||||
class PromptASRDataset(torch.utils.data.Dataset):
|
||||
"""This is a dataset for Prompt ASR. It supports the following features:
|
||||
1. Select a tuple of (text, pre_text, style_text) randomly from a
|
||||
list of texts as supervisions.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
return_cuts: bool = False,
|
||||
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
|
||||
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
|
||||
input_strategy: BatchIO = PrecomputedFeatures(),
|
||||
text_sampling_func: Optional[Callable[[List[str]], str]] = None,
|
||||
rare_word_list: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Icefall ASR IterableDataset constructor. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py
|
||||
for more details.
|
||||
|
||||
:param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut
|
||||
objects used to create that batch.
|
||||
:param cut_transforms: A list of transforms to be applied on each sampled batch,
|
||||
before converting cuts to an input representation (audio/features).
|
||||
Examples: cut concatenation, noise cuts mixing, etc.
|
||||
:param input_transforms: A list of transforms to be applied on each sampled batch,
|
||||
after the cuts are converted to audio/features.
|
||||
Examples: normalization, SpecAugment, etc.
|
||||
:param input_strategy: Converts cuts into a collated batch of audio/features.
|
||||
By default, reads pre-computed features from disk.
|
||||
:param text_sampling_func: Sampling a text as transcription from a list of texts.
|
||||
"""
|
||||
super().__init__()
|
||||
# Initialize the fields
|
||||
self.return_cuts = return_cuts
|
||||
self.cut_transforms = ifnone(cut_transforms, [])
|
||||
self.input_transforms = ifnone(input_transforms, [])
|
||||
self.input_strategy = input_strategy
|
||||
|
||||
# a text sampling function
|
||||
self.text_sampling_func = text_sampling_func
|
||||
self.rare_word_list = rare_word_list
|
||||
|
||||
def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]:
|
||||
"""
|
||||
Return a new batch, with the batch size automatically determined using the constraints
|
||||
of max_frames and max_cuts.
|
||||
"""
|
||||
validate_for_asr(cuts)
|
||||
|
||||
# Sort the cuts by duration so that the first one determines the batch time dimensions.
|
||||
cuts = cuts.sort_by_duration(ascending=False)
|
||||
|
||||
# Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts
|
||||
# the supervision boundaries.
|
||||
for tnfm in self.cut_transforms:
|
||||
cuts = tnfm(cuts)
|
||||
|
||||
# Sort the cuts again after transforms
|
||||
cuts = cuts.sort_by_duration(ascending=False)
|
||||
|
||||
# Get a tensor with batched feature matrices, shape (B, T, F)
|
||||
# Collation performs auto-padding, if necessary.
|
||||
input_tpl = self.input_strategy(cuts)
|
||||
if len(input_tpl) == 3:
|
||||
# An input strategy with fault tolerant audio reading mode.
|
||||
# "cuts" may be a subset of the original "cuts" variable,
|
||||
# that only has cuts for which we succesfully read the audio.
|
||||
inputs, _, cuts = input_tpl
|
||||
else:
|
||||
inputs, _ = input_tpl
|
||||
|
||||
# Get a dict of tensors that encode the positional information about supervisions
|
||||
# in the batch of feature matrices. The tensors are named "sequence_idx",
|
||||
# "start_frame/sample" and "num_frames/samples".
|
||||
supervision_intervals = self.input_strategy.supervision_intervals(cuts)
|
||||
|
||||
# Apply all available transforms on the inputs, i.e. either audio or features.
|
||||
# This could be feature extraction, global MVN, SpecAugment, etc.
|
||||
segments = torch.stack(list(supervision_intervals.values()), dim=1)
|
||||
for tnfm in self.input_transforms:
|
||||
inputs = tnfm(inputs, supervision_segments=segments)
|
||||
|
||||
batch = {
|
||||
"inputs": inputs,
|
||||
"supervisions": default_collate(
|
||||
[
|
||||
self.text_sampling_func(
|
||||
texts=supervision.texts,
|
||||
pre_texts=supervision.pre_texts,
|
||||
context_list=supervision.context_list
|
||||
if "context_list" in supervision.custom
|
||||
else None,
|
||||
rare_word_list=self.rare_word_list,
|
||||
)
|
||||
if self.text_sampling_func is not None
|
||||
else {
|
||||
"text": train_text_normalization(supervision.texts[0]),
|
||||
"pre_text": train_text_normalization(supervision.pre_texts[0]),
|
||||
"style_text": train_text_normalization(
|
||||
supervision.pre_texts[0]
|
||||
),
|
||||
"transform_ids": 0,
|
||||
}
|
||||
for sequence_idx, cut in enumerate(cuts)
|
||||
for supervision in cut.supervisions
|
||||
]
|
||||
),
|
||||
}
|
||||
# Update the 'supervisions' field with sequence_idx and start/num frames/samples
|
||||
batch["supervisions"].update(supervision_intervals)
|
||||
if self.return_cuts:
|
||||
batch["supervisions"]["cut"] = [
|
||||
cut for cut in cuts for sup in cut.supervisions
|
||||
]
|
||||
|
||||
has_word_alignments = all(
|
||||
s.alignment is not None and "word" in s.alignment
|
||||
for c in cuts
|
||||
for s in c.supervisions
|
||||
)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def validate_for_asr(cuts: CutSet) -> None:
|
||||
validate(cuts)
|
||||
tol = 2e-3 # 1ms
|
||||
for cut in cuts:
|
||||
for supervision in cut.supervisions:
|
||||
assert supervision.start >= -tol, (
|
||||
f"Supervisions starting before the cut are not supported for ASR"
|
||||
f" (sup id: {supervision.id}, cut id: {cut.id})"
|
||||
)
|
||||
|
||||
# Supervision start time is relative to Cut ...
|
||||
# https://lhotse.readthedocs.io/en/v0.10_e/cuts.html
|
||||
#
|
||||
# 'supervision.end' is end of supervision inside the Cut
|
||||
assert supervision.end <= cut.duration + tol, (
|
||||
f"Supervisions ending after the cut "
|
||||
f"are not supported for ASR"
|
||||
f" (sup id: {supervision.id}, cut id: {cut.id})"
|
||||
)
|
||||
|
||||
|
||||
def get_substring(s: str, min_len: int = 40, max_len: int = 250) -> str:
|
||||
"""A helper function that generates a random substring from a given string
|
||||
|
||||
Args:
|
||||
s (str): Input string
|
||||
|
||||
Returns:
|
||||
str: Returned substring
|
||||
"""
|
||||
min_len = min(len(s), min_len)
|
||||
|
||||
start = random.randint(0, len(s) - min_len)
|
||||
end = min(start + max_len, random.randint(start + min_len, len(s)))
|
||||
|
||||
return s[start:end]
|
||||
|
||||
|
||||
def triplet_text_sampling(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
context_list: Optional[str] = None,
|
||||
rare_word_list: Optional[List[str]] = None,
|
||||
transforms: Optional[List[Callable[[str], str]]] = None,
|
||||
min_len_style: Optional[int] = 80,
|
||||
) -> Dict[str, str]:
|
||||
"""This function generates a triplet of
|
||||
(pre_text, style_text, ref_text). The style of style_text and ref_text
|
||||
should **always** match, whereas the style of pre_text is arbitrary.
|
||||
Suppose we have 2 different transforms A,B, and the preceding text is
|
||||
referred to as pre_text. The following three tuples are all valid:
|
||||
|
||||
(A(pre_text), A(style_text), A(ref_text))
|
||||
(A(pre_text), B(style_text), B(ref_text))
|
||||
(A(pre_text), A(style_text), A(ref_text))
|
||||
(B(pre_text), B(style_text), B(ref_text))
|
||||
|
||||
If transforms is not given, the following pre-defined transforms
|
||||
are available:
|
||||
0: original (mixed-cased, with punc)
|
||||
1: upper_only_alpha (upper-cased, no punc)
|
||||
|
||||
When the transform of text and pre_text match, we can use the whole
|
||||
pre_text as the prompt text.
|
||||
|
||||
Args:
|
||||
texts (List[str]):
|
||||
A list of ref_texts whose first item is the ground truth
|
||||
text from books.
|
||||
pre_texts (List[str]):
|
||||
A list of pre_texts, whose first item is the groundtruth
|
||||
pre_text from books.
|
||||
context_list: Optional[str] = None,
|
||||
A list of biasing words separated by space
|
||||
rare_word_list: Optional[str] = None,
|
||||
A list of rare-words separated by space (used as distractors)
|
||||
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
||||
|
||||
Returns:
|
||||
A dictionary of ref_text, pre_text, style_text
|
||||
"""
|
||||
assert len(texts) == len(pre_texts)
|
||||
assert len(texts) == 2
|
||||
|
||||
# we assume the first item to be ground truth
|
||||
gt_text = texts[0]
|
||||
gt_pre_text = pre_texts[0]
|
||||
|
||||
if transforms is None:
|
||||
transforms = [
|
||||
lambda x: x, # return it self
|
||||
upper_only_alpha,
|
||||
lower_only_alpha,
|
||||
lower_all_char,
|
||||
]
|
||||
|
||||
sampling_weight = [
|
||||
0.7,
|
||||
0.3,
|
||||
0.0,
|
||||
0.0,
|
||||
] # Mixed-punc should have the largest sampling prob
|
||||
|
||||
total_transforms = len(transforms) # do not use the recognized trans
|
||||
|
||||
# Randomly sample transforms
|
||||
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
|
||||
|
||||
# get the normalized text and pre_text
|
||||
text = transforms[i_text](gt_text)
|
||||
pre_text = transforms[i_pre_text](gt_pre_text)
|
||||
|
||||
if i_text == i_pre_text:
|
||||
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
|
||||
else:
|
||||
# get the pre_text of same style as text
|
||||
# For now, **don't** do transform to the style text, because we do it after the dataloader
|
||||
style_text = gt_pre_text
|
||||
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
|
||||
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
|
||||
|
||||
return {
|
||||
"text": train_text_normalization(text),
|
||||
"pre_text": train_text_normalization(pre_text),
|
||||
"style_text": train_text_normalization(style_text),
|
||||
"transform_ids": i_text,
|
||||
}
|
||||
|
||||
|
||||
def triplet_text_sampling_with_context_list(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
context_list: str,
|
||||
rare_word_list: List[str],
|
||||
transforms: Optional[List[Callable[[str], str]]] = None,
|
||||
min_len_style: Optional[int] = 80,
|
||||
) -> Dict[str, str]:
|
||||
"""This function generates a triplet of
|
||||
(pre_text, style_text, ref_text). The pre_text is either the preceding text
|
||||
or a list of words (context words + distractors).
|
||||
The style of style_text and ref_text should **always** match, whereas
|
||||
the style of pre_text is arbitrary.
|
||||
Suppose we have 2 different transforms A,B, and the preceding text is
|
||||
referred to as pre_text. The following three tuples are all valid:
|
||||
|
||||
(A(pre_text), A(style_text), A(ref_text))
|
||||
(A(pre_text), B(style_text), B(ref_text))
|
||||
(A(pre_text), A(style_text), A(ref_text))
|
||||
(B(pre_text), B(style_text), B(ref_text))
|
||||
|
||||
If transforms is not given, the following pre-defined transforms
|
||||
are available:
|
||||
0: original (mixed-cased, with punc)
|
||||
1: upper_only_alpha (upper-cased, no punc)
|
||||
|
||||
When the transform of text and pre_text match, we can use the whole
|
||||
pre_text as the prompt text.
|
||||
|
||||
Args:
|
||||
texts (List[str]):
|
||||
A list of ref_texts whose first item is the ground truth
|
||||
text from books.
|
||||
pre_texts (List[str]):
|
||||
A list of pre_texts, whose first item is the groundtruth
|
||||
pre_text from books.
|
||||
context_list: Optional[str] = None,
|
||||
A list of biasing words separated by space
|
||||
rare_word_list: Optional[str] = None,
|
||||
A list of rare-words separated by space (used as distractors)
|
||||
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
||||
|
||||
Returns:
|
||||
A dictionary of ref_text, pre_text, style_text
|
||||
Returns:
|
||||
str: A dictionary
|
||||
"""
|
||||
# import pdb; pdb.set_trace()
|
||||
assert len(texts) == len(pre_texts)
|
||||
assert len(texts) == 2
|
||||
|
||||
if context_list is not None:
|
||||
context_list = context_list.lower()
|
||||
|
||||
# we assume the first item to be ground truth
|
||||
gt_text = texts[0]
|
||||
gt_pre_text = pre_texts[0]
|
||||
|
||||
if transforms is None:
|
||||
transforms = [
|
||||
lambda x: x, # return it self
|
||||
upper_only_alpha,
|
||||
lower_only_alpha,
|
||||
lower_all_char,
|
||||
]
|
||||
|
||||
sampling_weight = [
|
||||
0.7,
|
||||
0.3,
|
||||
0.0,
|
||||
0.0,
|
||||
] # Mixed-punc should have the largest sampling prob
|
||||
|
||||
total_transforms = len(transforms) # do not use the recognized trans
|
||||
|
||||
# Select a transformation randomly
|
||||
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
|
||||
|
||||
# get the normalized text and pre_text
|
||||
text = transforms[i_text](gt_text)
|
||||
pre_text = get_pre_text_with_context_list2(
|
||||
text=gt_text,
|
||||
pre_text=gt_pre_text,
|
||||
context_list=context_list,
|
||||
rare_words_list=rare_word_list,
|
||||
)
|
||||
pre_text = transforms[i_pre_text](pre_text)
|
||||
|
||||
if i_text == i_pre_text:
|
||||
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
|
||||
else:
|
||||
# get the pre_text of same style as text
|
||||
# For now, **don't** do transform to the style text
|
||||
style_text = gt_pre_text
|
||||
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
|
||||
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
|
||||
|
||||
return {
|
||||
"text": train_text_normalization(text),
|
||||
"pre_text": train_text_normalization(pre_text),
|
||||
"style_text": train_text_normalization(style_text),
|
||||
"transform_ids": i_text,
|
||||
}
|
||||
|
||||
|
||||
def get_pre_text_with_context_list(
|
||||
text: str,
|
||||
pre_text: str,
|
||||
context_list: str,
|
||||
rare_words_list: List[str] = None,
|
||||
) -> str:
|
||||
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
|
||||
# By a small proportion of time, use the substring of ref_text as pre_text
|
||||
|
||||
if context_list != "" and context_list is not None:
|
||||
v = random.random()
|
||||
if v < 0.5:
|
||||
# correct + distractors
|
||||
# sample distractors
|
||||
num_distractors = random.randint(0, 50)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
# sample correct
|
||||
correct = context_list.split()
|
||||
i = random.randint(1, len(correct))
|
||||
correct = random.sample(correct, i)
|
||||
# combine correct and distractors
|
||||
pre_text = distractors + correct
|
||||
random.shuffle(pre_text)
|
||||
pre_text = " ".join(pre_text)
|
||||
elif v < 0.7:
|
||||
splitted = text.split()
|
||||
sampling_weights = [len(w) ** 1.2 for w in splitted]
|
||||
sampling_weights = [p / sum(sampling_weights) for p in sampling_weights]
|
||||
i = random.randint(1, min(len(splitted), 20))
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
num_distractors = random.randint(0, 70)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
splitted += distractors
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
pre_text = " ".join(splitted)
|
||||
else:
|
||||
pre_text = pre_text
|
||||
else:
|
||||
v = random.random()
|
||||
if v < 0.1:
|
||||
splitted = text.split()
|
||||
sampling_weights = [len(w) ** 1.2 for w in splitted]
|
||||
sampling_weights = [p / sum(sampling_weights) for p in sampling_weights]
|
||||
i = random.randint(1, min(len(splitted), 20))
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
pre_text = " ".join(splitted)
|
||||
num_distractors = random.randint(0, 70)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
splitted += distractors
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
elif v < 0.2:
|
||||
# full distractors
|
||||
num_distractors = random.randint(5, 100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
pre_text = " ".join(distractors)
|
||||
|
||||
elif v < 0.3:
|
||||
pre_text = get_substring(text, min_len=15, max_len=150)
|
||||
else:
|
||||
pre_text = pre_text
|
||||
|
||||
return pre_text
|
||||
|
||||
|
||||
def get_pre_text_with_context_list2(
|
||||
text: str,
|
||||
pre_text: str,
|
||||
context_list: str,
|
||||
rare_words_list: List[str] = None,
|
||||
) -> str:
|
||||
# Get the pre_text, either the ground truth preceding text or
|
||||
# a list of words consisting of biasing words and distrators
|
||||
# By a small proportion of time, use the substring of ref_text as pre_text
|
||||
|
||||
if context_list != "" and context_list is not None:
|
||||
v = random.random()
|
||||
if v < 0.4:
|
||||
# sample distractors
|
||||
num_distractors = random.randint(50, 100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
# sample correct
|
||||
correct = context_list.split()
|
||||
i = random.randint(1, len(correct))
|
||||
correct = random.sample(correct, i)
|
||||
# combine correct and distractors
|
||||
pre_text = distractors + correct
|
||||
random.shuffle(pre_text)
|
||||
pre_text = " ".join(pre_text)
|
||||
elif v < 0.55:
|
||||
splitted = text.split()
|
||||
sampling_weights = [
|
||||
len(w) ** 1.2 for w in splitted
|
||||
] # longer words with higher weights
|
||||
sampling_weights = [p / sum(sampling_weights) for p in sampling_weights]
|
||||
i = random.randint(1, min(len(splitted), 20))
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
num_distractors = random.randint(50, 100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
splitted += distractors
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
pre_text = " ".join(splitted)
|
||||
else:
|
||||
pre_text = pre_text
|
||||
else:
|
||||
v = random.random()
|
||||
if v < 0.3:
|
||||
splitted = text.split()
|
||||
sampling_weights = [len(w) ** 1.2 for w in splitted]
|
||||
sampling_weights = [p / sum(sampling_weights) for p in sampling_weights]
|
||||
i = random.randint(1, min(len(splitted), 20))
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
pre_text = " ".join(splitted)
|
||||
num_distractors = random.randint(50, 100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
splitted += distractors
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
elif v < 0.4:
|
||||
# full distractors
|
||||
num_distractors = random.randint(5, 100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
pre_text = " ".join(distractors)
|
||||
elif v < 0.6:
|
||||
pre_text = get_substring(text, min_len=15, max_len=150)
|
||||
else:
|
||||
pre_text = pre_text
|
||||
|
||||
return pre_text
|
||||
|
||||
|
||||
def naive_triplet_text_sampling(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
context_list: str = None,
|
||||
rare_word_list: List[str] = None,
|
||||
min_len_style: Optional[int] = 120,
|
||||
):
|
||||
# The most simplest text sampling function, used only for
|
||||
# evaluation, use a fixed sentence as the style text
|
||||
|
||||
return {
|
||||
"text": train_text_normalization(texts[0]),
|
||||
"pre_text": train_text_normalization(pre_texts[0]),
|
||||
"style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related. What do you think?",
|
||||
"transform_ids": 0,
|
||||
}
|
||||
|
||||
|
||||
def random_shuffle_subset(
|
||||
data: List[str],
|
||||
p: float = 0.2,
|
||||
p_mask: float = 0.05,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Randomly shuffle the subset by probability `p`, which means that p% of the samples
|
||||
in the original batch are shuffled, the others are kept in the original order.
|
||||
|
||||
With a probability of `p_mask`, replace the original string with an empty string.
|
||||
|
||||
"""
|
||||
|
||||
num_to_shuffle = int(len(data) * p)
|
||||
id_to_shuffle = np.random.choice(len(data), num_to_shuffle, replace=False)
|
||||
item_to_shuffle = [data[id] for id in id_to_shuffle]
|
||||
random.shuffle(item_to_shuffle)
|
||||
|
||||
for id, item in zip(id_to_shuffle, item_to_shuffle):
|
||||
data[id] = item
|
||||
|
||||
# Randomly mask a proportion of the data to empty string
|
||||
if p_mask > 0:
|
||||
for i in range(len(data)):
|
||||
if random.random() < p_mask:
|
||||
data[i] = ""
|
||||
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
texts = [
|
||||
"AA, BB, cC, dD!",
|
||||
"AA BB CC DD",
|
||||
]
|
||||
|
||||
pre_texts = [
|
||||
"EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg?",
|
||||
"EE FF GG EE FF GG EE FF GG EE FF GG EE FF GG",
|
||||
]
|
||||
for i in range(10):
|
||||
print(f"Run: {i}")
|
||||
print(triplet_text_sampling(texts, pre_texts))
|
791
egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py
Normal file
791
egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py
Normal file
@ -0,0 +1,791 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao,
|
||||
# Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) modified beam search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriHeavyAsrDataModule
|
||||
from beam_search import greedy_search, greedy_search_batch, modified_beam_search
|
||||
from ls_text_normalization import word_normalization
|
||||
from text_normalization import (
|
||||
ref_text_normalization,
|
||||
remove_non_alphabetic,
|
||||
upper_only_alpha,
|
||||
)
|
||||
from train_baseline import add_model_arguments, get_params, get_transducer_model
|
||||
from utils import write_error_stats
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, setup_logger, store_transcripts, str2bool
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bpe_500",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
- modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion
|
||||
- modified_beam_search_LODR
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--post-normalization",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--long-audio-recog",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-ls-test-set",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use librispeech test set for evaluation.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--compute-CER",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Reports CER. By default, only reports WER",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
LM:
|
||||
A neural net LM for shallow fusion. Only used when `--use-shallow-fusion`
|
||||
set to true.
|
||||
ngram_lm:
|
||||
A ngram lm. Used in LODR decoding.
|
||||
ngram_lm_scale:
|
||||
The scale of the ngram language model.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
texts = batch["supervisions"]["text"]
|
||||
batch_size = feature.size(0)
|
||||
|
||||
# Get the transducer encoder output
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
encoder_out, encoder_out_lens = model.encode_audio(
|
||||
feature=feature,
|
||||
feature_lens=feature_lens,
|
||||
)
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
LM:
|
||||
A neural network LM, used during shallow fusion
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
if not params.use_ls_test_set:
|
||||
book_names = [
|
||||
cut.text_path.split("/")[-2] for cut in batch["supervisions"]["cut"]
|
||||
]
|
||||
else:
|
||||
book_names = ["" for _ in cut_ids]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, book_name, hyp_words, ref_text in zip(
|
||||
cut_ids, book_names, hyps, texts
|
||||
):
|
||||
ref_text = ref_text_normalization(ref_text)
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
# if not params.use_ls_test_set:
|
||||
# results[name + " " + book_name].extend(this_batch)
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
biasing_words: List[str] = None,
|
||||
):
|
||||
test_set_wers = dict()
|
||||
test_set_cers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f,
|
||||
f"{test_set_name}-{key}",
|
||||
results,
|
||||
enable_log=True,
|
||||
biasing_words=biasing_words,
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
if params.compute_CER:
|
||||
# Write CER statistics
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
|
||||
)
|
||||
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
cer = write_error_stats(
|
||||
f,
|
||||
f"{test_set_name}-{key}",
|
||||
results,
|
||||
enable_log=True,
|
||||
compute_CER=params.compute_CER,
|
||||
)
|
||||
test_set_cers[key] = cer
|
||||
|
||||
logging.info("Wrote detailed CER stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
if params.compute_CER:
|
||||
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
|
||||
errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tcER", file=f)
|
||||
for key, val in test_set_cers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_cers:
|
||||
s += "{} CER\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriHeavyAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"modified_beam_search",
|
||||
)
|
||||
|
||||
if params.long_audio_recog:
|
||||
params.res_dir = params.exp_dir / (params.decoding_method + "long_audio")
|
||||
else:
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if "ngram" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
),
|
||||
strict=False,
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
),
|
||||
strict=False,
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
LM = None
|
||||
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
libriheavy = LibriHeavyAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = libriheavy.test_clean_cuts()
|
||||
test_other_cuts = libriheavy.test_other_cuts()
|
||||
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
|
||||
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
|
||||
long_audio_cuts = libriheavy.long_audio_cuts()
|
||||
|
||||
test_clean_dl = libriheavy.valid_dataloaders(
|
||||
test_clean_cuts,
|
||||
)
|
||||
test_other_dl = libriheavy.valid_dataloaders(
|
||||
test_other_cuts,
|
||||
)
|
||||
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
|
||||
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
|
||||
long_audio_dl = libriheavy.valid_dataloaders(
|
||||
long_audio_cuts,
|
||||
)
|
||||
|
||||
if params.use_ls_test_set:
|
||||
test_sets = ["ls-test-clean", "ls-test-other"]
|
||||
test_dl = [ls_test_clean_dl, ls_test_other_dl]
|
||||
else:
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
if params.long_audio_recog:
|
||||
test_sets = ["long-audio"]
|
||||
test_dl = [long_audio_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
if params.use_ls_test_set:
|
||||
f = open(
|
||||
"data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", "r"
|
||||
)
|
||||
biasing_words = f.read().strip().split()
|
||||
f.close()
|
||||
else:
|
||||
biasing_words = None
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
if params.post_normalization:
|
||||
if "-post-normalization" not in params.suffix:
|
||||
params.suffix += "-post-normalization"
|
||||
|
||||
new_res = {}
|
||||
for k in results_dict:
|
||||
new_ans = []
|
||||
for item in results_dict[k]:
|
||||
id, ref, hyp = item
|
||||
if params.use_ls_test_set:
|
||||
hyp = (
|
||||
" ".join(hyp).replace("-", " ").split()
|
||||
) # handle the hypens
|
||||
hyp = upper_only_alpha(" ".join(hyp)).split()
|
||||
hyp = [word_normalization(w.upper()) for w in hyp]
|
||||
hyp = " ".join(hyp).split()
|
||||
hyp = [w for w in hyp if w != ""]
|
||||
ref = upper_only_alpha(" ".join(ref)).split()
|
||||
else:
|
||||
hyp = upper_only_alpha(" ".join(hyp)).split()
|
||||
ref = upper_only_alpha(" ".join(ref)).split()
|
||||
new_ans.append((id, ref, hyp))
|
||||
new_res[k] = new_ans
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=new_res,
|
||||
biasing_words=biasing_words,
|
||||
)
|
||||
|
||||
if params.suffix.endswith("-post-normalization"):
|
||||
params.suffix = params.suffix.replace("-post-normalization", "")
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1025
egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py
Executable file
1025
egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,963 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao,
|
||||
# Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) modified beam search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriHeavyAsrDataModule
|
||||
from beam_search import (
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
greedy_search_batch_with_context,
|
||||
greedy_search_with_context,
|
||||
modified_beam_search,
|
||||
)
|
||||
from dataset import naive_triplet_text_sampling, random_shuffle_subset
|
||||
from lhotse import load_manifest_lazy
|
||||
from text_normalization import (
|
||||
lower_all_char,
|
||||
lower_only_alpha,
|
||||
ref_text_normalization,
|
||||
remove_non_alphabetic,
|
||||
train_text_normalization,
|
||||
upper_all_char,
|
||||
upper_only_alpha,
|
||||
)
|
||||
from train_bert_encoder_with_style import (
|
||||
_encode_texts_as_bytes_with_tokenizer,
|
||||
add_model_arguments,
|
||||
get_params,
|
||||
get_tokenizer,
|
||||
get_transducer_model,
|
||||
)
|
||||
from transformers import BertModel, BertTokenizer
|
||||
from utils import get_facebook_biasing_list
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Where to store the logs",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bpe_500",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
- modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion
|
||||
- modified_beam_search_LODR
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input-manifest",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input manifest to be decoded",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-manifest",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Where to store the output manifest (directory)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-pre-text",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Use pre-text is available during decoding",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-style-prompt",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Use style prompt when evaluation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-context-embedding",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use context fuser when evaluation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--post-normalization",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--compute-CER",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Reports CER. By default, only reports WER",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--style-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of style prompt, i.e style_text",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pre-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of content prompt, i.e pre_text",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-ls-test-set",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use librispeech test set for evaluation.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-ls-context-list",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="If use a fixed context list for LibriSpeech decoding",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
||||
"""Apply transform to a list of text. By default, the text are in
|
||||
ground truth format, i.e mixed-punc.
|
||||
|
||||
Args:
|
||||
text (List[str]): Input text string
|
||||
transform (str): Transform to be applied
|
||||
|
||||
Returns:
|
||||
List[str]: _description_
|
||||
"""
|
||||
if transform == "mixed-punc":
|
||||
return text
|
||||
elif transform == "upper-no-punc":
|
||||
return [upper_only_alpha(s) for s in text]
|
||||
elif transform == "lower-no-punc":
|
||||
return [lower_only_alpha(s) for s in text]
|
||||
elif transform == "lower-punc":
|
||||
return [lower_all_char(s) for s in text]
|
||||
else:
|
||||
raise NotImplementedError(f"Unseen transform: {transform}")
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
tokenizer,
|
||||
batch: dict,
|
||||
biasing_dict: dict = None,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
LM:
|
||||
A neural net LM for shallow fusion. Only used when `--use-shallow-fusion`
|
||||
set to true.
|
||||
ngram_lm:
|
||||
A ngram lm. Used in LODR decoding.
|
||||
ngram_lm_scale:
|
||||
The scale of the ngram language model.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
cuts = batch["supervisions"]["cut"]
|
||||
cut_ids = [c.supervisions[0].id for c in cuts]
|
||||
batch_size = feature.size(0)
|
||||
|
||||
# get pre_text
|
||||
if "pre_text" in batch["supervisions"] and params.use_pre_text:
|
||||
pre_texts = batch["supervisions"][
|
||||
"text"
|
||||
] # use the ground truth ref text as pre_text
|
||||
pre_texts = [train_text_normalization(t) for t in pre_texts]
|
||||
else:
|
||||
pre_texts = ["" for _ in range(batch_size)]
|
||||
|
||||
if params.use_ls_context_list:
|
||||
pre_texts = [biasing_dict[id] for id in cut_ids]
|
||||
|
||||
# get style_text
|
||||
if params.use_style_prompt:
|
||||
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related."
|
||||
style_texts = batch["supervisions"].get(
|
||||
"style_text", [fixed_sentence for _ in range(batch_size)]
|
||||
)
|
||||
style_texts = [train_text_normalization(t) for t in style_texts]
|
||||
else:
|
||||
style_texts = ["" for _ in range(batch_size)] # use empty string
|
||||
|
||||
# Get the text embedding input
|
||||
if params.use_pre_text or params.use_style_prompt:
|
||||
|
||||
# apply style transform to the pre_text and style_text
|
||||
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
||||
# pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
|
||||
if params.use_style_prompt:
|
||||
style_texts = _apply_style_transform(
|
||||
style_texts, params.style_text_transform
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
# Use tokenizer to prepare input for text encoder
|
||||
encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer(
|
||||
pre_texts=pre_texts,
|
||||
style_texts=style_texts,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
)
|
||||
|
||||
memory, memory_key_padding_mask = model.encode_text(
|
||||
encoded_inputs=encoded_inputs,
|
||||
style_lens=style_lens,
|
||||
) # (T,B,C)
|
||||
else:
|
||||
memory = None
|
||||
memory_key_padding_mask = None
|
||||
|
||||
# Get the transducer encoder output
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
encoder_out, encoder_out_lens = model.encode_audio(
|
||||
feature=feature,
|
||||
feature_lens=feature_lens,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
if memory is None or not params.use_context_embedding:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
else:
|
||||
memory = memory.permute(1, 0, 2) # (T,N,C) -> (N,T,C)
|
||||
context = model.context_fuser(
|
||||
memory, padding_mask=memory_key_padding_mask
|
||||
) # (N,C)
|
||||
context = model.joiner.context_proj(context) # (N,C)
|
||||
hyp_tokens = greedy_search_batch_with_context(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
context=context,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
if memory is None or not params.use_context_embedding:
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
else:
|
||||
cur_context = context[i : i + 1, :]
|
||||
hyp = greedy_search_with_context(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
context=cur_context,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
tokenizer,
|
||||
biasing_dict: Dict = None,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
LM:
|
||||
A neural network LM, used during shallow fusion
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 40
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"][
|
||||
"text"
|
||||
] # By default, this should be in mixed-punc format
|
||||
|
||||
# the style of ref_text should match style_text
|
||||
texts = _apply_style_transform(texts, params.style_text_transform)
|
||||
if params.use_style_prompt:
|
||||
texts = _apply_style_transform(texts, params.style_text_transform)
|
||||
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
tokenizer=tokenizer,
|
||||
biasing_dict=biasing_dict,
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_text = ref_text_normalization(
|
||||
ref_text
|
||||
) # remove full-width symbols & some book marks
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
test_set_cers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
if params.compute_CER:
|
||||
# Write CER statistics
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
|
||||
)
|
||||
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
cer = write_error_stats(
|
||||
f,
|
||||
f"{test_set_name}-{key}",
|
||||
results,
|
||||
enable_log=True,
|
||||
compute_CER=params.compute_CER,
|
||||
)
|
||||
test_set_cers[key] = cer
|
||||
|
||||
logging.info("Wrote detailed CER stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
if params.compute_CER:
|
||||
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
|
||||
errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tCER", file=f)
|
||||
for key, val in test_set_cers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_cers:
|
||||
s += "{} CER\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
def add_decoding_result_to_manifest(
|
||||
in_manifest,
|
||||
out_manifest: str,
|
||||
results_dict: Dict,
|
||||
):
|
||||
# write the decoding results with prompt to the manifest as an
|
||||
# extra ref text
|
||||
new_ans = {}
|
||||
for key, value in results_dict.items():
|
||||
for items in value:
|
||||
id, ref, hyp = items
|
||||
new_ans[id] = " ".join(hyp)
|
||||
|
||||
def _add_decoding(c):
|
||||
key = c.supervisions[0].id
|
||||
c.supervisions[0].texts.append(new_ans[key])
|
||||
return c
|
||||
|
||||
in_manifest = in_manifest.map(_add_decoding)
|
||||
logging.info(f"Saving manifest to {out_manifest}")
|
||||
in_manifest.to_file(out_manifest)
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriHeavyAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
cuts = load_manifest_lazy(args.input_manifest)
|
||||
|
||||
world_size = args.world_size
|
||||
assert world_size >= 1
|
||||
if world_size > 1:
|
||||
splitted_cuts = cuts.split(num_splits=world_size)
|
||||
mp.spawn(
|
||||
run, args=(world_size, args, splitted_cuts), nprocs=world_size, join=True
|
||||
)
|
||||
else:
|
||||
run(rank=0, world_size=1, args=args, cuts=cuts)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def run(rank, world_size, args, cuts):
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.use_pre_text:
|
||||
params.suffix += f"-pre-text-{params.pre_text_transform}"
|
||||
|
||||
if params.use_style_prompt:
|
||||
params.suffix += f"-style-prompt-{params.style_text_transform}"
|
||||
|
||||
params.suffix += f"-{rank}"
|
||||
|
||||
world_size = params.world_size
|
||||
|
||||
params.output_manifest = Path(params.output_manifest)
|
||||
if world_size > 1:
|
||||
cuts = cuts[rank]
|
||||
out_name = params.output_manifest / f"with_decoding_job_{rank}.jsonl.gz"
|
||||
else:
|
||||
out_name = params.output_manifest / "with_decoding.jsonl.gz"
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
|
||||
setup_logger(f"{params.log_dir}/log-get-manifest-with-decoding-{rank}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
tokenizer = get_tokenizer(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
LM = None
|
||||
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
libriheavy = LibriHeavyAsrDataModule(args)
|
||||
|
||||
dl = libriheavy.valid_dataloaders(
|
||||
cuts, text_sampling_func=naive_triplet_text_sampling
|
||||
)
|
||||
|
||||
test_sets = ["test"]
|
||||
test_dl = [dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
biasing_dict = None
|
||||
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
tokenizer=tokenizer,
|
||||
biasing_dict=biasing_dict,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
# save_results(
|
||||
# params=params,
|
||||
# test_set_name=test_set,
|
||||
# results_dict=results_dict,
|
||||
# )
|
||||
|
||||
add_decoding_result_to_manifest(
|
||||
in_manifest=cuts,
|
||||
out_manifest=out_name,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
# torch.set_num_threads(1)
|
||||
# torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
130
egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py
Normal file
130
egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py
Normal file
@ -0,0 +1,130 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from scaling import Balancer
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""This class modifies the stateless decoder from the following paper:
|
||||
|
||||
RNN-transducer with stateless prediction network
|
||||
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
||||
|
||||
It removes the recurrent connection from the decoder, i.e., the prediction
|
||||
network. Different from the above paper, it adds an extra Conv1d
|
||||
right after the embedding layer.
|
||||
|
||||
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
decoder_dim: int,
|
||||
blank_id: int,
|
||||
context_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
vocab_size:
|
||||
Number of tokens of the modeling unit including blank.
|
||||
decoder_dim:
|
||||
Dimension of the input embedding, and of the decoder output.
|
||||
blank_id:
|
||||
The ID of the blank symbol.
|
||||
context_size:
|
||||
Number of previous words to use to predict the next word.
|
||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.embedding = nn.Embedding(
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=decoder_dim,
|
||||
padding_idx=blank_id,
|
||||
)
|
||||
# the balancers are to avoid any drift in the magnitude of the
|
||||
# embeddings, which would interact badly with parameter averaging.
|
||||
self.balancer = Balancer(
|
||||
decoder_dim,
|
||||
channel_dim=-1,
|
||||
min_positive=0.0,
|
||||
max_positive=1.0,
|
||||
min_abs=0.5,
|
||||
max_abs=1.0,
|
||||
prob=0.05,
|
||||
)
|
||||
|
||||
self.blank_id = blank_id
|
||||
|
||||
assert context_size >= 1, context_size
|
||||
self.context_size = context_size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
if context_size > 1:
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels=decoder_dim,
|
||||
out_channels=decoder_dim,
|
||||
kernel_size=context_size,
|
||||
padding=0,
|
||||
groups=decoder_dim // 4, # group size == 4
|
||||
bias=False,
|
||||
)
|
||||
self.balancer2 = Balancer(
|
||||
decoder_dim,
|
||||
channel_dim=-1,
|
||||
min_positive=0.0,
|
||||
max_positive=1.0,
|
||||
min_abs=0.5,
|
||||
max_abs=1.0,
|
||||
prob=0.05,
|
||||
)
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, U).
|
||||
need_pad:
|
||||
True to left pad the input. Should be True during training.
|
||||
False to not pad the input. Should be False during inference.
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, decoder_dim).
|
||||
"""
|
||||
y = y.to(torch.int64)
|
||||
# this stuff about clamp() is a temporary fix for a mismatch
|
||||
# at utterance start, we use negative ids in beam_search.py
|
||||
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
|
||||
|
||||
embedding_out = self.balancer(embedding_out)
|
||||
|
||||
if self.context_size > 1:
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
if need_pad is True:
|
||||
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
|
||||
else:
|
||||
# During inference time, there is no need to do extra padding
|
||||
# as we only need one output
|
||||
assert embedding_out.size(-1) == self.context_size
|
||||
embedding_out = self.conv(embedding_out)
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
embedding_out = F.relu(embedding_out)
|
||||
embedding_out = self.balancer2(embedding_out)
|
||||
|
||||
return embedding_out
|
43
egs/libriheavy/ASR/zipformer_prompt_asr/encoder_interface.py
Normal file
43
egs/libriheavy/ASR/zipformer_prompt_asr/encoder_interface.py
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class EncoderInterface(nn.Module):
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A tensor of shape (batch_size, input_seq_len, num_features)
|
||||
containing the input features.
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames
|
||||
in `x` before padding.
|
||||
Returns:
|
||||
Return a tuple containing two tensors:
|
||||
- encoder_out, a tensor of (batch_size, out_seq_len, output_dim)
|
||||
containing unnormalized probabilities, i.e., the output of a
|
||||
linear layer.
|
||||
- encoder_out_lens, a tensor of shape (batch_size,) containing
|
||||
the number of frames in `encoder_out` before padding.
|
||||
"""
|
||||
raise NotImplementedError("Please implement it in a subclass")
|
255
egs/libriheavy/ASR/zipformer_prompt_asr/export_PromptASR.py
Normal file
255
egs/libriheavy/ASR/zipformer_prompt_asr/export_PromptASR.py
Normal file
@ -0,0 +1,255 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2023 Xiaomi Corporation (Author: Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
|
||||
"""
|
||||
Export `model.state_dict()`
|
||||
|
||||
- For non-streaming model:
|
||||
|
||||
./zipformer_prompt_asr/export_PromptASR.py \
|
||||
--exp-dir ./zipformer_prompt_asr/exp \
|
||||
--tokens data/lang_bpe_500_fallback_coverage_0.99/tokens.txt \
|
||||
--epoch 50 \
|
||||
--avg 10
|
||||
|
||||
- For streaming model:
|
||||
|
||||
./zipformer_prompt_asr/export_PromptASR.py \
|
||||
--exp-dir ./zipformer_prompt_asr/exp \
|
||||
--causal 1 \
|
||||
--tokens data/lang_bpe_500_fallback_coverage_0.99/tokens.txt \
|
||||
--epoch 50 \
|
||||
--avg 10
|
||||
|
||||
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from train_bert_encoder import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import make_pad_mask, num_tokens, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/tokens.txt",
|
||||
help="Path to the tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
It will generate a file named jit_script.pt.
|
||||
Check ./jit_pretrained.py for how to use it.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
# if torch.cuda.is_available():
|
||||
# device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.eval()
|
||||
|
||||
assert params.jit is False, "Jit is not supported yet"
|
||||
|
||||
logging.info("Not using torchscript. Export model.state_dict()")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
86
egs/libriheavy/ASR/zipformer_prompt_asr/joiner.py
Normal file
86
egs/libriheavy/ASR/zipformer_prompt_asr/joiner.py
Normal file
@ -0,0 +1,86 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import ScaledLinear
|
||||
|
||||
|
||||
class Joiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
context_dim: int = 512,
|
||||
context_injection: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
|
||||
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
|
||||
self.output_linear = nn.Linear(joiner_dim, vocab_size)
|
||||
if context_injection:
|
||||
self.context_proj = ScaledLinear(
|
||||
context_dim, joiner_dim, initial_scale=0.25
|
||||
)
|
||||
else:
|
||||
self.context_proj = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
context: torch.Tensor = None,
|
||||
project_input: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, s_range, C).
|
||||
decoder_out:
|
||||
Output from the decoder. Its shape is (N, T, s_range, C).
|
||||
context:
|
||||
An embedding vector representing the previous context information
|
||||
project_input:
|
||||
If true, apply input projections encoder_proj and decoder_proj.
|
||||
If this is false, it is the user's responsibility to do this
|
||||
manually.
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, s_range, C).
|
||||
"""
|
||||
assert encoder_out.ndim == decoder_out.ndim == 4
|
||||
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
|
||||
|
||||
if project_input:
|
||||
if context:
|
||||
logit = (
|
||||
self.encoder_proj(encoder_out)
|
||||
+ self.decoder_proj(decoder_out)
|
||||
+ self.context_proj(context)
|
||||
)
|
||||
else:
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
||||
else:
|
||||
if context is not None:
|
||||
logit = encoder_out + decoder_out + context.unsqueeze(1).unsqueeze(1)
|
||||
else:
|
||||
logit = encoder_out + decoder_out
|
||||
|
||||
logit = self.output_linear(torch.tanh(logit))
|
||||
|
||||
return logit
|
153
egs/libriheavy/ASR/zipformer_prompt_asr/ls_text_normalization.py
Normal file
153
egs/libriheavy/ASR/zipformer_prompt_asr/ls_text_normalization.py
Normal file
@ -0,0 +1,153 @@
|
||||
import re
|
||||
|
||||
words = {
|
||||
0: "zero",
|
||||
1: "one",
|
||||
2: "two",
|
||||
3: "three",
|
||||
4: "four",
|
||||
5: "five",
|
||||
6: "six",
|
||||
7: "seven",
|
||||
8: "eight",
|
||||
9: "nine",
|
||||
10: "ten",
|
||||
11: "eleven",
|
||||
12: "twelve",
|
||||
13: "thirteen",
|
||||
14: "fourteen",
|
||||
15: "fifteen",
|
||||
16: "sixteen",
|
||||
17: "seventeen",
|
||||
18: "eighteen",
|
||||
19: "nineteen",
|
||||
20: "twenty",
|
||||
30: "thirty",
|
||||
40: "forty",
|
||||
50: "fifty",
|
||||
60: "sixty",
|
||||
70: "seventy",
|
||||
80: "eighty",
|
||||
90: "ninety",
|
||||
}
|
||||
ordinal_nums = [
|
||||
"zeroth",
|
||||
"first",
|
||||
"second",
|
||||
"third",
|
||||
"fourth",
|
||||
"fifth",
|
||||
"sixth",
|
||||
"seventh",
|
||||
"eighth",
|
||||
"ninth",
|
||||
"tenth",
|
||||
"eleventh",
|
||||
"twelfth",
|
||||
"thirteenth",
|
||||
"fourteenth",
|
||||
"fifteenth",
|
||||
"sixteenth",
|
||||
"seventeenth",
|
||||
"eighteenth",
|
||||
"nineteenth",
|
||||
"twentieth",
|
||||
]
|
||||
|
||||
num_ordinal_dict = {num: ordinal_nums[num] for num in range(21)}
|
||||
|
||||
|
||||
def year_to_words(num: int):
|
||||
assert isinstance(num, int), num
|
||||
# check if a num is representing a year
|
||||
if num > 1500 and num < 2000:
|
||||
return words[num // 100] + " " + num_to_words(num % 100)
|
||||
elif num == 2000:
|
||||
return "TWO THOUSAND"
|
||||
elif num > 2000:
|
||||
return "TWO THOUSAND AND " + num_to_words(num % 100)
|
||||
else:
|
||||
return num_to_words(num)
|
||||
|
||||
|
||||
def num_to_words(num: int):
|
||||
# Return the English words of a integer number
|
||||
|
||||
# If this is a year number
|
||||
if num > 1500 and num < 2030:
|
||||
return year_to_words(num)
|
||||
|
||||
if num < 20:
|
||||
return words[num]
|
||||
if num < 100:
|
||||
if num % 10 == 0:
|
||||
return words[num // 10 * 10]
|
||||
else:
|
||||
return words[num // 10 * 10] + " " + words[num % 10]
|
||||
if num < 1000:
|
||||
return words[num // 100] + " hundred and " + num_to_words(num % 100)
|
||||
if num < 1000000:
|
||||
return num_to_words(num // 1000) + " thousand " + num_to_words(num % 1000)
|
||||
return num
|
||||
|
||||
|
||||
def num_to_ordinal_word(num: int):
|
||||
|
||||
return num_ordinal_dict.get(num, num_to_words(num)).upper()
|
||||
|
||||
|
||||
def replace_full_width_symbol(s: str) -> str:
|
||||
# replace full-width symbol with theri half width counterpart
|
||||
s = s.replace("“", '"')
|
||||
s = s.replace("”", '"')
|
||||
s = s.replace("‘", "'")
|
||||
s = s.replace("’", "'")
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def decoding_normalization(text: str) -> str:
|
||||
text = replace_full_width_symbol(text)
|
||||
|
||||
# Only keep all alpha-numeric characters, hypen and apostrophe
|
||||
text = text.replace("-", " ")
|
||||
text = re.sub(r"[^a-zA-Z0-9\s']+", "", text)
|
||||
return text
|
||||
|
||||
|
||||
def word_normalization(word: str) -> str:
|
||||
# 1 .Use full word for some abbreviation
|
||||
# 2. Convert digits to english words
|
||||
# 3. Convert ordinal number to english words
|
||||
if word == "MRS":
|
||||
return "MISSUS"
|
||||
if word == "MR":
|
||||
return "MISTER"
|
||||
if word == "ST":
|
||||
return "SAINT"
|
||||
if word == "ECT":
|
||||
return "ET CETERA"
|
||||
if word.isnumeric():
|
||||
word = num_to_words(int(word))
|
||||
return str(word).upper()
|
||||
# e.g 9TH, 6TH
|
||||
if word[-2:] == "TH" and word[0].isnumeric():
|
||||
return num_to_ordinal_word(int(word[:-2])).upper()
|
||||
if word[0] == "'":
|
||||
return word[1:]
|
||||
|
||||
return word
|
||||
|
||||
|
||||
def simple_normalization(text: str) -> str:
|
||||
text = replace_full_width_symbol(text)
|
||||
text = text.replace("--", " ")
|
||||
|
||||
return text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
s = str(1830)
|
||||
out = word_normalization(s)
|
||||
print(s, out)
|
262
egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py
Normal file
262
egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py
Normal file
@ -0,0 +1,262 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import random
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear, penalize_abs_values_gt
|
||||
from torch import Tensor
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||
"Sequence Transduction with Recurrent Neural Networks"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_embed: nn.Module,
|
||||
encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
joiner: nn.Module,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder_embed:
|
||||
It is a Convolutional 2D subsampling module. It converts
|
||||
an input of shape (N, T, idim) to an output of of shape
|
||||
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||
It should contain one attribute: `blank_id`.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
||||
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
assert hasattr(decoder, "blank_id")
|
||||
|
||||
self.encoder_embed = encoder_embed
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
self.simple_am_proj = ScaledLinear(
|
||||
encoder_dim,
|
||||
vocab_size,
|
||||
initial_scale=0.25,
|
||||
)
|
||||
self.simple_lm_proj = ScaledLinear(
|
||||
decoder_dim,
|
||||
vocab_size,
|
||||
initial_scale=0.25,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
text:
|
||||
A 2-D tensor of integer dtype containing prompt text, of shape (N, T).
|
||||
It is exptected to contain the style prompt (first) and then the content
|
||||
prompt.
|
||||
text_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of elements (bytes)
|
||||
in `text` before padding, which will include the lengths of the
|
||||
style plus the content prompt.
|
||||
style_lens:
|
||||
A 1-D tensor of shape (N,), containing the number of elements (bytes)
|
||||
within each row of `text` that correspond to the style prompt (these
|
||||
are expected to come first).
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
x, x_lens = self.encoder_embed(x, x_lens)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, x_lens = self.encoder(
|
||||
x,
|
||||
x_lens,
|
||||
src_key_padding_mask,
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
# Note: y does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros(
|
||||
(encoder_out.size(0), 4),
|
||||
dtype=torch.int64,
|
||||
device=encoder_out.device,
|
||||
)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad,
|
||||
py_grad=py_grad,
|
||||
boundary=boundary,
|
||||
s_range=prune_range,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=self.joiner.encoder_proj(encoder_out),
|
||||
lm=self.joiner.decoder_proj(decoder_out),
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
||||
|
||||
def encode_audio(
|
||||
self,
|
||||
feature: Tensor,
|
||||
feature_lens: Tensor,
|
||||
memory: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Encode the input audio features
|
||||
|
||||
Args:
|
||||
feature (Tensor): Input audio (N,T,C)
|
||||
feature_lens (Tensor): Length of input audio (N,)
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor]: Encoded acoustic features and length
|
||||
"""
|
||||
x, x_lens = self.encoder_embed(feature, feature_lens)
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens = self.encoder(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
392
egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py
Normal file
392
egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py
Normal file
@ -0,0 +1,392 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import random
|
||||
import warnings
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear, penalize_abs_values_gt
|
||||
from torch import Tensor
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask
|
||||
|
||||
|
||||
class PromptedTransducer(nn.Module):
|
||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||
"Sequence Transduction with Recurrent Neural Networks"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_embed: nn.Module,
|
||||
encoder: EncoderInterface,
|
||||
text_encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
joiner: nn.Module,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
use_BERT: bool = True,
|
||||
text_encoder_type: str = "BERT",
|
||||
text_encoder_adapter: bool = False,
|
||||
freeze_text_encoder: bool = True,
|
||||
context_fuser: nn.Module = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder_embed:
|
||||
It is a Convolutional 2D subsampling module. It converts
|
||||
an input of shape (N, T, idim) to an output of of shape
|
||||
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||
`logit_lens` of shape (N,).
|
||||
text_encoder:
|
||||
This is a encoder that processes text information (e.g content prompt
|
||||
and style prompt). The input is `x` of (N,T) and `x_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||
It should contain one attribute: `blank_id`.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
||||
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
text_encoder_type:
|
||||
The type of the text_encoder. Supported are (BERT, DistilBERT)
|
||||
context_fuser
|
||||
A optional module that fuses the embeddings of text encoder. The fused embedding
|
||||
will be added to the joiner.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
assert hasattr(decoder, "blank_id")
|
||||
|
||||
self.encoder_embed = encoder_embed
|
||||
self.encoder = encoder
|
||||
self.text_encoder = text_encoder
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
self.simple_am_proj = ScaledLinear(
|
||||
encoder_dim,
|
||||
vocab_size,
|
||||
initial_scale=0.25,
|
||||
)
|
||||
self.simple_lm_proj = ScaledLinear(
|
||||
decoder_dim,
|
||||
vocab_size,
|
||||
initial_scale=0.25,
|
||||
)
|
||||
|
||||
self.use_BERT = use_BERT # if the text encoder is a pre-trained BERT
|
||||
self.context_fuser = context_fuser
|
||||
|
||||
assert text_encoder_type in (
|
||||
"BERT",
|
||||
"DistilBERT",
|
||||
"BERT-UNCASED",
|
||||
), f"Unseen text_encoder type {text_encoder_type}"
|
||||
self.text_encoder_dim = (
|
||||
self.text_encoder.config.hidden_size
|
||||
if text_encoder_type in ("BERT", "BERT-UNCASED")
|
||||
else self.text_encoder.config.dim
|
||||
)
|
||||
self.freeze_text_encoder = freeze_text_encoder
|
||||
|
||||
if text_encoder_adapter:
|
||||
self.text_encoder_adapter = nn.Sequential(
|
||||
nn.Linear(self.text_encoder_dim, self.text_encoder_dim, bias=False),
|
||||
nn.Tanh(),
|
||||
)
|
||||
else:
|
||||
self.text_encoder_adapter = None
|
||||
|
||||
self.style_prompt_embedding = nn.Parameter(
|
||||
torch.full((self.text_encoder_dim,), 0.5)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
encoded_inputs: Dict,
|
||||
style_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
use_pre_text: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
text:
|
||||
A 2-D tensor of integer dtype containing prompt text, of shape (N, T).
|
||||
It is exptected to contain the style prompt (first) and then the content
|
||||
prompt.
|
||||
text_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of elements (bytes)
|
||||
in `text` before padding, which will include the lengths of the
|
||||
style plus the content prompt.
|
||||
style_lens:
|
||||
A 1-D tensor of shape (N,), containing the number of elements (bytes)
|
||||
within each row of `text` that correspond to the style prompt (these
|
||||
are expected to come first).
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
if self.freeze_text_encoder:
|
||||
self.text_encoder.eval()
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
x, x_lens = self.encoder_embed(x, x_lens)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
# freeze the BERT text encoder
|
||||
|
||||
if use_pre_text:
|
||||
memory, memory_key_padding_mask = self.encode_text(
|
||||
encoded_inputs, style_lens=style_lens
|
||||
)
|
||||
else:
|
||||
memory = None
|
||||
memory_key_padding_mask = None
|
||||
|
||||
encoder_out, x_lens = self.encoder(
|
||||
x,
|
||||
x_lens,
|
||||
src_key_padding_mask,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
# Note: y does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros(
|
||||
(encoder_out.size(0), 4),
|
||||
dtype=torch.int64,
|
||||
device=encoder_out.device,
|
||||
)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad,
|
||||
py_grad=py_grad,
|
||||
boundary=boundary,
|
||||
s_range=prune_range,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=self.joiner.encoder_proj(encoder_out),
|
||||
lm=self.joiner.decoder_proj(decoder_out),
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
if self.context_fuser is not None and memory is not None:
|
||||
memory = memory.permute(1, 0, 2) # (T,N,C) -> (N,T,C)
|
||||
context = self.context_fuser(memory, padding_mask=memory_key_padding_mask)
|
||||
context = self.joiner.context_proj(context)
|
||||
else:
|
||||
context = None
|
||||
|
||||
logits = self.joiner(am_pruned, lm_pruned, context=context, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
||||
|
||||
def _add_style_indicator(self, memory: Tensor, style_lens: Tensor):
|
||||
"""
|
||||
Adds to `memory` an indicator that is 1.0 for positions that correspond to
|
||||
the `style prompt` and 0 elsewhere. The scale can be fixed because the
|
||||
scale of the embedding vector can adjust to compensate.
|
||||
|
||||
Args:
|
||||
memory: (memory_len, batch_size, embed_dim)
|
||||
style_lens: (batch_size,), a vector of lengths of the style prompt.
|
||||
"""
|
||||
|
||||
(memory_len, batch_size, embed_dim) = memory.shape
|
||||
|
||||
indicator = (
|
||||
torch.arange(memory_len, device=memory.device).unsqueeze(-1) < style_lens
|
||||
)
|
||||
indicator = indicator.to(memory.dtype)
|
||||
|
||||
extra_term = torch.zeros_like(memory)
|
||||
extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand(
|
||||
memory_len, batch_size, self.text_encoder_dim
|
||||
)
|
||||
|
||||
return memory + extra_term
|
||||
|
||||
def encode_text(
|
||||
self,
|
||||
encoded_inputs: Dict,
|
||||
style_lens: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Get the embeddings of text
|
||||
|
||||
Args:
|
||||
encoded_inputs: The encoded inputs generated by a tokenizer (Dict)
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor]: Returns the text embeddings encoded by the
|
||||
text_encoder and the attention mask
|
||||
"""
|
||||
text_lens = encoded_inputs.pop("length") # need to use pop to remove this item
|
||||
|
||||
# Freeze the pre-trained text encoder
|
||||
with torch.no_grad():
|
||||
memory = self.text_encoder(**encoded_inputs)["last_hidden_state"] # (B,T,C)
|
||||
memory = memory.permute(1, 0, 2)
|
||||
|
||||
# Text encoder adapter
|
||||
if self.text_encoder_adapter is not None:
|
||||
memory = self.text_encoder_adapter(memory)
|
||||
|
||||
memory = self._add_style_indicator(memory, style_lens)
|
||||
|
||||
memory_key_padding_mask = make_pad_mask(text_lens)
|
||||
|
||||
return memory, memory_key_padding_mask
|
||||
|
||||
def encode_audio(
|
||||
self,
|
||||
feature: Tensor,
|
||||
feature_lens: Tensor,
|
||||
memory: Optional[Tensor],
|
||||
memory_key_padding_mask: Optional[Tensor],
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Encode the input audio features
|
||||
|
||||
Args:
|
||||
feature (Tensor): Input audio (N,T,C)
|
||||
feature_lens (Tensor): Length of input audio (N,)
|
||||
memory (Tensor): Embeddings from the text encoder
|
||||
memory_key_padding_mask (Tensor): _description_
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor]: _description_
|
||||
"""
|
||||
x, x_lens = self.encoder_embed(feature, feature_lens)
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens = self.encoder(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
|
||||
Transducer = PromptedTransducer # for decoding
|
1168
egs/libriheavy/ASR/zipformer_prompt_asr/optim.py
Normal file
1168
egs/libriheavy/ASR/zipformer_prompt_asr/optim.py
Normal file
File diff suppressed because it is too large
Load Diff
359
egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py
Normal file
359
egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py
Normal file
@ -0,0 +1,359 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script loads a checkpoint (`pretrained.pt`) and uses it to decode waves.
|
||||
You can generate the checkpoint with the following command:
|
||||
|
||||
./zipformer/export_PromptASR.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--tokens data/lang_bpe_500_fallback_coverage_0.99/tokens.txt \
|
||||
--epoch 50 \
|
||||
--avg 10
|
||||
|
||||
Utterance level context biasing:
|
||||
|
||||
./zipformer/pretrained.py \
|
||||
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||
--tokens data/lang_bpe_500_fallback_coverage_0.99/tokens.txt \
|
||||
--method modified_beam_search \
|
||||
--use-pre-text True \
|
||||
--content-prompt "bessy random words hello k2 ASR" \
|
||||
--use-style-prompt True \
|
||||
librispeech.flac
|
||||
|
||||
|
||||
Word level context biasing:
|
||||
|
||||
./zipformer/pretrained.py \
|
||||
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||
--tokens data/lang_bpe_500_fallback_coverage_0.99/tokens.txt \
|
||||
--method modified_beam_search \
|
||||
--use-pre-text True \
|
||||
--content-prompt "The topic is about horses." \
|
||||
--use-style-prompt True \
|
||||
test.wav
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import greedy_search_batch, modified_beam_search
|
||||
from text_normalization import _apply_style_transform, train_text_normalization
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train_bert_encoder import (
|
||||
_encode_texts_as_bytes_with_tokenizer,
|
||||
add_model_arguments,
|
||||
get_params,
|
||||
get_tokenizer,
|
||||
get_transducer_model,
|
||||
)
|
||||
|
||||
from icefall.utils import make_pad_mask, num_tokens, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500_fallback_coverage_0.99/bpe.model",
|
||||
help="""Path to tokens.txt.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame. Used only when
|
||||
--method is greedy_search.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-pre-text",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Use content prompt during decoding",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-style-prompt",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Use style prompt during decoding",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pre-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of content prompt, i.e pre_text",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--style-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of style prompt, i.e style_text",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--content-prompt", type=str, default="", help="The content prompt for decoding"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--style-prompt",
|
||||
type=str,
|
||||
default="Mixed-cased English text with punctuations, feel free to change it.",
|
||||
help="The style prompt for decoding",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
if params.causal:
|
||||
assert (
|
||||
"," not in params.chunk_size
|
||||
), "chunk_size should be one value in decoding."
|
||||
assert (
|
||||
"," not in params.left_context_frames
|
||||
), "left_context_frames should be one value in decoding."
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
tokenizer = get_tokenizer(params) # for text encoder
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
assert (
|
||||
len(params.sound_files) == 1
|
||||
), "Only support decoding one audio at this moment"
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
# encode prompts
|
||||
if params.use_pre_text:
|
||||
pre_text = [train_text_normalization(params.content_prompt)]
|
||||
pre_text = _apply_style_transform(pre_text, params.pre_text_transform)
|
||||
else:
|
||||
pre_text = [""]
|
||||
|
||||
if params.use_style_prompt:
|
||||
style_text = [params.style_prompt]
|
||||
style_text = _apply_style_transform(style_text, params.style_text_transform)
|
||||
else:
|
||||
style_text = [""]
|
||||
|
||||
if params.use_pre_text or params.use_style_prompt:
|
||||
encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer(
|
||||
pre_texts=pre_text,
|
||||
style_texts=style_text,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
no_limit=True,
|
||||
)
|
||||
|
||||
memory, memory_key_padding_mask = model.encode_text(
|
||||
encoded_inputs=encoded_inputs,
|
||||
style_lens=style_lens,
|
||||
) # (T,B,C)
|
||||
else:
|
||||
memory = None
|
||||
memory_key_padding_mask = None
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
encoder_out, encoder_out_lens = model.encode_audio(
|
||||
feature=features,
|
||||
feature_lens=feature_lengths,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
|
||||
hyps = []
|
||||
msg = f"Using {params.method}"
|
||||
logging.info(msg)
|
||||
|
||||
if params.method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
hyps.append(sp.decode(hyp_tokens)[0])
|
||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
hyps.append(sp.decode(hyp_tokens)[0])
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {params.method}")
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
s += f"{filename}:\n{hyp}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1872
egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py
Normal file
1872
egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py
Normal file
File diff suppressed because it is too large
Load Diff
276
egs/libriheavy/ASR/zipformer_prompt_asr/subsampling.py
Normal file
276
egs/libriheavy/ASR/zipformer_prompt_asr/subsampling.py
Normal file
@ -0,0 +1,276 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from scaling import (
|
||||
Balancer,
|
||||
BiasNorm,
|
||||
Dropout3,
|
||||
FloatLike,
|
||||
Optional,
|
||||
ScaledConv2d,
|
||||
ScaleGrad,
|
||||
ScheduledFloat,
|
||||
SwooshL,
|
||||
SwooshR,
|
||||
Whiten,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class ConvNeXt(nn.Module):
|
||||
"""
|
||||
Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
hidden_ratio: int = 3,
|
||||
kernel_size: Tuple[int, int] = (7, 7),
|
||||
layerdrop_rate: FloatLike = None,
|
||||
):
|
||||
super().__init__()
|
||||
padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
|
||||
hidden_channels = channels * hidden_ratio
|
||||
if layerdrop_rate is None:
|
||||
layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015))
|
||||
self.layerdrop_rate = layerdrop_rate
|
||||
|
||||
self.depthwise_conv = nn.Conv2d(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
groups=channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding,
|
||||
)
|
||||
|
||||
self.pointwise_conv1 = nn.Conv2d(
|
||||
in_channels=channels, out_channels=hidden_channels, kernel_size=1
|
||||
)
|
||||
|
||||
self.hidden_balancer = Balancer(
|
||||
hidden_channels,
|
||||
channel_dim=1,
|
||||
min_positive=0.3,
|
||||
max_positive=1.0,
|
||||
min_abs=0.75,
|
||||
max_abs=5.0,
|
||||
)
|
||||
|
||||
self.activation = SwooshL()
|
||||
self.pointwise_conv2 = ScaledConv2d(
|
||||
in_channels=hidden_channels,
|
||||
out_channels=channels,
|
||||
kernel_size=1,
|
||||
initial_scale=0.01,
|
||||
)
|
||||
|
||||
self.out_balancer = Balancer(
|
||||
channels,
|
||||
channel_dim=1,
|
||||
min_positive=0.4,
|
||||
max_positive=0.6,
|
||||
min_abs=1.0,
|
||||
max_abs=6.0,
|
||||
)
|
||||
self.out_whiten = Whiten(
|
||||
num_groups=1,
|
||||
whitening_limit=5.0,
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01,
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
return self.forward_internal(x)
|
||||
layerdrop_rate = float(self.layerdrop_rate)
|
||||
|
||||
if layerdrop_rate != 0.0:
|
||||
batch_size = x.shape[0]
|
||||
mask = (
|
||||
torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
|
||||
> layerdrop_rate
|
||||
)
|
||||
else:
|
||||
mask = None
|
||||
# turns out this caching idea does not work with --world-size > 1
|
||||
# return caching_eval(self.forward_internal, x, mask)
|
||||
return self.forward_internal(x, mask)
|
||||
|
||||
def forward_internal(
|
||||
self, x: Tensor, layer_skip_mask: Optional[Tensor] = None
|
||||
) -> Tensor:
|
||||
"""
|
||||
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
|
||||
|
||||
The returned value has the same shape as x.
|
||||
"""
|
||||
bypass = x
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.pointwise_conv1(x)
|
||||
x = self.hidden_balancer(x)
|
||||
x = self.activation(x)
|
||||
x = self.pointwise_conv2(x)
|
||||
|
||||
if layer_skip_mask is not None:
|
||||
x = x * layer_skip_mask
|
||||
|
||||
x = bypass + x
|
||||
x = self.out_balancer(x)
|
||||
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
|
||||
x = self.out_whiten(x)
|
||||
x = x.transpose(1, 3) # (N, C, H, W)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/2 length).
|
||||
|
||||
Convert an input of shape (N, T, idim) to an output
|
||||
with shape (N, T', odim), where
|
||||
T' = (T-3)//2 - 2 == (T-7)//2
|
||||
|
||||
It is based on
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
layer1_channels: int = 8,
|
||||
layer2_channels: int = 32,
|
||||
layer3_channels: int = 128,
|
||||
dropout: FloatLike = 0.1,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
in_channels:
|
||||
Number of channels in. The input shape is (N, T, in_channels).
|
||||
Caution: It requires: T >=7, in_channels >=7
|
||||
out_channels
|
||||
Output dim. The output shape is (N, (T-3)//2, out_channels)
|
||||
layer1_channels:
|
||||
Number of channels in layer1
|
||||
layer1_channels:
|
||||
Number of channels in layer2
|
||||
bottleneck:
|
||||
bottleneck dimension for 1d squeeze-excite
|
||||
"""
|
||||
assert in_channels >= 7
|
||||
super().__init__()
|
||||
|
||||
# The ScaleGrad module is there to prevent the gradients
|
||||
# w.r.t. the weight or bias of the first Conv2d module in self.conv from
|
||||
# exceeding the range of fp16 when using automatic mixed precision (amp)
|
||||
# training. (The second one is necessary to stop its bias from getting
|
||||
# a too-large gradient).
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=1,
|
||||
out_channels=layer1_channels,
|
||||
kernel_size=3,
|
||||
padding=(0, 1), # (time, freq)
|
||||
),
|
||||
ScaleGrad(0.2),
|
||||
Balancer(layer1_channels, channel_dim=1, max_abs=1.0),
|
||||
SwooshR(),
|
||||
nn.Conv2d(
|
||||
in_channels=layer1_channels,
|
||||
out_channels=layer2_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=0,
|
||||
),
|
||||
Balancer(layer2_channels, channel_dim=1, max_abs=4.0),
|
||||
SwooshR(),
|
||||
nn.Conv2d(
|
||||
in_channels=layer2_channels,
|
||||
out_channels=layer3_channels,
|
||||
kernel_size=3,
|
||||
stride=(1, 2), # (time, freq)
|
||||
),
|
||||
Balancer(layer3_channels, channel_dim=1, max_abs=4.0),
|
||||
SwooshR(),
|
||||
)
|
||||
|
||||
# just one convnext layer
|
||||
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
|
||||
|
||||
out_width = (((in_channels - 1) // 2) - 1) // 2
|
||||
|
||||
self.out = nn.Linear(out_width * layer3_channels, out_channels)
|
||||
# use a larger than normal grad_scale on this whitening module; there is
|
||||
# only one such module, so there is not a concern about adding together
|
||||
# many copies of this extra gradient term.
|
||||
self.out_whiten = Whiten(
|
||||
num_groups=1,
|
||||
whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.02,
|
||||
)
|
||||
|
||||
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
|
||||
# getting large, there is an unnecessary degree of freedom.
|
||||
self.out_norm = BiasNorm(out_channels)
|
||||
self.dropout = Dropout3(dropout, shared_dim=1)
|
||||
|
||||
def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is (N, T, idim).
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
|
||||
Returns:
|
||||
- a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
||||
- output lengths, of shape (batch_size,)
|
||||
"""
|
||||
# On entry, x is (N, T, idim)
|
||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
|
||||
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
|
||||
# gradients.
|
||||
x = self.conv(x)
|
||||
x = self.convnext(x)
|
||||
|
||||
# Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||
b, c, t, f = x.size()
|
||||
|
||||
x = x.transpose(1, 2).reshape(b, t, c * f)
|
||||
# now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels))
|
||||
|
||||
x = self.out(x)
|
||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||
x = self.out_whiten(x)
|
||||
x = self.out_norm(x)
|
||||
x = self.dropout(x)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
x_lens = (x_lens - 7) // 2
|
||||
assert x.size(1) == x_lens.max().item()
|
||||
|
||||
return x, x_lens
|
119
egs/libriheavy/ASR/zipformer_prompt_asr/test_model.py
Executable file
119
egs/libriheavy/ASR/zipformer_prompt_asr/test_model.py
Executable file
@ -0,0 +1,119 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./pruned_transducer_stateless4/test_model.py
|
||||
"""
|
||||
|
||||
from scaling import ScheduledFloat
|
||||
from train_subformer import get_params, get_text_encoder, get_transducer_model
|
||||
from zipformer import Zipformer2
|
||||
|
||||
|
||||
def test_model_1():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.num_encoder_layers = 24
|
||||
params.dim_feedforward = 1536 # 384 * 4
|
||||
params.encoder_dim = 384
|
||||
model = get_transducer_model(params)
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
|
||||
# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf
|
||||
def test_model_M():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.num_encoder_layers = "2,4,3,2,4"
|
||||
params.feedforward_dims = "1024,1024,2048,2048,1024"
|
||||
params.nhead = "8,8,8,8,8"
|
||||
params.encoder_dims = "384,384,384,384,384"
|
||||
params.attention_dims = "192,192,192,192,192"
|
||||
params.encoder_unmasked_dims = "256,256,256,256,256"
|
||||
params.zipformer_downsampling_factors = "1,2,4,8,2"
|
||||
params.cnn_module_kernels = "31,31,15,15"
|
||||
|
||||
params.text_encoder_dim = (192, 192, 256, 384)
|
||||
params.decoder_dim = 512
|
||||
params.joiner_dim = 512
|
||||
model = Zipformer2(
|
||||
output_downsampling_factor=8,
|
||||
downsampling_factor=(1, 2, 4, 8),
|
||||
num_encoder_layers=(2, 4, 4, 4),
|
||||
encoder_dim=(192, 192, 256, 384),
|
||||
encoder_unmasked_dim=(192, 192, 256, 256),
|
||||
query_head_dim=(32, 32, 32, 32),
|
||||
pos_head_dim=(4, 4, 4, 4),
|
||||
value_head_dim=(12, 12, 12, 12),
|
||||
pos_dim=48,
|
||||
num_heads=(4, 4, 4, 8),
|
||||
feedforward_dim=(
|
||||
384,
|
||||
512,
|
||||
768,
|
||||
1024,
|
||||
), # could increase this if there is nough data
|
||||
cnn_module_kernel=(31, 31, 15, 15),
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||
warmup_batches=4000.0,
|
||||
causal=False,
|
||||
)
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
model = Zipformer2(
|
||||
output_downsampling_factor=8,
|
||||
downsampling_factor=(1, 2, 4, 8),
|
||||
num_encoder_layers=(2, 4, 6, 6),
|
||||
encoder_dim=(256, 256, 384, 512),
|
||||
encoder_unmasked_dim=(196, 196, 256, 256),
|
||||
query_head_dim=(32, 32, 32, 32),
|
||||
pos_head_dim=(4, 4, 4, 4),
|
||||
value_head_dim=(12, 12, 12, 12),
|
||||
pos_dim=48,
|
||||
num_heads=(4, 4, 4, 8),
|
||||
feedforward_dim=(
|
||||
384,
|
||||
512,
|
||||
768,
|
||||
1024,
|
||||
), # could increase this if there is nough data
|
||||
cnn_module_kernel=(31, 31, 15, 15),
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||
warmup_batches=4000.0,
|
||||
causal=False,
|
||||
)
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
|
||||
def main():
|
||||
# test_model_1()
|
||||
test_model_M()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
101
egs/libriheavy/ASR/zipformer_prompt_asr/text_normalization.py
Normal file
101
egs/libriheavy/ASR/zipformer_prompt_asr/text_normalization.py
Normal file
@ -0,0 +1,101 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||
#
|
||||
# See ../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
|
||||
def train_text_normalization(s: str) -> str:
|
||||
# replace full-width with half-width
|
||||
s = s.replace("“", '"')
|
||||
s = s.replace("”", '"')
|
||||
s = s.replace("‘", "'")
|
||||
s = s.replace("’", "'")
|
||||
if s[:2] == '" ': # remove the starting double quote
|
||||
s = s[2:]
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def ref_text_normalization(ref_text: str) -> str:
|
||||
# Rule 1: Remove the [FN#[]]
|
||||
p = r"[FN#[0-9]*]"
|
||||
pattern = re.compile(p)
|
||||
|
||||
res = pattern.findall(ref_text)
|
||||
ref_text = re.sub(p, "", ref_text)
|
||||
|
||||
ref_text = train_text_normalization(ref_text)
|
||||
|
||||
return ref_text
|
||||
|
||||
|
||||
def remove_non_alphabetic(text: str, strict: bool = True) -> str:
|
||||
# Recommend to set strict to False
|
||||
if not strict:
|
||||
# Note, this also keeps space, single quote(') and hypen (-)
|
||||
text = text.replace("-", " ")
|
||||
text = text.replace("—", " ")
|
||||
return re.sub(r"[^a-zA-Z0-9\s']+", "", text)
|
||||
else:
|
||||
# only keeps space
|
||||
return re.sub(r"[^a-zA-Z\s]+", "", text)
|
||||
|
||||
|
||||
def upper_only_alpha(text: str) -> str:
|
||||
return remove_non_alphabetic(text.upper(), strict=False)
|
||||
|
||||
|
||||
def lower_only_alpha(text: str) -> str:
|
||||
return remove_non_alphabetic(text.lower(), strict=False)
|
||||
|
||||
|
||||
def lower_all_char(text: str) -> str:
|
||||
return text.lower()
|
||||
|
||||
|
||||
def upper_all_char(text: str) -> str:
|
||||
return text.upper()
|
||||
|
||||
|
||||
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
||||
"""Apply transform to a list of text. By default, the text are in
|
||||
ground truth format, i.e mixed-punc.
|
||||
|
||||
Args:
|
||||
text (List[str]): Input text string
|
||||
transform (str): Transform to be applied
|
||||
|
||||
Returns:
|
||||
List[str]: _description_
|
||||
"""
|
||||
if transform == "mixed-punc":
|
||||
return text
|
||||
elif transform == "upper-no-punc":
|
||||
return [upper_only_alpha(s) for s in text]
|
||||
elif transform == "lower-no-punc":
|
||||
return [lower_only_alpha(s) for s in text]
|
||||
elif transform == "lower-punc":
|
||||
return [lower_all_char(s) for s in text]
|
||||
else:
|
||||
raise NotImplementedError(f"Unseen transform: {transform}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ref_text = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
|
||||
print(ref_text)
|
||||
res = upper_only_alpha(ref_text)
|
||||
print(res)
|
1390
egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py
Normal file
1390
egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py
Normal file
File diff suppressed because it is too large
Load Diff
1798
egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py
Executable file
1798
egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py
Executable file
File diff suppressed because it is too large
Load Diff
515
egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_bert.py
Normal file
515
egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_bert.py
Normal file
@ -0,0 +1,515 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
|
||||
python ./zipformer_prompt_asr/transcribe_bert.py \
|
||||
--epoch 50 \
|
||||
--avg 10 \
|
||||
--exp-dir ./zipformer_prompt_asr/exp \
|
||||
--manifest-dir data/long_audios/long_audio.jsonl.gz \
|
||||
--pre-text-transform mixed-punc \
|
||||
--style-text-transform mixed-punc \
|
||||
--num-history 5 \
|
||||
--use-pre-text True \
|
||||
--use-gt-pre-text False
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from decode_bert import _apply_style_transform
|
||||
from lhotse import Fbank, load_manifest
|
||||
from text_normalization import (
|
||||
lower_all_char,
|
||||
lower_only_alpha,
|
||||
ref_text_normalization,
|
||||
remove_non_alphabetic,
|
||||
train_text_normalization,
|
||||
upper_all_char,
|
||||
upper_only_alpha,
|
||||
)
|
||||
from tqdm import tqdm
|
||||
from train_bert_encoder import (
|
||||
_encode_texts_as_bytes_with_tokenizer,
|
||||
add_model_arguments,
|
||||
get_params,
|
||||
get_tokenizer,
|
||||
get_transducer_model,
|
||||
)
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--manifest-dir",
|
||||
type=str,
|
||||
default="data/long_audios/long_audio.jsonl.gz",
|
||||
help="""This is the manfiest for long audio transcription.
|
||||
The cust are intended to be sorted, i.e first sort by recording ID and
|
||||
then sort by start timestamp""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-pre-text",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether use pre-text when decoding the current chunk",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-style-prompt",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Use style prompt when evaluation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pre-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of content prompt, i.e pre_text",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--style-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of style prompt, i.e style_text",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-history",
|
||||
type=int,
|
||||
default=2,
|
||||
help="How many previous chunks to look if using pre-text for decoding",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-gt-pre-text",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether use gt pre text when using content prompt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--post-normalization",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
params.res_dir = params.exp_dir / "long_audio_transcribe"
|
||||
params.res_dir.mkdir(exist_ok=True)
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "beam_search" in params.method:
|
||||
params.suffix += f"-{params.method}-beam-size-{params.beam_size}"
|
||||
|
||||
if params.use_pre_text:
|
||||
if params.use_gt_pre_text:
|
||||
params.suffix += f"-use-gt-pre-text-{params.pre_text_transform}-history-{params.num_history}"
|
||||
else:
|
||||
params.suffix += (
|
||||
f"-pre-text-{params.pre_text_transform}-history-{params.num_history}"
|
||||
)
|
||||
|
||||
book_name = params.manifest_dir.split("/")[-1].replace(".jsonl.gz", "")
|
||||
setup_logger(
|
||||
f"{params.res_dir}/log-decode-{book_name}-{params.suffix}", log_level="info"
|
||||
)
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
tokenizer = get_tokenizer(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
# load manifest
|
||||
manifest = load_manifest(params.manifest_dir)
|
||||
|
||||
results = []
|
||||
count = 0
|
||||
|
||||
last_recording = ""
|
||||
last_end = -1
|
||||
history = []
|
||||
num_pre_texts = []
|
||||
|
||||
for cut in manifest:
|
||||
if cut.has_features:
|
||||
feat = cut.load_features()
|
||||
feat_lens = cut.num_frames
|
||||
else:
|
||||
feat = cut.compute_features(extractor=Fbank())
|
||||
feat_lens = feat.shape[0]
|
||||
|
||||
cur_recording = cut.recording.id
|
||||
|
||||
if cur_recording != last_recording:
|
||||
last_recording = cur_recording
|
||||
history = [] # clean up the history
|
||||
last_end = -1
|
||||
logging.info("Moving on to the next recording")
|
||||
else:
|
||||
if cut.start < last_end - 0.2: # overlap with the previous cuts
|
||||
logging.warning("An overlap exists between current cut and last cut")
|
||||
logging.warning("Skipping this cut!")
|
||||
continue
|
||||
if cut.start > last_end + 10:
|
||||
logging.warning(
|
||||
f"Large time gap between the current and previous utterance: {cut.start - last_end}."
|
||||
)
|
||||
|
||||
# prepare input
|
||||
x = torch.tensor(feat, device=device).unsqueeze(0)
|
||||
x_lens = torch.tensor(
|
||||
[
|
||||
feat_lens,
|
||||
],
|
||||
device=device,
|
||||
)
|
||||
|
||||
if params.use_pre_text:
|
||||
if params.num_history > 0:
|
||||
pre_texts = history[-params.num_history :]
|
||||
else:
|
||||
pre_texts = []
|
||||
num_pre_texts.append(len(pre_texts))
|
||||
pre_texts = [train_text_normalization(" ".join(pre_texts))]
|
||||
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
|
||||
style_texts = [fixed_sentence]
|
||||
|
||||
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
||||
if params.use_style_prompt:
|
||||
style_texts = _apply_style_transform(
|
||||
style_texts, params.style_text_transform
|
||||
)
|
||||
|
||||
# encode prompts
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer(
|
||||
pre_texts=pre_texts,
|
||||
style_texts=style_texts,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
no_limit=True,
|
||||
)
|
||||
if params.num_history > 5:
|
||||
logging.info(
|
||||
f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} "
|
||||
)
|
||||
|
||||
memory, memory_key_padding_mask = model.encode_text(
|
||||
encoded_inputs=encoded_inputs,
|
||||
style_lens=style_lens,
|
||||
) # (T,B,C)
|
||||
else:
|
||||
memory = None
|
||||
memory_key_padding_mask = None
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
encoder_out, encoder_out_lens = model.encode_audio(
|
||||
feature=x,
|
||||
feature_lens=x_lens,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
|
||||
if params.method == "greedy_search":
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
|
||||
hyp = sp.decode(hyp_tokens)[0] # in string format
|
||||
ref_text = ref_text_normalization(
|
||||
cut.supervisions[0].texts[0]
|
||||
) # required to match the training
|
||||
|
||||
# extend the history
|
||||
if params.use_gt_pre_text:
|
||||
history.append(ref_text)
|
||||
else:
|
||||
history.append(hyp)
|
||||
last_end = cut.end # update the last end timestamp
|
||||
|
||||
# append the current decoding result
|
||||
hyp = hyp.split()
|
||||
ref = ref_text.split()
|
||||
results.append((cut.id, ref, hyp))
|
||||
|
||||
count += 1
|
||||
if count % 100 == 0:
|
||||
logging.info(f"Cuts processed until now: {count}/{len(manifest)}")
|
||||
logging.info(
|
||||
f"Averaged context numbers of last 100 samples is: {sum(num_pre_texts[-100:])/100}"
|
||||
)
|
||||
|
||||
logging.info(f"A total of {count} cuts")
|
||||
logging.info(
|
||||
f"Averaged context numbers of whole set is: {sum(num_pre_texts)/len(num_pre_texts)}"
|
||||
)
|
||||
|
||||
results = sorted(results)
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}.txt"
|
||||
)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f,
|
||||
f"long-audio-{params.method}",
|
||||
results,
|
||||
enable_log=True,
|
||||
compute_CER=False,
|
||||
)
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
if params.post_normalization:
|
||||
params.suffix += "-post-normalization"
|
||||
|
||||
new_res = []
|
||||
for item in results:
|
||||
id, ref, hyp = item
|
||||
hyp = upper_only_alpha(" ".join(hyp)).split()
|
||||
ref = upper_only_alpha(" ".join(ref)).split()
|
||||
new_res.append((id, ref, hyp))
|
||||
|
||||
new_res = sorted(new_res)
|
||||
recog_path = (
|
||||
params.res_dir
|
||||
/ f"recogs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
|
||||
)
|
||||
store_transcripts(filename=recog_path, texts=new_res)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
errs_filename = (
|
||||
params.res_dir
|
||||
/ f"errs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f,
|
||||
f"long-audio-{params.method}",
|
||||
new_res,
|
||||
enable_log=True,
|
||||
compute_CER=False,
|
||||
)
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
439
egs/libriheavy/ASR/zipformer_prompt_asr/utils.py
Normal file
439
egs/libriheavy/ASR/zipformer_prompt_asr/utils.py
Normal file
@ -0,0 +1,439 @@
|
||||
import argparse
|
||||
import ast
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
||||
|
||||
import kaldialign
|
||||
from lhotse import load_manifest, load_manifest_lazy
|
||||
from lhotse.cut import Cut, CutSet
|
||||
from text_normalization import remove_non_alphabetic
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--manifest-dir",
|
||||
type=str,
|
||||
default="data/fbank",
|
||||
help="Where are the manifest stored",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--subset", type=str, default="medium", help="Which subset to work with"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=10000,
|
||||
help="How many words to keep",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_facebook_biasing_list(
|
||||
test_set: str,
|
||||
num_distractors: int = 100,
|
||||
) -> Dict:
|
||||
# Get the biasing list from the meta paper: https://arxiv.org/pdf/2104.02194.pdf
|
||||
assert num_distractors in (0, 100, 500, 1000, 2000), num_distractors
|
||||
if num_distractors == 0:
|
||||
if test_set == "test-clean":
|
||||
biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_100.tsv"
|
||||
elif test_set == "test-other":
|
||||
biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_100.tsv"
|
||||
else:
|
||||
raise ValueError(f"Unseen test set {test_set}")
|
||||
else:
|
||||
if test_set == "test-clean":
|
||||
biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_{num_distractors}.tsv"
|
||||
elif test_set == "test-other":
|
||||
biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_{num_distractors}.tsv"
|
||||
else:
|
||||
raise ValueError(f"Unseen test set {test_set}")
|
||||
|
||||
f = open(biasing_file, "r")
|
||||
data = f.readlines()
|
||||
f.close()
|
||||
|
||||
output = dict()
|
||||
for line in data:
|
||||
id, _, l1, l2 = line.split("\t")
|
||||
if num_distractors > 0: # use distractors
|
||||
biasing_list = ast.literal_eval(l2)
|
||||
else:
|
||||
biasing_list = ast.literal_eval(l1)
|
||||
biasing_list = [w.strip().upper() for w in biasing_list]
|
||||
output[id] = " ".join(biasing_list)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def brian_biasing_list(level: str):
|
||||
# The biasing list from Brian's paper: https://arxiv.org/pdf/2109.00627.pdf
|
||||
root_dir = f"data/context_biasing/LibriSpeechBiasingLists/{level}Level"
|
||||
all_files = glob.glob(root_dir + "/*")
|
||||
biasing_dict = {}
|
||||
for f in all_files:
|
||||
k = f.split("/")[-1]
|
||||
fin = open(f, "r")
|
||||
data = fin.read().strip().split()
|
||||
biasing_dict[k] = " ".join(data)
|
||||
fin.close()
|
||||
|
||||
return biasing_dict
|
||||
|
||||
|
||||
def get_rare_words(
|
||||
subset: str = "medium",
|
||||
top_k: int = 10000,
|
||||
# min_count: int = 10000,
|
||||
):
|
||||
"""Get a list of rare words appearing less than `min_count` times
|
||||
|
||||
Args:
|
||||
subset: The dataset
|
||||
top_k (int): How many frequent words
|
||||
"""
|
||||
txt_path = f"data/tmp/transcript_words_{subset}.txt"
|
||||
rare_word_file = f"data/context_biasing/{subset}_rare_words_topk_{top_k}.txt"
|
||||
|
||||
if os.path.exists(rare_word_file):
|
||||
print("File exists, do not proceed!")
|
||||
return
|
||||
|
||||
print("---Identifying rare words in the manifest---")
|
||||
count_file = f"data/tmp/transcript_words_{subset}_count.txt"
|
||||
if not os.path.exists(count_file):
|
||||
with open(txt_path, "r") as file:
|
||||
words = file.read().upper().split()
|
||||
word_count = {}
|
||||
for word in words:
|
||||
word = remove_non_alphabetic(word, strict=False)
|
||||
word = word.split()
|
||||
for w in word:
|
||||
if w not in word_count:
|
||||
word_count[w] = 1
|
||||
else:
|
||||
word_count[w] += 1
|
||||
|
||||
word_count = list(word_count.items()) # convert to a list of tuple
|
||||
word_count = sorted(word_count, key=lambda w: int(w[1]), reverse=True)
|
||||
with open(count_file, "w") as fout:
|
||||
for w, count in word_count:
|
||||
fout.write(f"{w}\t{count}\n")
|
||||
|
||||
else:
|
||||
word_count = {}
|
||||
with open(count_file, "r") as fin:
|
||||
word_count = fin.read().strip().split("\n")
|
||||
word_count = [pair.split("\t") for pair in word_count]
|
||||
word_count = sorted(word_count, key=lambda w: int(w[1]), reverse=True)
|
||||
|
||||
print(f"A total of {len(word_count)} words appeared!")
|
||||
rare_words = []
|
||||
for word, count in word_count[top_k:]:
|
||||
rare_words.append(word + "\n")
|
||||
print(f"A total of {len(rare_words)} are identified as rare words.")
|
||||
|
||||
with open(rare_word_file, "w") as f:
|
||||
f.writelines(rare_words)
|
||||
|
||||
|
||||
def add_context_list_to_manifest(
|
||||
manifest_dir: str,
|
||||
subset: str = "medium",
|
||||
top_k: int = 10000,
|
||||
):
|
||||
"""Generate a context list of rare words for each utterance in the manifest
|
||||
|
||||
Args:
|
||||
manifest_dir: Where to store the manifest with context list
|
||||
subset (str): Subset
|
||||
top_k (int): How many frequent words
|
||||
|
||||
"""
|
||||
orig_manifest_dir = f"{manifest_dir}/libriheavy_cuts_{subset}.jsonl.gz"
|
||||
target_manifest_dir = orig_manifest_dir.replace(
|
||||
".jsonl.gz", f"_with_context_list_topk_{top_k}.jsonl.gz"
|
||||
)
|
||||
if os.path.exists(target_manifest_dir):
|
||||
print(f"Target file exits at {target_manifest_dir}!")
|
||||
return
|
||||
|
||||
rare_words_file = f"data/context_biasing/{subset}_rare_words_topk_{top_k}.txt"
|
||||
print(f"---Reading rare words from {rare_words_file}---")
|
||||
with open(rare_words_file, "r") as f:
|
||||
rare_words = f.read()
|
||||
rare_words = rare_words.split("\n")
|
||||
rare_words = set(rare_words)
|
||||
print(f"A total of {len(rare_words)} rare words!")
|
||||
|
||||
cuts = load_manifest_lazy(orig_manifest_dir)
|
||||
print(f"Loaded manifest from {orig_manifest_dir}")
|
||||
|
||||
def _add_context(c: Cut):
|
||||
splits = (
|
||||
remove_non_alphabetic(c.supervisions[0].texts[0], strict=False)
|
||||
.upper()
|
||||
.split()
|
||||
)
|
||||
found = []
|
||||
for w in splits:
|
||||
if w in rare_words:
|
||||
found.append(w)
|
||||
c.supervisions[0].context_list = " ".join(found)
|
||||
return c
|
||||
|
||||
cuts = cuts.map(_add_context)
|
||||
print(f"---Saving manifest with context list to {target_manifest_dir}---")
|
||||
cuts.to_file(target_manifest_dir)
|
||||
print("Finished")
|
||||
|
||||
|
||||
def check(
|
||||
manifest_dir: str,
|
||||
subset: str = "medium",
|
||||
top_k: int = 10000,
|
||||
):
|
||||
# Show how many samples in the training set have a context list
|
||||
# and the average length of context list
|
||||
print("--- Calculating the stats over the manifest ---")
|
||||
|
||||
manifest_dir = f"{manifest_dir}/libriheavy_cuts_{subset}_with_context_list_topk_{top_k}.jsonl.gz"
|
||||
cuts = load_manifest_lazy(manifest_dir)
|
||||
total_cuts = len(cuts)
|
||||
has_context_list = [c.supervisions[0].context_list != "" for c in cuts]
|
||||
context_list_len = [len(c.supervisions[0].context_list.split()) for c in cuts]
|
||||
print(f"{sum(has_context_list)}/{total_cuts} cuts have context list! ")
|
||||
print(
|
||||
f"Average length of non-empty context list is {sum(context_list_len)/sum(has_context_list)}"
|
||||
)
|
||||
|
||||
|
||||
def write_error_stats(
|
||||
f: TextIO,
|
||||
test_set_name: str,
|
||||
results: List[Tuple[str, str]],
|
||||
enable_log: bool = True,
|
||||
compute_CER: bool = False,
|
||||
biasing_words: List[str] = None,
|
||||
) -> float:
|
||||
"""Write statistics based on predicted results and reference transcripts. It also calculates the
|
||||
biasing word error rate as described in https://arxiv.org/pdf/2104.02194.pdf
|
||||
|
||||
It will write the following to the given file:
|
||||
|
||||
- WER
|
||||
- number of insertions, deletions, substitutions, corrects and total
|
||||
reference words. For example::
|
||||
|
||||
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
||||
reference words (2337 correct)
|
||||
|
||||
- The difference between the reference transcript and predicted result.
|
||||
An instance is given below::
|
||||
|
||||
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
||||
|
||||
The above example shows that the reference word is `EDISON`,
|
||||
but it is predicted to `ADDISON` (a substitution error).
|
||||
|
||||
Another example is::
|
||||
|
||||
FOR THE FIRST DAY (SIR->*) I THINK
|
||||
|
||||
The reference word `SIR` is missing in the predicted
|
||||
results (a deletion error).
|
||||
results:
|
||||
An iterable of tuples. The first element is the cut_id, the second is
|
||||
the reference transcript and the third element is the predicted result.
|
||||
enable_log:
|
||||
If True, also print detailed WER to the console.
|
||||
Otherwise, it is written only to the given file.
|
||||
biasing_words:
|
||||
All the words in the biasing list
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
||||
ins: Dict[str, int] = defaultdict(int)
|
||||
dels: Dict[str, int] = defaultdict(int)
|
||||
|
||||
# `words` stores counts per word, as follows:
|
||||
# corr, ref_sub, hyp_sub, ins, dels
|
||||
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
||||
num_corr = 0
|
||||
ERR = "*"
|
||||
|
||||
if compute_CER:
|
||||
for i, res in enumerate(results):
|
||||
cut_id, ref, hyp = res
|
||||
ref = list("".join(ref))
|
||||
hyp = list("".join(hyp))
|
||||
results[i] = (cut_id, ref, hyp)
|
||||
|
||||
for cut_id, ref, hyp in results:
|
||||
ali = kaldialign.align(ref, hyp, ERR)
|
||||
for ref_word, hyp_word in ali:
|
||||
if ref_word == ERR: # INSERTION
|
||||
ins[hyp_word] += 1
|
||||
words[hyp_word][3] += 1
|
||||
elif hyp_word == ERR: # DELETION
|
||||
dels[ref_word] += 1
|
||||
words[ref_word][4] += 1
|
||||
elif hyp_word != ref_word: # SUBSTITUTION
|
||||
subs[(ref_word, hyp_word)] += 1
|
||||
words[ref_word][1] += 1
|
||||
words[hyp_word][2] += 1
|
||||
else:
|
||||
words[ref_word][0] += 1
|
||||
num_corr += 1
|
||||
ref_len = sum([len(r) for _, r, _ in results])
|
||||
sub_errs = sum(subs.values())
|
||||
ins_errs = sum(ins.values())
|
||||
del_errs = sum(dels.values())
|
||||
tot_errs = sub_errs + ins_errs + del_errs
|
||||
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
||||
|
||||
if enable_log:
|
||||
logging.info(
|
||||
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
||||
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
||||
f"{del_errs} del, {sub_errs} sub ]"
|
||||
)
|
||||
|
||||
print(f"%WER = {tot_err_rate}", file=f)
|
||||
print(
|
||||
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
||||
f"{sub_errs} substitutions, over {ref_len} reference "
|
||||
f"words ({num_corr} correct)",
|
||||
file=f,
|
||||
)
|
||||
print(
|
||||
"Search below for sections starting with PER-UTT DETAILS:, "
|
||||
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
||||
file=f,
|
||||
)
|
||||
|
||||
print("", file=f)
|
||||
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
||||
for cut_id, ref, hyp in results:
|
||||
ali = kaldialign.align(ref, hyp, ERR)
|
||||
combine_successive_errors = True
|
||||
if combine_successive_errors:
|
||||
ali = [[[x], [y]] for x, y in ali]
|
||||
for i in range(len(ali) - 1):
|
||||
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
||||
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
||||
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
||||
ali[i] = [[], []]
|
||||
ali = [
|
||||
[
|
||||
list(filter(lambda a: a != ERR, x)),
|
||||
list(filter(lambda a: a != ERR, y)),
|
||||
]
|
||||
for x, y in ali
|
||||
]
|
||||
ali = list(filter(lambda x: x != [[], []], ali))
|
||||
ali = [
|
||||
[
|
||||
ERR if x == [] else " ".join(x),
|
||||
ERR if y == [] else " ".join(y),
|
||||
]
|
||||
for x, y in ali
|
||||
]
|
||||
|
||||
print(
|
||||
f"{cut_id}:\t"
|
||||
+ " ".join(
|
||||
(
|
||||
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
|
||||
for ref_word, hyp_word in ali
|
||||
)
|
||||
),
|
||||
file=f,
|
||||
)
|
||||
|
||||
print("", file=f)
|
||||
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
||||
|
||||
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
|
||||
print(f"{count} {ref} -> {hyp}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("DELETIONS: count ref", file=f)
|
||||
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
||||
print(f"{count} {ref}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("INSERTIONS: count hyp", file=f)
|
||||
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
||||
print(f"{count} {hyp}", file=f)
|
||||
|
||||
unbiased_word_counts = 0
|
||||
unbiased_word_errs = 0
|
||||
biased_word_counts = 0
|
||||
biased_word_errs = 0
|
||||
|
||||
print("", file=f)
|
||||
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
|
||||
|
||||
for _, word, counts in sorted(
|
||||
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
||||
):
|
||||
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
||||
tot_errs = ref_sub + hyp_sub + ins + dels
|
||||
# number of appearances of "word" in reference text
|
||||
ref_count = (
|
||||
corr + ref_sub + dels
|
||||
) # correct + in ref but got substituted + deleted
|
||||
# number of appearances of "word" in hyp text
|
||||
hyp_count = corr + hyp_sub + ins
|
||||
|
||||
if biasing_words is not None:
|
||||
if word in biasing_words:
|
||||
biased_word_counts += ref_count
|
||||
biased_word_errs += ins + dels + ref_sub
|
||||
else:
|
||||
unbiased_word_counts += ref_count
|
||||
unbiased_word_errs += ins + dels + hyp_sub
|
||||
|
||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||
|
||||
if biasing_words is not None:
|
||||
B_WER = "%.2f" % (100 * biased_word_errs / biased_word_counts)
|
||||
U_WER = "%.2f" % (100 * unbiased_word_errs / unbiased_word_counts)
|
||||
logging.info(f"Biased WER: {B_WER} [{biased_word_errs}/{biased_word_counts}] ")
|
||||
logging.info(
|
||||
f"Un-biased WER: {U_WER} [{unbiased_word_errs}/{unbiased_word_counts}]"
|
||||
)
|
||||
|
||||
return float(tot_err_rate)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
manifest_dir = args.manifest_dir
|
||||
subset = args.subset
|
||||
top_k = args.top_k
|
||||
get_rare_words(subset=subset, top_k=top_k)
|
||||
add_context_list_to_manifest(
|
||||
manifest_dir=manifest_dir,
|
||||
subset=subset,
|
||||
top_k=top_k,
|
||||
)
|
||||
check(
|
||||
manifest_dir=manifest_dir,
|
||||
subset=subset,
|
||||
top_k=top_k,
|
||||
)
|
2310
egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py
Normal file
2310
egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -483,7 +483,7 @@ def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]:
|
||||
|
||||
|
||||
def store_transcripts(
|
||||
filename: Pathlike, texts: Iterable[Tuple[str, str, str]]
|
||||
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
|
||||
) -> None:
|
||||
"""Save predicted results and reference transcripts to a file.
|
||||
|
||||
@ -500,6 +500,9 @@ def store_transcripts(
|
||||
"""
|
||||
with open(filename, "w") as f:
|
||||
for cut_id, ref, hyp in texts:
|
||||
if char_level:
|
||||
ref = list("".join(ref))
|
||||
hyp = list("".join(hyp))
|
||||
print(f"{cut_id}:\tref={ref}", file=f)
|
||||
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||
|
||||
@ -557,6 +560,7 @@ def write_error_stats(
|
||||
test_set_name: str,
|
||||
results: List[Tuple[str, str]],
|
||||
enable_log: bool = True,
|
||||
compute_CER: bool = False,
|
||||
sclite_mode: bool = False,
|
||||
) -> float:
|
||||
"""Write statistics based on predicted results and reference transcripts.
|
||||
@ -585,7 +589,7 @@ def write_error_stats(
|
||||
The reference word `SIR` is missing in the predicted
|
||||
results (a deletion error).
|
||||
results:
|
||||
An iterable of tuples. The first element is the cur_id, the second is
|
||||
An iterable of tuples. The first element is the cut_id, the second is
|
||||
the reference transcript and the third element is the predicted result.
|
||||
enable_log:
|
||||
If True, also print detailed WER to the console.
|
||||
@ -602,6 +606,14 @@ def write_error_stats(
|
||||
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
||||
num_corr = 0
|
||||
ERR = "*"
|
||||
|
||||
if compute_CER:
|
||||
for i, res in enumerate(results):
|
||||
cut_id, ref, hyp = res
|
||||
ref = list("".join(ref))
|
||||
hyp = list("".join(hyp))
|
||||
results[i] = (cut_id, ref, hyp)
|
||||
|
||||
for cut_id, ref, hyp in results:
|
||||
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
|
||||
for ref_word, hyp_word in ali:
|
||||
@ -1426,7 +1438,10 @@ def measure_gradient_norms(model: nn.Module, norm: str = "l1") -> Dict[str, floa
|
||||
|
||||
|
||||
def get_parameter_groups_with_lrs(
|
||||
model: nn.Module, lr: float, include_names: bool = False
|
||||
model: nn.Module,
|
||||
lr: float,
|
||||
include_names: bool = False,
|
||||
freeze_modules: List[str] = [],
|
||||
) -> List[dict]:
|
||||
"""
|
||||
This is for use with the ScaledAdam optimizers (more recent versions that accept lists of
|
||||
@ -1450,6 +1465,8 @@ def get_parameter_groups_with_lrs(
|
||||
... ]
|
||||
|
||||
"""
|
||||
named_modules = list(model.named_modules())
|
||||
|
||||
# flat_lr_scale just contains the lr_scale explicitly specified
|
||||
# for each prefix of the name, e.g. 'encoder.layers.3', these need
|
||||
# to be multiplied for all prefix of the name of any given parameter.
|
||||
@ -1469,6 +1486,15 @@ def get_parameter_groups_with_lrs(
|
||||
split_name = name.split(".")
|
||||
# caution: as a special case, if the name is '', split_name will be [ '' ].
|
||||
prefix = split_name[0]
|
||||
if prefix == "module": # DDP
|
||||
module_name = split_name[1]
|
||||
if module_name in freeze_modules:
|
||||
logging.info(f"Remove {name} from parameters")
|
||||
continue
|
||||
else:
|
||||
if prefix in freeze_modules:
|
||||
logging.info(f"Remove {name} from parameters")
|
||||
continue
|
||||
cur_lr = lr * flat_lr_scale[prefix]
|
||||
if prefix != "":
|
||||
cur_lr *= flat_lr_scale[""]
|
||||
|
Loading…
x
Reference in New Issue
Block a user