mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
Whisper large fine-tuning on wenetspeech, mutli-hans-zh (#1483)
* add whisper fbank for wenetspeech * add whisper fbank for other dataset * add str to bool * add decode for wenetspeech * add requirments.txt * add original model decode with 30s * test feature extractor speed * add aishell2 feat * change compute feature batch * fix overwrite * fix executor * regression * add kaldifeatwhisper fbank * fix io issue * parallel jobs * use multi machines * add wenetspeech fine-tune scripts * add monkey patch codes * remove useless file * fix subsampling factor * fix too long audios * add remove long short * fix whisper version to support multi batch beam * decode all wav files * remove utterance more than 30s in test_net * only test net * using soft links * add kespeech whisper feats * fix index error * add manifests for whisper * change to licomchunky writer * add missing option * decrease cpu usage * add speed perturb for kespeech * fix kespeech speed perturb * add dataset * load checkpoint from specific path * add speechio * add speechio results --------- Co-authored-by: zr_jin <peter.jin.cn@gmail.com>
This commit is contained in:
parent
cdb3fb5675
commit
5df24c1685
@ -75,7 +75,7 @@ It's reworked Zipformer with Pruned RNNT loss, trained with Byte-level BPE, `voc
|
||||
| fast beam search | 4.43 | 4.17 | --epoch 40 --avg 10 |
|
||||
|
||||
```bash
|
||||
./prepare.sh
|
||||
./prepare.sh
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1"
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
## Results
|
||||
|
||||
### Aishell2 char-based training results
|
||||
### Aishell2 char-based training results
|
||||
|
||||
#### Pruned transducer stateless 5
|
||||
|
||||
|
||||
@ -29,7 +29,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -42,10 +49,12 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
def compute_fbank_aishell2(
|
||||
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
|
||||
dataset_parts = (
|
||||
"train",
|
||||
@ -68,8 +77,12 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
@ -82,7 +95,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
logging.info("Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -111,7 +124,12 @@ def get_args():
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -122,5 +140,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_aishell2(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
perturb_speed=args.perturb_speed,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
||||
@ -108,6 +108,16 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
whisper_mel_bins=80
|
||||
if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
|
||||
log "Stage 30: Compute whisper fbank for aishell2"
|
||||
if [ ! -f data/fbank/.aishell2.whisper.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aishell2.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
touch data/fbank/.aishell2.whisper.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Compute fbank for musan"
|
||||
if [ ! -f data/fbank/.msuan.done ]; then
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
This recipe contains some various ASR models trained with Aishell4 (including S, M and L three subsets).
|
||||
|
||||
The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks.
|
||||
The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks.
|
||||
|
||||
(From [Open Speech and Language Resources](https://www.openslr.org/111/))
|
||||
|
||||
|
||||
@ -29,7 +29,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -42,10 +49,12 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
def compute_fbank_aishell4(
|
||||
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/aishell4")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
|
||||
dataset_parts = (
|
||||
"train_S",
|
||||
@ -70,7 +79,12 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
@ -84,7 +98,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
logging.info("Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -95,7 +109,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
# when an executor is specified, make more partitions
|
||||
num_jobs=num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=ChunkedLilcomHdf5Writer,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
|
||||
logging.info("About splitting cuts into smaller chunks")
|
||||
@ -121,7 +135,12 @@ def get_args():
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -132,5 +151,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_aishell4(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
perturb_speed=args.perturb_speed,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
||||
@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
set -eou pipefail
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
stop_stage=7
|
||||
perturb_speed=true
|
||||
|
||||
|
||||
@ -76,11 +76,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Process aishell4"
|
||||
log "Stage 2: Compute fbank for aishell4"
|
||||
if [ ! -f data/fbank/aishell4/.fbank.done ]; then
|
||||
mkdir -p data/fbank/aishell4
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
|
||||
touch data/fbank/aishell4/.fbank.done
|
||||
touch data/fbank/.fbank.done
|
||||
fi
|
||||
fi
|
||||
|
||||
whisper_mel_bins=80
|
||||
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||
log "Stage 20: Compute whisper fbank for aishell4"
|
||||
if [ ! -f data/fbank/aishell4/.fbank.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
touch data/fbank/.fbank.done
|
||||
fi
|
||||
fi
|
||||
|
||||
@ -106,16 +116,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Compute fbank for aishell4"
|
||||
if [ ! -f data/fbank/.aishell4.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
|
||||
touch data/fbank/.aishell4.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Prepare char based lang"
|
||||
log "Stage 5: Prepare char based lang"
|
||||
lang_char_dir=data/lang_char
|
||||
mkdir -p $lang_char_dir
|
||||
|
||||
|
||||
@ -29,7 +29,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -42,10 +49,12 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
def compute_fbank_alimeeting(
|
||||
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/alimeeting")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
|
||||
dataset_parts = (
|
||||
"train",
|
||||
@ -53,7 +62,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
||||
"test",
|
||||
)
|
||||
|
||||
prefix = "alimeeting"
|
||||
prefix = "alimeeting-far"
|
||||
suffix = "jsonl.gz"
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
@ -70,7 +79,12 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
@ -83,7 +97,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
logging.info("Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -121,7 +135,12 @@ def get_args():
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use the Whisper Fbank feature extractor. Default: False.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -132,5 +151,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_alimeeting(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
perturb_speed=args.perturb_speed,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
||||
@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
set -eou pipefail
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
stop_stage=7
|
||||
perturb_speed=true
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
@ -66,10 +66,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Process alimeeting"
|
||||
if [ ! -f data/fbank/alimeeting/.fbank.done ]; then
|
||||
mkdir -p data/fbank/alimeeting
|
||||
log "Stage 2: compute fbank for alimeeting"
|
||||
if [ ! -f data/fbank/.fbank.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed}
|
||||
touch data/fbank/.fbank.done
|
||||
fi
|
||||
fi
|
||||
|
||||
whisper_mel_bins=80
|
||||
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||
log "Stage 20: compute whisper fbank for alimeeting"
|
||||
if [ ! -f data/fbank/.fbank.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
touch data/fbank/.fbank.done
|
||||
fi
|
||||
fi
|
||||
|
||||
@ -95,16 +106,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Compute fbank for alimeeting"
|
||||
if [ ! -f data/fbank/.alimeeting.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_alimeeting.py --perturb-speed True
|
||||
touch data/fbank/.alimeeting.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Prepare char based lang"
|
||||
log "Stage 5: Prepare char based lang"
|
||||
lang_char_dir=data/lang_char
|
||||
mkdir -p $lang_char_dir
|
||||
|
||||
|
||||
@ -36,4 +36,4 @@ This recipe includes scripts for training Zipformer model using multiple Chinese
|
||||
3. AliMeeting
|
||||
4. MagicData
|
||||
5. KeSpeech-ASR
|
||||
6. WeNetSpeech
|
||||
6. WeNetSpeech
|
||||
|
||||
@ -17,11 +17,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
@ -31,7 +41,28 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_kespeech_dev_test():
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-mel-bins",
|
||||
type=int,
|
||||
default=80,
|
||||
help="""The number of mel bins for Fbank""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def compute_fbank_kespeech_dev_test(args):
|
||||
in_out_dir = Path("data/fbank/kespeech")
|
||||
# number of workers in dataloader
|
||||
num_workers = 42
|
||||
@ -48,7 +79,12 @@ def compute_fbank_kespeech_dev_test():
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
if args.whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device=device)
|
||||
)
|
||||
else:
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
@ -86,7 +122,11 @@ def main():
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_kespeech_dev_test()
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
compute_fbank_kespeech_dev_test(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -28,10 +28,14 @@ from lhotse import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
set_audio_duration_mismatch_tolerance,
|
||||
set_caching_enabled,
|
||||
)
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
@ -88,6 +92,20 @@ def get_parser():
|
||||
default=-1,
|
||||
help="Stop processing pieces until this number (exclusive).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-mel-bins",
|
||||
type=int,
|
||||
default=80,
|
||||
help="""The number of mel bins for Fbank""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@ -111,14 +129,19 @@ def compute_fbank_kespeech_splits(args):
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
if args.whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance
|
||||
set_caching_enabled(False)
|
||||
for i in range(start, stop):
|
||||
idx = f"{i + 1}".zfill(num_digits)
|
||||
logging.info(f"Processing {idx}/{num_splits}")
|
||||
idx = f"{i}".zfill(num_digits)
|
||||
logging.info(f"Processing {i+1}/{num_splits}")
|
||||
|
||||
cuts_path = output_dir / f"kespeech-asr_cuts_{subset}.{idx}.jsonl.gz"
|
||||
if cuts_path.is_file():
|
||||
|
||||
@ -30,10 +30,17 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
from icefall.utils import get_executor, str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
@ -43,10 +50,33 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False):
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-mel-bins",
|
||||
type=int,
|
||||
default=80,
|
||||
help="""The number of mel bins for Fbank""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def compute_fbank_magicdata(
|
||||
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/magicdata")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(30, os.cpu_count())
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
|
||||
dataset_parts = ("train", "test", "dev")
|
||||
prefix = "magicdata"
|
||||
@ -66,7 +96,12 @@ def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False)
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
if args.whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
@ -107,7 +142,12 @@ def get_args():
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -118,5 +158,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_magicdata(
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
speed_perturb=args.speed_perturb,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
||||
@ -30,10 +30,17 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
from icefall.utils import get_executor, str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
@ -43,7 +50,9 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False):
|
||||
def compute_fbank_primewords(
|
||||
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/primewords")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
@ -66,7 +75,12 @@ def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
@ -108,6 +122,13 @@ def get_args():
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -118,5 +139,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_primewords(
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
speed_perturb=args.speed_perturb,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
||||
@ -30,10 +30,17 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
from icefall.utils import get_executor, str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
@ -43,7 +50,9 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False):
|
||||
def compute_fbank_stcmds(
|
||||
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/stcmds")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
@ -66,7 +75,12 @@ def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False):
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
@ -107,6 +121,12 @@ def get_args():
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -117,5 +137,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_stcmds(
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
speed_perturb=args.speed_perturb,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
||||
@ -30,10 +30,17 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
from icefall.utils import get_executor, str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
@ -43,7 +50,9 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False):
|
||||
def compute_fbank_thchs30(
|
||||
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/thchs30")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
@ -70,7 +79,12 @@ def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False):
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
@ -113,6 +127,12 @@ def get_args():
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -123,5 +143,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_thchs30(
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
speed_perturb=args.speed_perturb,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
||||
@ -60,7 +60,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
|
||||
if [ ! -f data/fbank/.thchs30.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_thchs30.py
|
||||
./local/compute_fbank_thchs30.py --speed-perturb true
|
||||
touch data/fbank/.thchs30.done
|
||||
fi
|
||||
fi
|
||||
@ -86,7 +86,7 @@ fi
|
||||
log "Dataset: AISHELL-2"
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Prepare AISHELL-2"
|
||||
if [ -e ../../aishell/ASR/data/fbank/.aishell2.done ]; then
|
||||
if [ -e ../../aishell2/ASR/data/fbank/.aishell2.done ]; then
|
||||
cd data/fbank
|
||||
ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_feats_train) .
|
||||
ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_feats_dev) .
|
||||
@ -95,30 +95,30 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts_dev.jsonl.gz) .
|
||||
ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts_test.jsonl.gz) .
|
||||
cd ../..
|
||||
else
|
||||
else
|
||||
log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
log "Dataset: AISHELL-4"
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Prepare AISHELL-4"
|
||||
if [ -e ../../aishell/ASR/data/fbank/.aishell4.done ]; then
|
||||
if [ -e ../../aishell4/ASR/data/fbank/.fbank.done ]; then
|
||||
cd data/fbank
|
||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_train) .
|
||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_dev) .
|
||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_test) .
|
||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_train_L) .
|
||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_train_M) .
|
||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_train_S) .
|
||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_L.jsonl.gz) .
|
||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_M.jsonl.gz) .
|
||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_S.jsonl.gz) .
|
||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_dev.jsonl.gz) .
|
||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_test.jsonl.gz) .
|
||||
cd ../..
|
||||
else
|
||||
else
|
||||
log "Abort! Please run ../../aishell4/ASR/prepare.sh --stage 3 --stop-stage 3"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
log "Dataset: ST-CMDS"
|
||||
@ -137,7 +137,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
|
||||
if [ ! -f data/fbank/.stcmds.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_stcmds.py
|
||||
./local/compute_fbank_stcmds.py --speed-perturb true
|
||||
touch data/fbank/.stcmds.done
|
||||
fi
|
||||
fi
|
||||
@ -151,15 +151,15 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
lhotse download primewords $dl_dir/primewords
|
||||
fi
|
||||
|
||||
if [ ! -f data/manifests/.stcmds.done ]; then
|
||||
if [ ! -f data/manifests/.primewords.done ]; then
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare stcmds $dl_dir/primewords data/manifests/primewords
|
||||
lhotse prepare primewords $dl_dir/primewords data/manifests/primewords
|
||||
touch data/manifests/.primewords.done
|
||||
fi
|
||||
|
||||
if [ ! -f data/fbank/.primewords.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_primewords.py
|
||||
./local/compute_fbank_primewords.py --speed-perturb true
|
||||
touch data/fbank/.primewords.done
|
||||
fi
|
||||
fi
|
||||
@ -180,7 +180,7 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
|
||||
if [ ! -f data/fbank/.magicdata.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_magicdata.py
|
||||
./local/compute_fbank_magicdata.py --speed-perturb true
|
||||
touch data/fbank/.magicdata.done
|
||||
fi
|
||||
fi
|
||||
@ -231,7 +231,7 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) .
|
||||
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET.jsonl.gz) .
|
||||
|
||||
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/L_split_1000) .
|
||||
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/L_split_${num_splits}) .
|
||||
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/*.lca) .
|
||||
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/) ./wenetspeech
|
||||
cd ../..
|
||||
@ -261,7 +261,7 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||
|
||||
if [ ! -f data/manifests/.kespeech.done ]; then
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare kespeech -j 16 $dl_dir/KeSpeech data/manifests/kespeech
|
||||
lhotse prepare kespeech -j 16 $dl_dir/KeSpeech data/manifests/kespeech
|
||||
touch data/manifests/.kespeech.done
|
||||
fi
|
||||
|
||||
@ -272,29 +272,29 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||
if [ ! -f data/fbank/.kespeech_preprocess_complete ]; then
|
||||
python3 ./local/preprocess_kespeech.py
|
||||
touch data/fbank/.kespeech_preprocess_complete
|
||||
fi
|
||||
|
||||
if [ -f data/fbank/.kespeech.train_phase1.split.${num_splits}.done ]; then
|
||||
fi
|
||||
|
||||
if [ ! -f data/fbank/.kespeech.train_phase1.split.${num_splits}.done ]; then
|
||||
log "Spliting KeSpeech train_phase1"
|
||||
lhotse split ${num_splits} \
|
||||
data/fbank/kespeech/kespeech-asr_cuts_train_phase1_raw.jsonl.gz \
|
||||
data/fbank/kespeech/train_phase1_split_${num_splits}
|
||||
touch data/fbank/.kespeech.train_phase1.split.${num_splits}.done
|
||||
fi
|
||||
|
||||
if [ -f data/fbank/.kespeech.train_phase2.split.${num_splits}.done ]; then
|
||||
|
||||
if [ ! -f data/fbank/.kespeech.train_phase2.split.${num_splits}.done ]; then
|
||||
log "Spliting KeSpeech train_phase2"
|
||||
lhotse split ${num_splits} \
|
||||
data/fbank/kespeech/kespeech-asr_cuts_train_phase2_raw.jsonl.gz \
|
||||
data/fbank/kespeech/train_phase2_split_${num_splits}
|
||||
touch data/fbank/.kespeech.train_phase2.split.${num_splits}.done
|
||||
fi
|
||||
|
||||
|
||||
log "Compute KeSpeech fbank for train_phase1"
|
||||
./local/compute_fbank_kespeech_splits.py --num-splits ${num_splits} --training-subset train_phase1
|
||||
./local/compute_fbank_kespeech_splits.py --speed-perturb true --num-splits ${num_splits} --training-subset train_phase1
|
||||
|
||||
log "Compute KeSpeech fbank for train_phase2"
|
||||
./local/compute_fbank_kespeech_splits.py --num-splits ${num_splits} --training-subset train_phase2
|
||||
./local/compute_fbank_kespeech_splits.py --speed-perturb true --num-splits ${num_splits} --training-subset train_phase2
|
||||
|
||||
log "Compute KeSpeech fbank for test/dev"
|
||||
./local/compute_fbank_kespeech_dev_test.py
|
||||
@ -303,13 +303,126 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
whisper_mel_bins=80
|
||||
if [ $stage -le 120 ] && [ $stop_stage -ge 120 ]; then
|
||||
log "Stage 120: Prepare KeSpeech for whisper"
|
||||
if [ ! -d $dl_dir/KeSpeech ]; then
|
||||
log "Abort! Please download KeSpeech first."
|
||||
log "KeSpeech download link: https://github.com/KeSpeech/KeSpeech"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f data/manifests/.kespeech.done ]; then
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare kespeech -j 8 $dl_dir/KeSpeech data/manifests/kespeech
|
||||
touch data/manifests/.kespeech.done
|
||||
fi
|
||||
|
||||
if [ ! -f data/fbank/.kespeech.done ]; then
|
||||
mkdir -p data/fbank
|
||||
|
||||
log "Preprocess KeSpeech manifest"
|
||||
if [ ! -f data/fbank/.kespeech_preprocess_complete ]; then
|
||||
python3 ./local/preprocess_kespeech.py --speed-perturb true
|
||||
touch data/fbank/.kespeech_preprocess_complete
|
||||
fi
|
||||
|
||||
if [ ! -f data/fbank/.kespeech.train_phase1.split.${num_splits}.done ]; then
|
||||
log "Spliting KeSpeech train_phase1"
|
||||
lhotse split ${num_splits} \
|
||||
data/fbank/kespeech/kespeech-asr_cuts_train_phase1_raw.jsonl.gz \
|
||||
data/fbank/kespeech/train_phase1_split_${num_splits}
|
||||
touch data/fbank/.kespeech.train_phase1.split.${num_splits}.done
|
||||
fi
|
||||
|
||||
if [ ! -f data/fbank/.kespeech.train_phase2.split.${num_splits}.done ]; then
|
||||
log "Spliting KeSpeech train_phase2"
|
||||
lhotse split ${num_splits} \
|
||||
data/fbank/kespeech/kespeech-asr_cuts_train_phase2_raw.jsonl.gz \
|
||||
data/fbank/kespeech/train_phase2_split_${num_splits}
|
||||
touch data/fbank/.kespeech.train_phase2.split.${num_splits}.done
|
||||
fi
|
||||
|
||||
log "Compute KeSpeech fbank for train_phase1"
|
||||
./local/compute_fbank_kespeech_splits.py --num-splits ${num_splits} --training-subset train_phase1 --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
|
||||
log "Compute KeSpeech fbank for train_phase2"
|
||||
./local/compute_fbank_kespeech_splits.py --num-splits ${num_splits} --training-subset train_phase2 --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
|
||||
log "Compute KeSpeech fbank for test/dev"
|
||||
# ./local/compute_fbank_kespeech_dev_test.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
|
||||
if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz ]; then
|
||||
pieces=$(find data/fbank/kespeech/train_phase1_split_${num_splits} -name "kespeech-asr_cuts_train_phase1.*.jsonl.gz")
|
||||
lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
|
||||
fi
|
||||
if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz ]; then
|
||||
pieces=$(find data/fbank/kespeech/train_phase2_split_${num_splits} -name "kespeech-asr_cuts_train_phase2.*.jsonl.gz")
|
||||
lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz
|
||||
fi
|
||||
touch data/fbank/.kespeech.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 121 ] && [ $stop_stage -ge 121 ]; then
|
||||
log "Stage 121: Prepare MagicData, Primewords, ST-CMDS, THCHS-30 for whisper"
|
||||
|
||||
if [ ! -f data/manifests/.magicdata.done ]; then
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare magicdata $dl_dir/magicdata data/manifests/magicdata
|
||||
touch data/manifests/.magicdata.done
|
||||
fi
|
||||
|
||||
if [ ! -f data/manifests/.primewords.done ]; then
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare primewords $dl_dir/primewords data/manifests/primewords
|
||||
touch data/manifests/.primewords.done
|
||||
fi
|
||||
if [ ! -f data/manifests/.stcmds.done ]; then
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare stcmds $dl_dir/stcmds data/manifests/stcmds
|
||||
touch data/manifests/.stcmds.done
|
||||
fi
|
||||
|
||||
if [ ! -f data/manifests/.thchs30.done ]; then
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare thchs-30 $dl_dir/thchs30 data/manifests/thchs30
|
||||
touch data/manifests/.thchs30.done
|
||||
fi
|
||||
|
||||
if [ ! -f data/fbank/.thchs30.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_thchs30.py --speed-perturb true --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
touch data/fbank/.thchs30.done
|
||||
fi
|
||||
|
||||
if [ ! -f data/fbank/.stcmds.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_stcmds.py --speed-perturb true --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
touch data/fbank/.stcmds.done
|
||||
fi
|
||||
if [ ! -f data/fbank/.magicdata.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_magicdata.py --speed-perturb true --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
touch data/fbank/.magicdata.done
|
||||
fi
|
||||
|
||||
if [ ! -f data/fbank/.primewords.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_primewords.py --speed-perturb true --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
touch data/fbank/.primewords.done
|
||||
fi
|
||||
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
|
||||
log "Stage 13: BPE model training (note that we use transcripts of wenetspeech only for BPE training)"
|
||||
./local/prepare_for_bpe_model.py --lang-dir ./data/lang_char --text ./data/lang_char/text
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
|
||||
|
||||
mkdir -p $lang_dir
|
||||
if [ ! -f $lang_dir/bpe.model ]; then
|
||||
./local/train_bpe_model.py \
|
||||
@ -329,7 +442,7 @@ if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
|
||||
--lexicon $lang_dir/lexicon.txt \
|
||||
--bpe-model $lang_dir/bpe.model
|
||||
fi
|
||||
|
||||
|
||||
if [ ! -f $lang_dir/L.fst ]; then
|
||||
log "Converting L.pt to L.fst"
|
||||
./shared/convert-k2-to-openfst.py \
|
||||
@ -350,7 +463,7 @@ fi
|
||||
|
||||
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
|
||||
log "Stage 14: Prepare G (note that we use ngram lm of wenetspeech only for G preparation)"
|
||||
|
||||
|
||||
if [ -d ../../wenetspeech/ASR/data/lang_char/ ]; then
|
||||
cd data
|
||||
ln -s ../../../../wenetspeech/ASR/data/lm .
|
||||
@ -369,5 +482,3 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
|
||||
python ./local/compile_lg.py --lang-dir $lang_dir
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
|
||||
1
egs/multi_zh-hans/ASR/whisper/asr_datamodule.py
Symbolic link
1
egs/multi_zh-hans/ASR/whisper/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
||||
../zipformer/asr_datamodule.py
|
||||
519
egs/multi_zh-hans/ASR/whisper/decode.py
Normal file
519
egs/multi_zh-hans/ASR/whisper/decode.py
Normal file
@ -0,0 +1,519 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
|
||||
# Fangjun Kuang,
|
||||
# Wei Kang)
|
||||
# 2024 Yuekai Zhang
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
# Command for decoding using fine-tuned models:
|
||||
git lfs install
|
||||
git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper
|
||||
ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt
|
||||
|
||||
python3 ./whisper/decode.py \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
--epoch 999 --avg 1 \
|
||||
--beam-size 10 --max-duration 50
|
||||
|
||||
# Command for decoding using pretrained models (before fine-tuning):
|
||||
|
||||
python3 ./whisper/decode.py \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
--epoch -1 --avg 1 \
|
||||
--remove-whisper-encoder-input-length-restriction False \
|
||||
--beam-size 10 --max-duration 50
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import whisper
|
||||
from asr_datamodule import AsrDataModule
|
||||
from lhotse.cut import Cut
|
||||
from multi_dataset import MultiDataset
|
||||
from tn.chinese.normalizer import Normalizer
|
||||
from whisper.normalizers import BasicTextNormalizer
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from zhconv import convert
|
||||
|
||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def average_checkpoints(
|
||||
filenames: List[Path], device: torch.device = torch.device("cpu")
|
||||
) -> dict:
|
||||
"""Average a list of checkpoints.
|
||||
The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict.
|
||||
|
||||
Args:
|
||||
filenames:
|
||||
Filenames of the checkpoints to be averaged. We assume all
|
||||
checkpoints are saved by :func:`save_checkpoint`.
|
||||
device:
|
||||
Move checkpoints to this device before averaging.
|
||||
Returns:
|
||||
Return a dict (i.e., state_dict) which is the average of all
|
||||
model state dicts contained in the checkpoints.
|
||||
"""
|
||||
n = len(filenames)
|
||||
|
||||
if "model" in torch.load(filenames[0], map_location=device):
|
||||
avg = torch.load(filenames[0], map_location=device)["model"]
|
||||
else:
|
||||
avg = torch.load(filenames[0], map_location=device)
|
||||
|
||||
# Identify shared parameters. Two parameters are said to be shared
|
||||
# if they have the same data_ptr
|
||||
uniqued: Dict[int, str] = dict()
|
||||
|
||||
for k, v in avg.items():
|
||||
v_data_ptr = v.data_ptr()
|
||||
if v_data_ptr in uniqued:
|
||||
continue
|
||||
uniqued[v_data_ptr] = k
|
||||
|
||||
uniqued_names = list(uniqued.values())
|
||||
|
||||
for i in range(1, n):
|
||||
if "model" in torch.load(filenames[i], map_location=device):
|
||||
state_dict = torch.load(filenames[i], map_location=device)["model"]
|
||||
else:
|
||||
state_dict = torch.load(filenames[i], map_location=device)
|
||||
for k in uniqued_names:
|
||||
avg[k] += state_dict[k]
|
||||
|
||||
for k in uniqued_names:
|
||||
if avg[k].is_floating_point():
|
||||
avg[k] /= n
|
||||
else:
|
||||
avg[k] //= n
|
||||
|
||||
return avg
|
||||
|
||||
|
||||
def remove_punctuation(text: str or List[str]):
|
||||
"""Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
|
||||
|
||||
Args:
|
||||
text: It can be a string or a list of strings.
|
||||
Returns:
|
||||
Return a string or a list of strings without any punctuation.
|
||||
"""
|
||||
punctuation = "!,.;:?、!,。;:?《》 "
|
||||
if isinstance(text, str):
|
||||
text = re.sub(r"[{}]+".format(punctuation), "", text).strip()
|
||||
return text
|
||||
elif isinstance(text, list):
|
||||
result_text = []
|
||||
for t in text:
|
||||
t = re.sub(r"[{}]+".format(punctuation), "", t).strip()
|
||||
result_text.append(t)
|
||||
return result_text
|
||||
else:
|
||||
raise Exception(f"Not support type {type(text)}")
|
||||
|
||||
|
||||
def to_simple(text: str or List[str]):
|
||||
"""Convert traditional Chinese to simplified Chinese.
|
||||
Args:
|
||||
text: It can be a string or a list of strings.
|
||||
Returns:
|
||||
Return a string or a list of strings converted to simplified Chinese.
|
||||
"""
|
||||
if isinstance(text, str):
|
||||
text = convert(text, "zh-cn")
|
||||
return text
|
||||
elif isinstance(text, list):
|
||||
result_text = []
|
||||
for t in text:
|
||||
t = convert(t, "zh-cn")
|
||||
result_text.append(t)
|
||||
return result_text
|
||||
else:
|
||||
raise Exception(f"Not support type{type(text)}")
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="beam-search",
|
||||
help="""Decoding method.
|
||||
Supported values are:
|
||||
- beam-search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="beam size for beam search decoding",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="whisper/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="large-v2",
|
||||
choices=["large-v2", "large-v3", "medium", "small", "tiny"],
|
||||
help="""The model name to use.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--remove-whisper-encoder-input-length-restriction",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="replace whisper encoder forward method to remove input length restriction",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
batch: dict,
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: "beam-search"
|
||||
- value: A list of lists. Each sublist is a list of token IDs.
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
batch:
|
||||
It is returned by :meth:`torch.utils.data.DataLoader.__iter__`.
|
||||
Returns:
|
||||
Return a dict, whose key may be "beam-search".
|
||||
"""
|
||||
dtype = torch.float16
|
||||
device = torch.device("cuda")
|
||||
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device, dtype=dtype).transpose(1, 2)
|
||||
if not params.remove_whisper_encoder_input_length_restriction:
|
||||
T = 3000
|
||||
if feature.shape[2] < T:
|
||||
feature = torch.cat(
|
||||
[
|
||||
feature,
|
||||
torch.zeros(
|
||||
feature.shape[0], feature.shape[1], T - feature.shape[2]
|
||||
).to(device, dtype=dtype),
|
||||
],
|
||||
2,
|
||||
)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_len = supervisions["num_frames"]
|
||||
feature_len = feature_len.to(device, dtype=dtype)
|
||||
results = model.decode(feature, params.decoding_options)
|
||||
hyps = [result.text for result in results]
|
||||
|
||||
hyps = remove_punctuation(hyps)
|
||||
hyps = to_simple(hyps)
|
||||
hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
|
||||
print(hyps)
|
||||
return {"beam-search": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
The dataloader.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
Returns:
|
||||
Return a dict, whose key may be "beam-search".
|
||||
"""
|
||||
results = []
|
||||
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
|
||||
num_cuts += len(batch["supervisions"]["text"])
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
|
||||
enable_log = True
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
if enable_log:
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tCER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
AsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
setup_logger(
|
||||
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
|
||||
)
|
||||
|
||||
options = whisper.DecodingOptions(
|
||||
task="transcribe",
|
||||
language="zh",
|
||||
without_timestamps=True,
|
||||
beam_size=params.beam_size,
|
||||
)
|
||||
params.decoding_options = options
|
||||
params.cleaner = BasicTextNormalizer()
|
||||
params.normalizer = Normalizer()
|
||||
|
||||
logging.info("Decoding started")
|
||||
logging.info(params)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
if params.remove_whisper_encoder_input_length_restriction:
|
||||
replace_whisper_encoder_forward()
|
||||
model = whisper.load_model(params.model_name, "cpu")
|
||||
if params.epoch > 0:
|
||||
if params.avg > 1:
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
)
|
||||
if "model" not in checkpoint:
|
||||
# deepspeed converted checkpoint only contains model state_dict
|
||||
filenames = [
|
||||
f"{params.exp_dir}/epoch-{epoch}.pt"
|
||||
for epoch in range(start, params.epoch + 1)
|
||||
]
|
||||
model.load_state_dict(average_checkpoints(filenames))
|
||||
else:
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
# save checkpoints
|
||||
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
torch.save(model.state_dict(), filename)
|
||||
else:
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
)
|
||||
if "model" not in checkpoint:
|
||||
model.load_state_dict(checkpoint, strict=True)
|
||||
else:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
|
||||
data_module = AsrDataModule(args)
|
||||
multi_dataset = MultiDataset(args.manifest_dir)
|
||||
|
||||
def remove_long_utt(c: Cut):
|
||||
# Keep only utterances with duration in 30 seconds
|
||||
#
|
||||
if c.duration > 30.0:
|
||||
# logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
return False
|
||||
return True
|
||||
|
||||
test_sets_cuts = multi_dataset.test_cuts()
|
||||
|
||||
test_sets = test_sets_cuts.keys()
|
||||
test_dls = [
|
||||
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_long_utt))
|
||||
for cuts_name in test_sets
|
||||
]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
)
|
||||
|
||||
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
egs/multi_zh-hans/ASR/whisper/ds_config_zero1.json
Symbolic link
1
egs/multi_zh-hans/ASR/whisper/ds_config_zero1.json
Symbolic link
@ -0,0 +1 @@
|
||||
../../../aishell/ASR/whisper/ds_config_zero1.json
|
||||
1
egs/multi_zh-hans/ASR/whisper/label_smoothing.py
Symbolic link
1
egs/multi_zh-hans/ASR/whisper/label_smoothing.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/conformer_ctc/label_smoothing.py
|
||||
296
egs/multi_zh-hans/ASR/whisper/multi_dataset.py
Normal file
296
egs/multi_zh-hans/ASR/whisper/multi_dataset.py
Normal file
@ -0,0 +1,296 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import glob
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import lhotse
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
|
||||
|
||||
class MultiDataset:
|
||||
def __init__(self, fbank_dir: str):
|
||||
"""
|
||||
Args:
|
||||
manifest_dir:
|
||||
It is expected to contain the following files:
|
||||
- aishell_cuts_train.jsonl.gz
|
||||
- aishell2_cuts_train.jsonl.gz
|
||||
- aishell4_cuts_train_L.jsonl.gz
|
||||
- aishell4_cuts_train_M.jsonl.gz
|
||||
- aishell4_cuts_train_S.jsonl.gz
|
||||
- alimeeting-far_cuts_train.jsonl.gz
|
||||
- magicdata_cuts_train.jsonl.gz
|
||||
- primewords_cuts_train.jsonl.gz
|
||||
- stcmds_cuts_train.jsonl.gz
|
||||
- thchs_30_cuts_train.jsonl.gz
|
||||
- kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
|
||||
- kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz
|
||||
- wenetspeech/cuts_L.jsonl.gz
|
||||
"""
|
||||
self.fbank_dir = Path(fbank_dir)
|
||||
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset train cuts")
|
||||
|
||||
# THCHS-30
|
||||
logging.info("Loading THCHS-30 in lazy mode")
|
||||
thchs_30_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "thchs_30_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL-1
|
||||
logging.info("Loading Aishell-1 in lazy mode")
|
||||
aishell_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL-2
|
||||
logging.info("Loading Aishell-2 in lazy mode")
|
||||
aishell_2_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL-4
|
||||
logging.info("Loading Aishell-4 in lazy mode")
|
||||
aishell_4_L_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell4_cuts_train_L.jsonl.gz"
|
||||
)
|
||||
aishell_4_M_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell4_cuts_train_M.jsonl.gz"
|
||||
)
|
||||
aishell_4_S_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell4_cuts_train_S.jsonl.gz"
|
||||
)
|
||||
|
||||
# ST-CMDS
|
||||
logging.info("Loading ST-CMDS in lazy mode")
|
||||
stcmds_cuts = load_manifest_lazy(self.fbank_dir / "stcmds_cuts_train.jsonl.gz")
|
||||
|
||||
# Primewords
|
||||
logging.info("Loading Primewords in lazy mode")
|
||||
primewords_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "primewords_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# MagicData
|
||||
logging.info("Loading MagicData in lazy mode")
|
||||
magicdata_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "magicdata_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# Ali-Meeting
|
||||
logging.info("Loading Ali-Meeting in lazy mode")
|
||||
alimeeting_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "alimeeting-far_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# WeNetSpeech
|
||||
logging.info("Loading WeNetSpeech in lazy mode")
|
||||
wenetspeech_L_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "wenetspeech" / "cuts_L.jsonl.gz"
|
||||
)
|
||||
|
||||
# KeSpeech
|
||||
logging.info("Loading KeSpeech in lazy mode")
|
||||
kespeech_1_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase1.jsonl.gz"
|
||||
)
|
||||
kespeech_2_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase2.jsonl.gz"
|
||||
)
|
||||
|
||||
return CutSet.mux(
|
||||
thchs_30_cuts,
|
||||
aishell_cuts,
|
||||
aishell_2_cuts,
|
||||
aishell_4_L_cuts,
|
||||
aishell_4_M_cuts,
|
||||
aishell_4_S_cuts,
|
||||
stcmds_cuts,
|
||||
primewords_cuts,
|
||||
magicdata_cuts,
|
||||
alimeeting_cuts,
|
||||
wenetspeech_L_cuts,
|
||||
kespeech_1_cuts,
|
||||
kespeech_2_cuts,
|
||||
weights=[
|
||||
len(thchs_30_cuts),
|
||||
len(aishell_cuts),
|
||||
len(aishell_2_cuts),
|
||||
len(aishell_4_L_cuts),
|
||||
len(aishell_4_M_cuts),
|
||||
len(aishell_4_S_cuts),
|
||||
len(stcmds_cuts),
|
||||
len(primewords_cuts),
|
||||
len(magicdata_cuts),
|
||||
len(alimeeting_cuts),
|
||||
len(wenetspeech_L_cuts),
|
||||
len(kespeech_1_cuts),
|
||||
len(kespeech_2_cuts),
|
||||
],
|
||||
)
|
||||
|
||||
def dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset dev cuts")
|
||||
|
||||
# AISHELL
|
||||
logging.info("Loading Aishell DEV set in lazy mode")
|
||||
aishell_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL-2
|
||||
logging.info("Loading Aishell-2 DEV set in lazy mode")
|
||||
aishell2_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
# Ali-Meeting
|
||||
logging.info("Loading Ali-Meeting DEV set in lazy mode")
|
||||
alimeeting_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz"
|
||||
)
|
||||
|
||||
# MagicData
|
||||
logging.info("Loading MagicData DEV set in lazy mode")
|
||||
magicdata_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "magicdata_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
# KeSpeech
|
||||
logging.info("Loading KeSpeech DEV set in lazy mode")
|
||||
kespeech_dev_phase1_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz"
|
||||
)
|
||||
kespeech_dev_phase2_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz"
|
||||
)
|
||||
|
||||
# WeNetSpeech
|
||||
logging.info("Loading WeNetSpeech DEV set in lazy mode")
|
||||
wenetspeech_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz"
|
||||
)
|
||||
|
||||
return wenetspeech_dev_cuts
|
||||
# return [
|
||||
# aishell_dev_cuts,
|
||||
# aishell2_dev_cuts,
|
||||
# alimeeting_dev_cuts,
|
||||
# magicdata_dev_cuts,
|
||||
# kespeech_dev_phase1_cuts,
|
||||
# kespeech_dev_phase2_cuts,
|
||||
# wenetspeech_dev_cuts,
|
||||
# ]
|
||||
|
||||
def test_cuts(self) -> Dict[str, CutSet]:
|
||||
logging.info("About to get multidataset test cuts")
|
||||
|
||||
# AISHELL
|
||||
logging.info("Loading Aishell set in lazy mode")
|
||||
aishell_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell_cuts_test.jsonl.gz"
|
||||
)
|
||||
aishell_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL-2
|
||||
logging.info("Loading Aishell-2 set in lazy mode")
|
||||
aishell2_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
|
||||
)
|
||||
aishell2_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL-4
|
||||
logging.info("Loading Aishell-4 TEST set in lazy mode")
|
||||
aishell4_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell4_cuts_test.jsonl.gz"
|
||||
)
|
||||
|
||||
# Ali-Meeting
|
||||
logging.info("Loading Ali-Meeting set in lazy mode")
|
||||
alimeeting_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "alimeeting-far_cuts_test.jsonl.gz"
|
||||
)
|
||||
alimeeting_eval_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz"
|
||||
)
|
||||
|
||||
# MagicData
|
||||
logging.info("Loading MagicData set in lazy mode")
|
||||
magicdata_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "magicdata_cuts_test.jsonl.gz"
|
||||
)
|
||||
magicdata_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "magicdata_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
# KeSpeech
|
||||
logging.info("Loading KeSpeech set in lazy mode")
|
||||
kespeech_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_test.jsonl.gz"
|
||||
)
|
||||
kespeech_dev_phase1_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz"
|
||||
)
|
||||
kespeech_dev_phase2_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz"
|
||||
)
|
||||
|
||||
# WeNetSpeech
|
||||
logging.info("Loading WeNetSpeech set in lazy mode")
|
||||
wenetspeech_test_meeting_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz"
|
||||
)
|
||||
wenetspeech_test_net_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz"
|
||||
)
|
||||
wenetspeech_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz"
|
||||
)
|
||||
|
||||
return {
|
||||
"aishell-2_test": aishell2_test_cuts,
|
||||
"aishell-4": aishell4_test_cuts,
|
||||
"magicdata_test": magicdata_test_cuts,
|
||||
"kespeech-asr_test": kespeech_test_cuts,
|
||||
}
|
||||
|
||||
# return {
|
||||
# "alimeeting_test": alimeeting_test_cuts,
|
||||
# "alimeeting_eval": alimeeting_eval_cuts,
|
||||
# "aishell_test": aishell_test_cuts,
|
||||
# "aishell_dev": aishell_dev_cuts,
|
||||
# "aishell-2_test": aishell2_test_cuts,
|
||||
# "aishell-2_dev": aishell2_dev_cuts,
|
||||
# "aishell-4": aishell4_test_cuts,
|
||||
# "magicdata_test": magicdata_test_cuts,
|
||||
# "magicdata_dev": magicdata_dev_cuts,
|
||||
# "kespeech-asr_test": kespeech_test_cuts,
|
||||
# "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts,
|
||||
# "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
|
||||
# "wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
|
||||
# "wenetspeech-net_test": wenetspeech_test_net_cuts,
|
||||
# "wenetspeech_dev": wenetspeech_dev_cuts,
|
||||
# }
|
||||
1
egs/multi_zh-hans/ASR/whisper/optim.py
Symbolic link
1
egs/multi_zh-hans/ASR/whisper/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/optim.py
|
||||
1
egs/multi_zh-hans/ASR/whisper/requirements.txt
Symbolic link
1
egs/multi_zh-hans/ASR/whisper/requirements.txt
Symbolic link
@ -0,0 +1 @@
|
||||
../../../aishell/ASR/whisper/requirements.txt
|
||||
983
egs/multi_zh-hans/ASR/whisper/train.py
Normal file
983
egs/multi_zh-hans/ASR/whisper/train.py
Normal file
@ -0,0 +1,983 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||
# 2024 Yuekai Zhang
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
|
||||
#fine-tuning with deepspeed zero stage 1
|
||||
torchrun --nproc-per-node 8 ./whisper/train.py \
|
||||
--max-duration 200 \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./whisper/ds_config_zero1.json
|
||||
|
||||
# fine-tuning with ddp
|
||||
torchrun --nproc_per_node 8 ./whisper/train.py \
|
||||
--max-duration 200 \
|
||||
--exp-dir whisper/exp_medium \
|
||||
--base-lr 1e-5 \
|
||||
--model-name medium
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import deepspeed
|
||||
import k2
|
||||
import optim
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import whisper
|
||||
from asr_datamodule import AsrDataModule
|
||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
from lhotse import CutSet, load_manifest
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from multi_dataset import MultiDataset
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.functional import pad as pad_tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import update_averaged_model
|
||||
from icefall.dist import cleanup_dist, get_rank, get_world_size, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
filter_uneven_sized_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
|
||||
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||
if isinstance(model, DDP):
|
||||
# get underlying nn.Module
|
||||
model = model.module
|
||||
for module in model.modules():
|
||||
if hasattr(module, "batch_count"):
|
||||
module.batch_count = batch_count
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tensorboard",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Should various information be logged in tensorboard.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Resume training from this epoch. It should be positive.
|
||||
If larger than 1, it will load checkpoint from
|
||||
exp-dir/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-batch",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --start-epoch is ignored and
|
||||
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="whisper/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="large-v2",
|
||||
choices=["large-v2", "large-v3", "medium", "small", "tiny"],
|
||||
help="""The model name to use.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pretrained-model-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="""The path to the pretrained model if it is not None. Training will
|
||||
start from this model. e.g. ./wenetspeech/ASR/whisper/exp_large_v2/epoch-4-avg-3.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--base-lr", type=float, default=1e-5, help="The base learning rate."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-batches",
|
||||
type=float,
|
||||
default=5000,
|
||||
help="""Number of steps that affects how rapidly the learning rate
|
||||
decreases. We suggest not to change this.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-epochs",
|
||||
type=float,
|
||||
default=6,
|
||||
help="""Number of epochs that affects how rapidly the learning rate decreases.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="The seed for random generators intended for reproducibility",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--print-diagnostics",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Accumulate stats on activations, print them and exit.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--inf-check",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Add hooks to check for infinite module outputs and gradients.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keep-last-k",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""Only keep this number of checkpoints on disk.
|
||||
For instance, if it is 3, there are only 3 checkpoints
|
||||
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
||||
It does not affect checkpoints with name `epoch-xxx.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--average-period",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Update the averaged model, namely `model_avg`, after processing
|
||||
this number of batches. `model_avg` is a separate version of model,
|
||||
in which each floating-point parameter is the average of all the
|
||||
parameters from the start of training. Each time we take the average,
|
||||
we do: `model_avg = model * (average_period / batch_idx_train) +
|
||||
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-fp16",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser = deepspeed.add_config_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
"""Return a dict containing training parameters.
|
||||
|
||||
All training related parameters that are not passed from the commandline
|
||||
are saved in the variable `params`.
|
||||
|
||||
Commandline options are merged into `params` after they are parsed, so
|
||||
you can also access them via `params`.
|
||||
|
||||
Explanation of options saved in `params`:
|
||||
|
||||
- frame_shift_ms: The frame shift in milliseconds.
|
||||
- allowed_excess_duration_ratio: The allowed excess duration ratio.
|
||||
- best_train_loss: The best training loss so far.
|
||||
- best_valid_loss: The best validation loss so far.
|
||||
- best_train_epoch: The epoch where the best training loss is achieved.
|
||||
- best_valid_epoch: The epoch where the best validation loss is achieved.
|
||||
- batch_idx_train: The batch index of the current batch.
|
||||
- log_interval: Log training stats every `log_interval` batches.
|
||||
- reset_interval: Reset the stats every `reset_interval` batches.
|
||||
- valid_interval: Run validation every `valid_interval` batches.
|
||||
- env_info: The environment information.
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"frame_shift_ms": 10.0,
|
||||
"subsampling_factor": 2,
|
||||
"allowed_excess_duration_ratio": 0.1,
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 10000,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
model_avg: nn.Module = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Load checkpoint from file.
|
||||
|
||||
If params.start_batch is positive, it will load the checkpoint from
|
||||
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
||||
params.start_epoch is larger than 1, it will load the checkpoint from
|
||||
`params.start_epoch - 1`.
|
||||
|
||||
Apart from loading state dict for `model` and `optimizer` it also updates
|
||||
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||
and `best_valid_loss` in `params`.
|
||||
|
||||
Args:
|
||||
params:
|
||||
The return value of :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
optimizer:
|
||||
The optimizer that we are using.
|
||||
scheduler:
|
||||
The scheduler that we are using.
|
||||
Returns:
|
||||
Return a dict containing previously saved training info.
|
||||
"""
|
||||
if params.start_batch > 0:
|
||||
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
||||
elif params.start_epoch > 1:
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
else:
|
||||
return None
|
||||
|
||||
assert filename.is_file(), f"{filename} does not exist!"
|
||||
|
||||
saved_params = load_checkpoint(
|
||||
filename,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
keys = [
|
||||
"best_train_epoch",
|
||||
"best_valid_epoch",
|
||||
"batch_idx_train",
|
||||
"best_train_loss",
|
||||
"best_valid_loss",
|
||||
]
|
||||
for k in keys:
|
||||
params[k] = saved_params[k]
|
||||
|
||||
if params.start_batch > 0:
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
optimizer:
|
||||
The optimizer used in the training.
|
||||
sampler:
|
||||
The sampler for the training dataset.
|
||||
scaler:
|
||||
The scaler used for mix precision training.
|
||||
"""
|
||||
if rank != 0:
|
||||
return
|
||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||
save_checkpoint_impl(
|
||||
filename=filename,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
tokenizer: whisper.tokenizer.Tokenizer,
|
||||
model: Union[nn.Module, DDP],
|
||||
batch: dict,
|
||||
is_training: bool,
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
"""
|
||||
Compute the loss for the given batch.
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
tokenizer:
|
||||
The tokenizer used to encode the text.
|
||||
model:
|
||||
The model for training.
|
||||
batch:
|
||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||
for the content in it.
|
||||
is_training:
|
||||
Whether it is training.
|
||||
Returns:
|
||||
Return a tuple of two elements. The first element is the loss tensor.
|
||||
"""
|
||||
# For the uneven-sized batch, the total duration after padding would possibly
|
||||
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
|
||||
# we simply drop the last few shortest samples, so that the retained total frames
|
||||
# (after padding) would not exceed `allowed_max_frames`:
|
||||
# `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
|
||||
# where `max_frames = max_duration * 1000 // frame_shift_ms`.
|
||||
# We set allowed_excess_duration_ratio=0.1.
|
||||
if isinstance(model, DDP):
|
||||
# get underlying nn.Module
|
||||
model = model.module
|
||||
|
||||
def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor:
|
||||
padding_size = max(tensor.shape[0] for tensor in tensors)
|
||||
dims = len(tensors[0].shape)
|
||||
padded_tensors = []
|
||||
for tensor in tensors:
|
||||
padding = [0] * 2 * dims
|
||||
padding[-1] = padding_size - tensor.shape[0]
|
||||
padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value))
|
||||
return torch.stack([tensor for tensor in padded_tensors], dim=0)
|
||||
|
||||
max_frames = params.max_duration * 1000 // params.frame_shift_ms
|
||||
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
|
||||
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
|
||||
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
feature = feature.transpose(1, 2) # (N, C, T)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
batch_idx_train = params.batch_idx_train
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
# remove spaces in texts
|
||||
texts = [text.replace(" ", "") for text in texts]
|
||||
|
||||
text_tokens_list = [
|
||||
list(tokenizer.sot_sequence_including_notimestamps)
|
||||
+ tokenizer.encode(text)
|
||||
+ [tokenizer.eot]
|
||||
for text in texts
|
||||
]
|
||||
# convert it to torch tensor
|
||||
text_tokens_list = [
|
||||
torch.LongTensor(text_tokens) for text_tokens in text_tokens_list
|
||||
]
|
||||
|
||||
# 50256 is the index of <pad> for all whisper models
|
||||
prev_outputs_tokens = _batch_tensors(
|
||||
[tokens[:-1] for tokens in text_tokens_list], pad_value=50256
|
||||
)
|
||||
target_tokens = _batch_tensors(
|
||||
[tokens[1:] for tokens in text_tokens_list], pad_value=50256
|
||||
)
|
||||
target_lengths = torch.LongTensor(
|
||||
[tokens.shape[0] - 1 for tokens in text_tokens_list]
|
||||
)
|
||||
|
||||
decoder_criterion = LabelSmoothingLoss(
|
||||
ignore_index=50256, label_smoothing=0.1, reduction="sum"
|
||||
)
|
||||
|
||||
# ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|>
|
||||
ignore_prefix_size = 3
|
||||
with torch.set_grad_enabled(is_training):
|
||||
encoder_out = model.encoder(feature)
|
||||
text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out)
|
||||
text_logits = text_logits[:, ignore_prefix_size:, :]
|
||||
target_tokens = target_tokens[:, ignore_prefix_size:]
|
||||
loss = decoder_criterion(text_logits, target_tokens.to(device))
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
tokenizer: whisper.tokenizer.Tokenizer,
|
||||
model: Union[nn.Module, DDP],
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> MetricsTracker:
|
||||
"""Run the validation process."""
|
||||
model.eval()
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
batch=batch,
|
||||
is_training=False,
|
||||
)
|
||||
assert loss.requires_grad is False
|
||||
tot_loss = tot_loss + loss_info
|
||||
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(loss.device)
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
tokenizer: whisper.tokenizer.Tokenizer,
|
||||
model: Union[nn.Module, DDP],
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: LRSchedulerType,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Train the model for one epoch.
|
||||
|
||||
The training loss from the mean of all frames is saved in
|
||||
`params.train_loss`. It runs the validation process every
|
||||
`params.valid_interval` batches.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The model for training.
|
||||
optimizer:
|
||||
The optimizer we are using.
|
||||
scheduler:
|
||||
The learning rate scheduler, we call step() every step.
|
||||
train_dl:
|
||||
Dataloader for the training dataset.
|
||||
valid_dl:
|
||||
Dataloader for the validation dataset.
|
||||
scaler:
|
||||
The scaler used for mix precision training.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
tb_writer:
|
||||
Writer to write log messages to tensorboard.
|
||||
world_size:
|
||||
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
||||
rank:
|
||||
The rank of the node in DDP training. If no DDP is used, it should
|
||||
be set to 0.
|
||||
"""
|
||||
model.train()
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
if params.deepspeed:
|
||||
model.save_checkpoint(
|
||||
save_dir=params.exp_dir,
|
||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||
client_state={},
|
||||
)
|
||||
if rank == 0:
|
||||
convert_zero_checkpoint_to_fp32_state_dict(
|
||||
params.exp_dir,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||
)
|
||||
os.system(
|
||||
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
||||
)
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
if params.deepspeed:
|
||||
# deepspeed's backward() is different from torch's backward()
|
||||
# in that it does not accept a loss tensor as input.
|
||||
# It computes the loss internally.
|
||||
model.backward(loss)
|
||||
model.step()
|
||||
else:
|
||||
scaler.scale(loss).backward()
|
||||
set_batch_count(model, params.batch_idx_train)
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
except: # noqa
|
||||
display_and_save_batch(batch, params=params)
|
||||
raise
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
return
|
||||
|
||||
if (
|
||||
rank == 0
|
||||
and params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.average_period == 0
|
||||
and not params.deepspeed
|
||||
):
|
||||
update_averaged_model(
|
||||
params=params,
|
||||
model_cur=model,
|
||||
model_avg=model_avg,
|
||||
)
|
||||
|
||||
if batch_idx % 100 == 0 and params.use_fp16 and not params.deepspeed:
|
||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||
# of the grad scaler is configurable, but we can't configure it to have different
|
||||
# behavior depending on the current grad scale.
|
||||
cur_grad_scale = scaler._scale.item()
|
||||
if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
|
||||
scaler.update(cur_grad_scale * 2.0)
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
if batch_idx % params.log_interval == 0:
|
||||
try:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
except: # noqa
|
||||
cur_lr = 0.0
|
||||
cur_grad_scale = (
|
||||
scaler._scale.item()
|
||||
if (params.use_fp16 and not params.deepspeed)
|
||||
else 1.0
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
f"lr: {cur_lr:.2e}, "
|
||||
+ (
|
||||
f"grad_scale: {scaler._scale.item()}"
|
||||
if (params.use_fp16 and not params.deepspeed)
|
||||
else ""
|
||||
)
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||
)
|
||||
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
if params.use_fp16:
|
||||
tb_writer.add_scalar(
|
||||
"train/grad_scale",
|
||||
cur_grad_scale,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
params.best_train_loss = params.train_loss
|
||||
|
||||
|
||||
def run(rank, world_size, args):
|
||||
"""
|
||||
Args:
|
||||
rank:
|
||||
It is a value between 0 and `world_size-1`, which is
|
||||
passed automatically by `mp.spawn()` in :func:`main`.
|
||||
The node with rank 0 is responsible for saving checkpoint.
|
||||
world_size:
|
||||
Number of GPUs for DDP training.
|
||||
args:
|
||||
The return value of get_parser().parse_args()
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
|
||||
replace_whisper_encoder_forward()
|
||||
model = whisper.load_model(params.model_name, "cpu")
|
||||
del model.alignment_heads
|
||||
|
||||
if params.pretrained_model_path:
|
||||
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
||||
if "model" not in checkpoint:
|
||||
model.load_state_dict(checkpoint, strict=True)
|
||||
else:
|
||||
load_checkpoint(params.pretrained_model_path, model)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
tokenizer = whisper.tokenizer.get_tokenizer(
|
||||
model.is_multilingual,
|
||||
num_languages=model.num_languages,
|
||||
language="zh",
|
||||
task="transcribe",
|
||||
)
|
||||
|
||||
model_avg: Optional[nn.Module] = None
|
||||
if rank == 0:
|
||||
# model_avg is only used with rank 0
|
||||
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||
|
||||
assert params.start_epoch > 0, params.start_epoch
|
||||
checkpoints = load_checkpoint_if_available(
|
||||
params=params, model=model, model_avg=model_avg
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logging.info(f"Device: {device}")
|
||||
model.to(device)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr)
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||
|
||||
if checkpoints and "optimizer" in checkpoints:
|
||||
logging.info("Loading optimizer state dict")
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
|
||||
if (
|
||||
checkpoints
|
||||
and "scheduler" in checkpoints
|
||||
and checkpoints["scheduler"] is not None
|
||||
):
|
||||
logging.info("Loading scheduler state dict")
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
if world_size > 1:
|
||||
if params.deepspeed:
|
||||
logging.info("Using DeepSpeed")
|
||||
model, optimizer, _, scheduler = deepspeed.initialize(
|
||||
args=params, model=model, model_parameters=model.parameters()
|
||||
)
|
||||
else:
|
||||
logging.info("Using DDP")
|
||||
setup_dist(use_ddp_launch=True)
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
if params.inf_check:
|
||||
register_inf_check_hooks(model)
|
||||
|
||||
data_module = AsrDataModule(args)
|
||||
multi_dataset = MultiDataset(args.manifest_dir)
|
||||
|
||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||
# We only load the sampler's state dict when it loads a checkpoint
|
||||
# saved in the middle of an epoch
|
||||
sampler_state_dict = checkpoints["sampler"]
|
||||
else:
|
||||
sampler_state_dict = None
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
#
|
||||
# Caution: There is a reason to select 20.0 here. Please see
|
||||
# ../local/display_manifest_statistics.py
|
||||
#
|
||||
# You should use ../local/display_manifest_statistics.py to get
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
if c.duration < 1.0 or c.duration > 20.0:
|
||||
# logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
return False
|
||||
return True
|
||||
|
||||
train_cuts = multi_dataset.train_cuts()
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
train_dl = data_module.train_dataloaders(
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
||||
valid_cuts = multi_dataset.dev_cuts()
|
||||
valid_dl = data_module.valid_dataloaders(valid_cuts)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
logging.info(f"start training from epoch {params.start_epoch}")
|
||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||
if not params.deepspeed:
|
||||
scheduler.step_epoch(epoch - 1)
|
||||
fix_random_seed(params.seed + epoch - 1)
|
||||
train_dl.sampler.set_epoch(epoch - 1)
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
scaler=scaler,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.print_diagnostics:
|
||||
diagnostic.print_diagnostics()
|
||||
break
|
||||
|
||||
if params.deepspeed:
|
||||
model.save_checkpoint(
|
||||
save_dir=params.exp_dir,
|
||||
tag=f"epoch-{params.cur_epoch}",
|
||||
client_state={},
|
||||
)
|
||||
if rank == 0:
|
||||
convert_zero_checkpoint_to_fp32_state_dict(
|
||||
params.exp_dir,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
|
||||
tag=f"epoch-{params.cur_epoch}",
|
||||
)
|
||||
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
|
||||
else:
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
if world_size > 1 and not params.deepspeed:
|
||||
torch.distributed.barrier()
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def display_and_save_batch(
|
||||
batch: dict,
|
||||
params: AttributeDict,
|
||||
) -> None:
|
||||
"""Display the batch statistics and save the batch into disk.
|
||||
|
||||
Args:
|
||||
batch:
|
||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||
for the content in it.
|
||||
params:
|
||||
Parameters for training. See :func:`get_params`.
|
||||
"""
|
||||
from lhotse.utils import uuid4
|
||||
|
||||
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
|
||||
logging.info(f"Saving batch to {filename}")
|
||||
torch.save(batch, filename)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
features = batch["inputs"]
|
||||
|
||||
logging.info(f"features shape: {features.shape}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
AsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
world_size = get_world_size()
|
||||
rank = get_rank()
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
run(rank=rank, world_size=world_size, args=args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -0,0 +1 @@
|
||||
../../../aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py
|
||||
15
egs/speechio/ASR/README.md
Normal file
15
egs/speechio/ASR/README.md
Normal file
@ -0,0 +1,15 @@
|
||||
|
||||
# Introduction
|
||||
|
||||
This recipe includes some different pretrained ASR models' decoding results with [SpeechIO](https://github.com/SpeechColab/Leaderboard) test sets.
|
||||
|
||||
[./RESULTS.md](./RESULTS.md) contains the latest results.
|
||||
|
||||
# Pretrained Models
|
||||
|
||||
The following table lists the pretrained models.
|
||||
|
||||
| | Huggingface | Comment |
|
||||
|---------------------------------------|--------------------|-----------------------------|
|
||||
| `zipformer` | zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 | Using [multi-hans-zh recipe](../../multi_zh-hans/ASR/zipformer/) training | |
|
||||
| `whisper` | yuekai/icefall_asr_wenetspeech_whisper | Using [wenetspeech recipe](../../wenetspeech/ASR/whisper/) training |
|
||||
92
egs/speechio/ASR/RESULTS.md
Normal file
92
egs/speechio/ASR/RESULTS.md
Normal file
@ -0,0 +1,92 @@
|
||||
## Results
|
||||
|
||||
### SpeechIO Test Set Decoding Results
|
||||
|
||||
##### Decoding results using pretrained [multi-hans-zh zipformer](https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24), [whipser-large-v2](https://github.com/openai/whisper/blob/main/whisper/__init__.py#L27), [whisper-large-v2-wenetspeech-ft](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper).
|
||||
|
||||
| | zipformer_transducer | zipformer_transducer_blank_penalty_2 | whisper_large_v2 | whisper_large_v2_wenetspeech | whisper_large_v2_wenetspeech_zipformer_fusion |
|
||||
|------------------------|----------------------|--------------------------------------|------------------|------------------------------|-----------------------------------------------|
|
||||
| SPEECHIO_ASR_ZH00000 | 10.04 | 8.04 | 11.4 | 9.88 | 7.78 |
|
||||
| SPEECHIO_ASR_ZH00001 | 1.67 | 1.51 | 2.49 | 1.57 | 1.38 |
|
||||
| SPEECHIO_ASR_ZH00002 | 5.89 | 5.27 | 7.89 | 5.65 | 4.99 |
|
||||
| SPEECHIO_ASR_ZH00003 | 2.66 | 2.79 | 5.94 | 2.27 | 2.33 |
|
||||
| SPEECHIO_ASR_ZH00004 | 3.6 | 3.34 | 4.57 | 3.62 | 3.26 |
|
||||
| SPEECHIO_ASR_ZH00005 | 7.54 | 5.81 | 8.39 | 7.26 | 5.43 |
|
||||
| SPEECHIO_ASR_ZH00006 | 15.59 | 13.34 | 19.07 | 13.64 | 11.96 |
|
||||
| SPEECHIO_ASR_ZH00007 | 15.9 | 15.05 | 16.7 | 14.06 | 13.73 |
|
||||
| SPEECHIO_ASR_ZH00008 | 11.07 | 9.68 | 14.69 | 10.34 | 8.87 |
|
||||
| SPEECHIO_ASR_ZH00009 | 7.38 | 6.23 | 8.32 | 6.74 | 5.96 |
|
||||
| SPEECHIO_ASR_ZH00010 | 9.19 | 7.33 | 11.2 | 8.85 | 6.97 |
|
||||
| SPEECHIO_ASR_ZH00011 | 4.16 | 3.84 | 54.56 | 4.09 | 3.72 |
|
||||
| SPEECHIO_ASR_ZH00012 | 7.61 | 6.58 | 10.53 | 8.35 | 6.27 |
|
||||
| SPEECHIO_ASR_ZH00013 | 8.72 | 7.66 | 9.32 | 7.26 | 6.7 |
|
||||
| SPEECHIO_ASR_ZH00014 | 9.69 | 8.71 | 9.03 | 7.03 | 6.59 |
|
||||
| SPEECHIO_ASR_ZH00015 | 11.94 | 11.37 | 16.58 | 12.02 | 11.11 |
|
||||
| SPEECHIO_ASR_ZH00016 | 9.79 | 8.79 | 14.1 | 10.19 | 8.15 |
|
||||
| SPEECHIO_ASR_ZH00017 | 8 | 6.72 | 9.04 | 8.9 | 6.44 |
|
||||
| SPEECHIO_ASR_ZH00018 | 5.42 | 5.02 | 6.06 | 4.86 | 4.4 |
|
||||
| SPEECHIO_ASR_ZH00019 | 11.26 | 9.06 | 14.8 | 9.83 | 8.22 |
|
||||
| SPEECHIO_ASR_ZH00020 | 4.37 | 4.23 | 5.97 | 4.23 | 4.13 |
|
||||
| SPEECHIO_ASR_ZH00021 | 7.81 | 6.34 | 8.53 | 7.08 | 5.88 |
|
||||
| SPEECHIO_ASR_ZH00022 | 9.11 | 8.54 | 9.7 | 8.97 | 8.02 |
|
||||
| SPEECHIO_ASR_ZH00023 | 9.98 | 8.98 | 6.31 | 9.44 | 8.57 |
|
||||
| SPEECHIO_ASR_ZH00024 | 16.15 | 12.95 | 20.54 | 15.92 | 12.28 |
|
||||
| SPEECHIO_ASR_ZH00025 | 10.38 | 9.82 | 11.4 | 10.26 | 9.27 |
|
||||
| SPEECHIO_ASR_ZH00026 | 5.69 | 5.63 | 9.09 | 5.95 | 5.51 |
|
||||
| Average WER (001-026) | 8.48 | 7.48 | 12.11 | 8.01 | 6.93 |
|
||||
|
||||
|
||||
|
||||
|
||||
Command for decoding using fine-tuned whisper:
|
||||
```bash
|
||||
git lfs install
|
||||
git clone https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper
|
||||
ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-4-avg3.pt whisper/exp_large_v2_wenetspeech/epoch-999.pt
|
||||
|
||||
python3 ./whisper/decode.py \
|
||||
--exp-dir whisper/exp_large_v2_wenetspeech \
|
||||
--model-name large-v2 \
|
||||
--epoch 999 --avg 1 \
|
||||
--start-index 0 --end-index 26 \
|
||||
--remove-whisper-encoder-input-length-restriction True \
|
||||
--manifest-dir data/fbank \
|
||||
--beam-size 1 --max-duration 50
|
||||
```
|
||||
Command for decoding using pretrained zipformer:
|
||||
```bash
|
||||
git lfs install
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24
|
||||
cd icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
git lfs pull --include "data/lang_bpe_2000/*"
|
||||
ln -s ../icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/exp/pretrained.pt zipformer/exp_pretrain/epoch-999.pt
|
||||
ln -s ../icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/data/lang_bpe_2000/ ./data
|
||||
wget https://huggingface.co/pkufool/icefall-asr-zipformer-wenetspeech-20230615/resolve/main/data/lang_char/words.txt
|
||||
mv words.txt ./data/lang_bpe_2000/
|
||||
|
||||
./zipformer/decode.py \
|
||||
--epoch 999 \
|
||||
--avg 1 \
|
||||
--blank-penalty 2.0 \
|
||||
--use-averaged-model false \
|
||||
--exp-dir ./zipformer/exp_pretrain \
|
||||
--max-duration 600 \
|
||||
--start-index 0 --end-index 26 \
|
||||
--manifest-dir data/fbank_kaldi \
|
||||
--decoding-method greedy_search
|
||||
```
|
||||
Command for fusion the above decoding results from whisper and zipformer:
|
||||
```bash
|
||||
python local/whisper_zipformer_fusion.py \
|
||||
--whisper-log-dir ./whisper/exp_large_v2_wenetspeech \
|
||||
--zipformer-log-dir ./zipformer/exp_pretrain/greedy_search \
|
||||
--output-log-dir ./results_fusion
|
||||
|
||||
```
|
||||
|
||||
See why the fusion helps [here](./local/whisper_zipformer_fusion.py).
|
||||
|
||||
SpeechIO fbank features, decoding scripts, logs, and decoding results
|
||||
are available at
|
||||
<https://huggingface.co/yuekai/icefall_asr_speechio>
|
||||
148
egs/speechio/ASR/local/compute_fbank_speechio.py
Normal file
148
egs/speechio/ASR/local/compute_fbank_speechio.py
Normal file
@ -0,0 +1,148 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Zengrui Jin)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the ST-CMDS dataset.
|
||||
It looks for manifests in the directory data/manifests/stcmds.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
SPEECHIO_TESTSET_INDEX = 26 # Currently, from 0 - 26 test sets are open source.
|
||||
|
||||
|
||||
def compute_fbank_speechio(
|
||||
num_mel_bins: int = 80,
|
||||
speed_perturb: bool = False,
|
||||
fbank_dir: str = "data/fbank",
|
||||
whisper_fbank: bool = False,
|
||||
):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path(fbank_dir)
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
|
||||
dataset_parts = []
|
||||
for i in range(SPEECHIO_TESTSET_INDEX + 1):
|
||||
idx = f"{i}".zfill(2)
|
||||
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
|
||||
|
||||
prefix = "speechio"
|
||||
suffix = "jsonl.gz"
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=src_dir,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
assert len(manifests) == len(dataset_parts), (
|
||||
len(manifests),
|
||||
len(dataset_parts),
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file():
|
||||
logging.info(f"{partition} already exists - skipping.")
|
||||
continue
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||
# when an executor is specified, make more partitions
|
||||
num_jobs=num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--num-mel-bins",
|
||||
type=int,
|
||||
default=80,
|
||||
help="""The number of mel bins for Fbank""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fbank-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_speechio(
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
fbank_dir=args.fbank_dir,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
1162
egs/speechio/ASR/local/display_manifest_statistics.py
Normal file
1162
egs/speechio/ASR/local/display_manifest_statistics.py
Normal file
File diff suppressed because it is too large
Load Diff
217
egs/speechio/ASR/local/whisper_zipformer_fusion.py
Normal file
217
egs/speechio/ASR/local/whisper_zipformer_fusion.py
Normal file
@ -0,0 +1,217 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2024 Author: Yuekai Zhang
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This file uses whisper and zipformer decoding results to generate fusion decoding results.
|
||||
Since whisper model is more likely to make deletion errors and zipformer model is more likely to make substitution and insertion errors,
|
||||
we trust whisper model when it makes substitution and insertion errors and trust zipformer model when it makes deletion errors.
|
||||
|
||||
Usage:
|
||||
python whisper_zipformer_fusion.py --whisper-log-dir ./whisper_decoding_log_dir --zipformer-log-dir ./zipformer_decoding_log_dir --output-log-dir ./results_fusion
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import kaldialign
|
||||
|
||||
from icefall.utils import store_transcripts, write_error_stats
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
"--whisper-log-dir",
|
||||
type=str,
|
||||
default="./recogs_whisper",
|
||||
help="The directory to store the whisper logs: e.g. recogs-SPEECHIO_ASR_ZH00014-beam-search-epoch--1-avg-1.txt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zipformer-log-dir",
|
||||
type=str,
|
||||
default="./recogs_zipformer",
|
||||
help="The directory to store the zipformer logs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-log-dir",
|
||||
type=str,
|
||||
default="./results_fusion",
|
||||
help="The directory to store the fusion logs",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def save_results(
|
||||
res_dir: Path,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
|
||||
suffix = "epoch-999-avg-1"
|
||||
|
||||
for key, results in results_dict.items():
|
||||
recog_path = res_dir / f"recogs-{test_set_name}-{key}-{suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
print(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = res_dir / f"errs-{test_set_name}-{key}-{suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
print("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = res_dir / f"wer-summary-{test_set_name}-{key}-{suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
print(s)
|
||||
|
||||
|
||||
def extract_hyp_ref_wavname(filename):
|
||||
"""
|
||||
0Phqz8RWYuE_0007-5: ref=['R', 'Y', 'Y', 'B', '它最大的优势就是进光量或者说是对光线利用率的提升']
|
||||
0Phqz8RWYuE_0007-5: hyp=而YB它最大的优势是近光量或者说是对光线利用率的提升
|
||||
"""
|
||||
hyps, refs, wav_name = [], [], []
|
||||
with open(filename, "r") as f:
|
||||
for line in f:
|
||||
if "ref" in line:
|
||||
ref = line.split("ref=")[1].strip()
|
||||
ref = ref[2:-2]
|
||||
list_elements = ref.split("', '")
|
||||
ref = "".join(list_elements)
|
||||
refs.append(ref)
|
||||
elif "hyp" in line:
|
||||
hyp = line.split("hyp=")[1].strip()
|
||||
hyps.append(hyp)
|
||||
wav_name.append(line.split(":")[0])
|
||||
return hyps, refs, wav_name
|
||||
|
||||
|
||||
def get_pair_filenames(
|
||||
whisper_log_dir,
|
||||
zipformer_log_dir,
|
||||
whisper_suffix="beam-search-epoch-999-avg-1",
|
||||
zipformer_suffix="greedy_search_blank_penalty_2.0-epoch-999-avg-1-context-2-max-sym-per-frame-1-blank-penalty-2.0",
|
||||
):
|
||||
results = []
|
||||
start_index, end_index = 0, 26
|
||||
dataset_parts = []
|
||||
for i in range(start_index, end_index + 1):
|
||||
idx = f"{i}".zfill(2)
|
||||
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
|
||||
for partition in dataset_parts:
|
||||
whisper_filename = f"{whisper_log_dir}/recogs-{partition}-{whisper_suffix}.txt"
|
||||
zipformer_filename = (
|
||||
f"{zipformer_log_dir}/recogs-{partition}-{zipformer_suffix}.txt"
|
||||
)
|
||||
results.append((whisper_filename, zipformer_filename))
|
||||
return results
|
||||
|
||||
|
||||
def fusion_hyps_trust_substituion_insertion(
|
||||
hyps_whisper, hyps_zipformer, refs, ERR="*"
|
||||
):
|
||||
"""
|
||||
alignment example:
|
||||
[('我', '你'), ('在', '*'), ('任', '任'), ('的', '的'), ('时', '时'), ('候', '候'), ('*', '呢')]
|
||||
left is whisper, right is zipformer
|
||||
for whisper substitution, use left
|
||||
for whisper insertion, use left
|
||||
for whisper deletion, use right
|
||||
"""
|
||||
hyps_fusion = []
|
||||
for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs):
|
||||
ali = kaldialign.align(hyp_w, hyp_z, ERR)
|
||||
hyp_f = ""
|
||||
for a in ali:
|
||||
if a[0] == ERR:
|
||||
hyp_f += a[1]
|
||||
else:
|
||||
hyp_f += a[0]
|
||||
hyps_fusion.append(hyp_f)
|
||||
return hyps_fusion
|
||||
|
||||
|
||||
def fusion_hyps_trust_substituion(hyps_whisper, hyps_zipformer, refs, ERR="*"):
|
||||
"""
|
||||
alignment example:
|
||||
[('我', '你'), ('在', '*'), ('任', '任'), ('的', '的'), ('时', '时'), ('候', '候'), ('*', '呢')]
|
||||
left is whisper, right is zipformer
|
||||
for whisper substitution, use left
|
||||
for whisper insertion, use right
|
||||
for whisper deletion, use right
|
||||
"""
|
||||
hyps_fusion = []
|
||||
for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs):
|
||||
ali = kaldialign.align(hyp_w, hyp_z, ERR)
|
||||
hyp_f = ""
|
||||
for a in ali:
|
||||
if a[0] == ERR:
|
||||
hyp_f += a[1]
|
||||
elif a[1] == ERR:
|
||||
pass
|
||||
else:
|
||||
hyp_f += a[0]
|
||||
hyps_fusion.append(hyp_f)
|
||||
return hyps_fusion
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
# mkdir output_log_dir
|
||||
Path(args.output_log_dir).mkdir(parents=True, exist_ok=True)
|
||||
pair_logs = get_pair_filenames(args.whisper_log_dir, args.zipformer_log_dir)
|
||||
for pair in pair_logs:
|
||||
hyps_whisper, refs, wav_name = extract_hyp_ref_wavname(pair[0])
|
||||
hyps_zipformer, _, _ = extract_hyp_ref_wavname(pair[1])
|
||||
|
||||
hyps_fusion = fusion_hyps_trust_substituion_insertion(
|
||||
hyps_whisper, hyps_zipformer, refs
|
||||
)
|
||||
|
||||
partition_name = pair[0].split("/")[-1].split("-")[1]
|
||||
save_results(
|
||||
Path(args.output_log_dir),
|
||||
partition_name,
|
||||
{"fusion": list(zip(wav_name, refs, hyps_fusion))},
|
||||
)
|
||||
|
||||
print(f"Processed {partition_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
67
egs/speechio/ASR/prepare.sh
Normal file
67
egs/speechio/ASR/prepare.sh
Normal file
@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=3
|
||||
stop_stage=3
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
# directories and files. If not, they will be downloaded
|
||||
# by this script automatically.
|
||||
#
|
||||
# - $dl_dir/SPEECHIO_ASR_ZH00000
|
||||
# This directory contains the following files downloaded from
|
||||
# https://github.com/SpeechColab/Leaderboard
|
||||
#
|
||||
# - metadata.tsv
|
||||
# - wav
|
||||
# - wav.scp
|
||||
# - trans.txt
|
||||
#
|
||||
|
||||
dl_dir=$PWD/download
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare speechio manifest"
|
||||
# We assume that you have downloaded the speechio dataset
|
||||
# to $dl_dir
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.speechio.done ]; then
|
||||
lhotse prepare speechio $dl_dir data/manifests
|
||||
touch data/manifests/.speechio.done
|
||||
fi
|
||||
fi
|
||||
|
||||
whisper_mel_bins=80
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Compute whisper fbank for speechio"
|
||||
if [ ! -f data/fbank/.speechio.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_speechio.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
touch data/fbank/.speechio.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Compute kaldi fbank for speechio"
|
||||
if [ ! -f data/fbank/.speechio.kaldi.done ]; then
|
||||
fbank_dir=data/fbank_kaldi
|
||||
mkdir -p $fbank_dir
|
||||
./local/compute_fbank_speechio.py --fbank-dir $fbank_dir
|
||||
touch data/fbank/.speechio.kaldi.done
|
||||
fi
|
||||
fi
|
||||
1
egs/speechio/ASR/shared
Symbolic link
1
egs/speechio/ASR/shared
Symbolic link
@ -0,0 +1 @@
|
||||
../../../icefall/shared//
|
||||
195
egs/speechio/ASR/whisper/asr_datamodule.py
Normal file
195
egs/speechio/ASR/whisper/asr_datamodule.py
Normal file
@ -0,0 +1,195 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import AudioSamples # noqa F401 For AudioSamples
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class AsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
There is no train and valid dataloader, for speechio dataset
|
||||
but there can be multiple test dataloaders.
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=300.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--concatenate-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, utterances (cuts) will be concatenated "
|
||||
"to minimize the amount of padding.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--duration-factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Determines the maximum duration of a concatenated cut "
|
||||
"relative to the duration of the longest cut in a batch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gap",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The amount of padding (in seconds) inserted between "
|
||||
"concatenated cuts. This padding is filled with noise when "
|
||||
"noise augmentation is used.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start-index",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Decoding will start from dataset SPEECHIO_ASR_ZH000index",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--end-index",
|
||||
type=int,
|
||||
default=26,
|
||||
help="Decoding will end with dataset SPEECHIO_ASR_ZH000index",
|
||||
)
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
520
egs/speechio/ASR/whisper/decode.py
Normal file
520
egs/speechio/ASR/whisper/decode.py
Normal file
@ -0,0 +1,520 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
|
||||
# Fangjun Kuang,
|
||||
# Wei Kang)
|
||||
# 2024 Yuekai Zhang
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
# Command for decoding using fine-tuned models:
|
||||
git lfs install
|
||||
git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper
|
||||
ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt
|
||||
|
||||
python3 ./whisper/decode.py \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
--epoch 999 --avg 1 \
|
||||
--beam-size 10 --max-duration 50
|
||||
|
||||
# Command for decoding using pretrained models (before fine-tuning):
|
||||
|
||||
python3 ./whisper/decode.py \
|
||||
--exp-dir whisper/exp_large_v2_pretrained \
|
||||
--model-name large-v2 \
|
||||
--epoch -1 --avg 1 \
|
||||
--start-index 14 --end-index 15 \
|
||||
--remove-whisper-encoder-input-length-restriction False \
|
||||
--beam-size 1 --max-duration 50
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import whisper
|
||||
from asr_datamodule import AsrDataModule
|
||||
from lhotse.cut import Cut
|
||||
from multi_dataset import MultiDataset
|
||||
from tn.chinese.normalizer import Normalizer
|
||||
from whisper.normalizers import BasicTextNormalizer
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from zhconv import convert
|
||||
|
||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def average_checkpoints(
|
||||
filenames: List[Path], device: torch.device = torch.device("cpu")
|
||||
) -> dict:
|
||||
"""Average a list of checkpoints.
|
||||
The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict.
|
||||
|
||||
Args:
|
||||
filenames:
|
||||
Filenames of the checkpoints to be averaged. We assume all
|
||||
checkpoints are saved by :func:`save_checkpoint`.
|
||||
device:
|
||||
Move checkpoints to this device before averaging.
|
||||
Returns:
|
||||
Return a dict (i.e., state_dict) which is the average of all
|
||||
model state dicts contained in the checkpoints.
|
||||
"""
|
||||
n = len(filenames)
|
||||
|
||||
if "model" in torch.load(filenames[0], map_location=device):
|
||||
avg = torch.load(filenames[0], map_location=device)["model"]
|
||||
else:
|
||||
avg = torch.load(filenames[0], map_location=device)
|
||||
|
||||
# Identify shared parameters. Two parameters are said to be shared
|
||||
# if they have the same data_ptr
|
||||
uniqued: Dict[int, str] = dict()
|
||||
|
||||
for k, v in avg.items():
|
||||
v_data_ptr = v.data_ptr()
|
||||
if v_data_ptr in uniqued:
|
||||
continue
|
||||
uniqued[v_data_ptr] = k
|
||||
|
||||
uniqued_names = list(uniqued.values())
|
||||
|
||||
for i in range(1, n):
|
||||
if "model" in torch.load(filenames[i], map_location=device):
|
||||
state_dict = torch.load(filenames[i], map_location=device)["model"]
|
||||
else:
|
||||
state_dict = torch.load(filenames[i], map_location=device)
|
||||
for k in uniqued_names:
|
||||
avg[k] += state_dict[k]
|
||||
|
||||
for k in uniqued_names:
|
||||
if avg[k].is_floating_point():
|
||||
avg[k] /= n
|
||||
else:
|
||||
avg[k] //= n
|
||||
|
||||
return avg
|
||||
|
||||
|
||||
def remove_punctuation(text: str or List[str]):
|
||||
"""Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
|
||||
|
||||
Args:
|
||||
text: It can be a string or a list of strings.
|
||||
Returns:
|
||||
Return a string or a list of strings without any punctuation.
|
||||
"""
|
||||
punctuation = "!,.;:?、!,。;:?《》 "
|
||||
if isinstance(text, str):
|
||||
text = re.sub(r"[{}]+".format(punctuation), "", text).strip()
|
||||
return text
|
||||
elif isinstance(text, list):
|
||||
result_text = []
|
||||
for t in text:
|
||||
t = re.sub(r"[{}]+".format(punctuation), "", t).strip()
|
||||
result_text.append(t)
|
||||
return result_text
|
||||
else:
|
||||
raise Exception(f"Not support type {type(text)}")
|
||||
|
||||
|
||||
def to_simple(text: str or List[str]):
|
||||
"""Convert traditional Chinese to simplified Chinese.
|
||||
Args:
|
||||
text: It can be a string or a list of strings.
|
||||
Returns:
|
||||
Return a string or a list of strings converted to simplified Chinese.
|
||||
"""
|
||||
if isinstance(text, str):
|
||||
text = convert(text, "zh-cn")
|
||||
return text
|
||||
elif isinstance(text, list):
|
||||
result_text = []
|
||||
for t in text:
|
||||
t = convert(t, "zh-cn")
|
||||
result_text.append(t)
|
||||
return result_text
|
||||
else:
|
||||
raise Exception(f"Not support type{type(text)}")
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="beam-search",
|
||||
help="""Decoding method.
|
||||
Supported values are:
|
||||
- beam-search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="beam size for beam search decoding",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="whisper/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="large-v2",
|
||||
choices=["large-v2", "large-v3", "medium", "small", "tiny"],
|
||||
help="""The model name to use.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--remove-whisper-encoder-input-length-restriction",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="replace whisper encoder forward method to remove input length restriction",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
batch: dict,
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: "beam-search"
|
||||
- value: A list of lists. Each sublist is a list of token IDs.
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
batch:
|
||||
It is returned by :meth:`torch.utils.data.DataLoader.__iter__`.
|
||||
Returns:
|
||||
Return a dict, whose key may be "beam-search".
|
||||
"""
|
||||
dtype = torch.float16
|
||||
device = torch.device("cuda")
|
||||
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device, dtype=dtype).transpose(1, 2)
|
||||
if not params.remove_whisper_encoder_input_length_restriction:
|
||||
T = 3000
|
||||
if feature.shape[2] < T:
|
||||
feature = torch.cat(
|
||||
[
|
||||
feature,
|
||||
torch.zeros(
|
||||
feature.shape[0], feature.shape[1], T - feature.shape[2]
|
||||
).to(device, dtype=dtype),
|
||||
],
|
||||
2,
|
||||
)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_len = supervisions["num_frames"]
|
||||
feature_len = feature_len.to(device, dtype=dtype)
|
||||
results = model.decode(feature, params.decoding_options)
|
||||
hyps = [result.text for result in results]
|
||||
|
||||
hyps = remove_punctuation(hyps)
|
||||
hyps = to_simple(hyps)
|
||||
hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
|
||||
print(hyps)
|
||||
return {"beam-search": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
The dataloader.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
Returns:
|
||||
Return a dict, whose key may be "beam-search".
|
||||
"""
|
||||
results = []
|
||||
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
|
||||
num_cuts += len(batch["supervisions"]["text"])
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
|
||||
enable_log = True
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
if enable_log:
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tCER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
AsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
setup_logger(
|
||||
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
|
||||
)
|
||||
|
||||
options = whisper.DecodingOptions(
|
||||
task="transcribe",
|
||||
language="zh",
|
||||
without_timestamps=True,
|
||||
beam_size=params.beam_size,
|
||||
)
|
||||
params.decoding_options = options
|
||||
params.cleaner = BasicTextNormalizer()
|
||||
params.normalizer = Normalizer()
|
||||
|
||||
logging.info("Decoding started")
|
||||
logging.info(params)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
if params.remove_whisper_encoder_input_length_restriction:
|
||||
replace_whisper_encoder_forward()
|
||||
model = whisper.load_model(params.model_name, "cpu")
|
||||
if params.epoch > 0:
|
||||
if params.avg > 1:
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
)
|
||||
if "model" not in checkpoint:
|
||||
# deepspeed converted checkpoint only contains model state_dict
|
||||
filenames = [
|
||||
f"{params.exp_dir}/epoch-{epoch}.pt"
|
||||
for epoch in range(start, params.epoch + 1)
|
||||
]
|
||||
model.load_state_dict(average_checkpoints(filenames))
|
||||
else:
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
# save checkpoints
|
||||
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
torch.save(model.state_dict(), filename)
|
||||
else:
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
)
|
||||
if "model" not in checkpoint:
|
||||
model.load_state_dict(checkpoint, strict=True)
|
||||
else:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
|
||||
data_module = AsrDataModule(args)
|
||||
multi_dataset = MultiDataset(args.manifest_dir, args.start_index, args.end_index)
|
||||
|
||||
def remove_long_utt(c: Cut):
|
||||
# Keep only utterances with duration in 30 seconds
|
||||
#
|
||||
if c.duration > 30.0:
|
||||
# logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
return False
|
||||
return True
|
||||
|
||||
test_sets_cuts = multi_dataset.test_cuts()
|
||||
|
||||
test_sets = test_sets_cuts.keys()
|
||||
test_dls = [
|
||||
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_long_utt))
|
||||
for cuts_name in test_sets
|
||||
]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
)
|
||||
|
||||
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
59
egs/speechio/ASR/whisper/multi_dataset.py
Normal file
59
egs/speechio/ASR/whisper/multi_dataset.py
Normal file
@ -0,0 +1,59 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import glob
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import lhotse
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
|
||||
|
||||
class MultiDataset:
|
||||
def __init__(self, fbank_dir: str, start_index: int = 0, end_index: int = 26):
|
||||
"""
|
||||
Args:
|
||||
manifest_dir:
|
||||
It is expected to contain the following files:
|
||||
- speechio_cuts_SPEECHIO_ASR_ZH00000.jsonl.gz
|
||||
...
|
||||
- speechio_cuts_SPEECHIO_ASR_ZH00026.jsonl.gz
|
||||
"""
|
||||
self.fbank_dir = Path(fbank_dir)
|
||||
self.start_index = start_index
|
||||
self.end_index = end_index
|
||||
|
||||
def test_cuts(self) -> Dict[str, CutSet]:
|
||||
logging.info("About to get multidataset test cuts")
|
||||
|
||||
dataset_parts = []
|
||||
for i in range(self.start_index, self.end_index + 1):
|
||||
idx = f"{i}".zfill(2)
|
||||
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
|
||||
|
||||
prefix = "speechio"
|
||||
suffix = "jsonl.gz"
|
||||
|
||||
results_dict = {}
|
||||
for partition in dataset_parts:
|
||||
path = f"{prefix}_cuts_{partition}.{suffix}"
|
||||
|
||||
logging.info(f"Loading {path} set in lazy mode")
|
||||
test_cuts = load_manifest_lazy(self.fbank_dir / path)
|
||||
results_dict[partition] = test_cuts
|
||||
|
||||
return results_dict
|
||||
1
egs/speechio/ASR/whisper/requirements.txt
Symbolic link
1
egs/speechio/ASR/whisper/requirements.txt
Symbolic link
@ -0,0 +1 @@
|
||||
../../../aishell/ASR/whisper/requirements.txt
|
||||
1
egs/speechio/ASR/whisper/whisper_encoder_forward_monkey_patch.py
Symbolic link
1
egs/speechio/ASR/whisper/whisper_encoder_forward_monkey_patch.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py
|
||||
1
egs/speechio/ASR/zipformer/asr_datamodule.py
Symbolic link
1
egs/speechio/ASR/zipformer/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
||||
../whisper/asr_datamodule.py
|
||||
1
egs/speechio/ASR/zipformer/beam_search.py
Symbolic link
1
egs/speechio/ASR/zipformer/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/beam_search.py
|
||||
623
egs/speechio/ASR/zipformer/ctc_decode.py
Normal file
623
egs/speechio/ASR/zipformer/ctc_decode.py
Normal file
@ -0,0 +1,623 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Liyong Guo,
|
||||
# Quandong Wang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
|
||||
(1) ctc-decoding
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method ctc-decoding
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AsrDataModule
|
||||
from lhotse.cut import Cut
|
||||
from multi_dataset import MultiDataset
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.decode import get_lattice, one_best_decoding
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_2000/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bpe_2000",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="ctc-decoding",
|
||||
help="""Decoding method.
|
||||
Supported values are:
|
||||
- (1) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
||||
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||
It needs neither a lexicon nor an n-gram LM.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""Number of paths for n-best based decoding method.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, and nbest-oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="""The scale to be applied to `lattice.scores`.
|
||||
It's needed if you use any kinds of n-best based rescoring.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, and nbest-oracle
|
||||
A smaller value results in more unique paths.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_decoding_params() -> AttributeDict:
|
||||
"""Parameters for decoding."""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"frame_shift_ms": 10,
|
||||
"search_beam": 20,
|
||||
"output_beam": 8,
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
batch: dict,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if no rescoring is used, the key is the string `no_rescore`.
|
||||
If LM rescoring is used, the key is the string `lm_scale_xxx`,
|
||||
where `xxx` is the value of `lm_scale`. An example key is
|
||||
`lm_scale_0.7`
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
|
||||
- params.decoding_method is "1best", it uses 1best decoding without LM rescoring.
|
||||
- params.decoding_method is "nbest", it uses nbest decoding without LM rescoring.
|
||||
- params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring.
|
||||
- params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM
|
||||
rescoring.
|
||||
|
||||
model:
|
||||
The neural model.
|
||||
H:
|
||||
The ctc topo. Used only when params.decoding_method is ctc-decoding.
|
||||
bpe_model:
|
||||
The BPE model. Used only when params.decoding_method is ctc-decoding.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
G:
|
||||
An LM. It is not None when params.decoding_method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
is a 3-gram LM, while this G is a 4-gram LM.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict. Note: If it decodes to nothing, then return None.
|
||||
"""
|
||||
device = H.device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
if params.causal:
|
||||
# this seems to cause insertions at the end of the utterance if used with zipformer.
|
||||
pad_len = 30
|
||||
feature_lens += pad_len
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, pad_len),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
|
||||
ctc_output = model.ctc_output(encoder_out) # (N, T, C)
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
torch.div(
|
||||
supervisions["start_frame"],
|
||||
params.subsampling_factor,
|
||||
rounding_mode="floor",
|
||||
),
|
||||
torch.div(
|
||||
supervisions["num_frames"],
|
||||
params.subsampling_factor,
|
||||
rounding_mode="floor",
|
||||
),
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
||||
assert bpe_model is not None
|
||||
decoding_graph = H
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=ctc_output,
|
||||
decoding_graph=decoding_graph,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
if params.decoding_method == "ctc-decoding":
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
# Note: `best_path.aux_labels` contains token IDs, not word IDs
|
||||
# since we are using H, not HLG here.
|
||||
#
|
||||
# token_ids is a lit-of-list of IDs
|
||||
token_ids = get_texts(best_path)
|
||||
|
||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
|
||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||
hyps = [s.split() for s in hyps]
|
||||
key = "ctc-decoding"
|
||||
return {key: hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
word_table: k2.SymbolTable,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
|
||||
H:
|
||||
The ctc topo. Used only when params.decoding_method is ctc-decoding.
|
||||
bpe_model:
|
||||
The BPE model. Used only when params.decoding_method is ctc-decoding.
|
||||
word_table:
|
||||
It is the word symbol table.
|
||||
G:
|
||||
An LM. It is not None when params.decoding_method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
is a 3-gram LM, while this G is a 4-gram LM.
|
||||
Returns:
|
||||
Return a dict, whose key may be "no-rescore" if no LM rescoring
|
||||
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = list(ref_text.replace(" ", ""))
|
||||
hyp_words = list("".join(hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
AsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
params = get_params()
|
||||
# add decoding params
|
||||
params.update(get_decoding_params())
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in ("ctc-decoding",)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.causal:
|
||||
assert (
|
||||
"," not in params.chunk_size
|
||||
), "chunk_size should be one value in decoding."
|
||||
assert (
|
||||
"," not in params.left_context_frames
|
||||
), "left_context_frames should be one value in decoding."
|
||||
params.suffix += f"-chunk-{params.chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context_frames}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
logging.info(params)
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
|
||||
params.vocab_size = num_classes
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = 0
|
||||
|
||||
HLG = None
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
modified=True,
|
||||
device=device,
|
||||
)
|
||||
bpe_model = spm.SentencePieceProcessor()
|
||||
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
||||
|
||||
G = None
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
data_module = AsrDataModule(args)
|
||||
multi_dataset = MultiDataset(args.manifest_dir, args.start_index, args.end_index)
|
||||
|
||||
test_sets_cuts = multi_dataset.test_cuts()
|
||||
|
||||
def remove_short_utt(c: Cut):
|
||||
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||
if T <= 0:
|
||||
logging.warning(
|
||||
f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}"
|
||||
)
|
||||
return T > 0
|
||||
|
||||
test_sets = test_sets_cuts.keys()
|
||||
test_dl = [
|
||||
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt))
|
||||
for cuts_name in test_sets
|
||||
]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
logging.info(f"Start decoding test set: {test_set}")
|
||||
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
word_table=lexicon.word_table,
|
||||
G=G,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
843
egs/speechio/ASR/zipformer/decode.py
Normal file
843
egs/speechio/ASR/zipformer/decode.py
Normal file
@ -0,0 +1,843 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search (one best)
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
(5) fast beam search (nbest)
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(6) fast beam search (nbest oracle WER)
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_oracle \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(7) fast beam search (with LG)
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_LG \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from lhotse.cut import Cut
|
||||
from multi_dataset import MultiDataset
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
make_pad_mask,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_2000/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bpe_2000",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--blank-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""
|
||||
The penalty applied on blank symbol during decoding.
|
||||
Note: It is a positive value that would be applied to logits like
|
||||
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||
[batch_size, vocab] and blank id is 0).
|
||||
""",
|
||||
)
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
if params.causal:
|
||||
# this seems to cause insertions at the end of the utterance if used with zipformer.
|
||||
pad_len = 30
|
||||
feature_lens += pad_len
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, pad_len),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest":
|
||||
hyp_tokens = fast_beam_search_nbest(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=sp.encode(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
blank_penalty=params.blank_penalty,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
blank_penalty=params.blank_penalty,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
key = f"blank_penalty_{params.blank_penalty}"
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search_" + key: hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
texts = [list(str(text).replace(" ", "")) for text in texts]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
hyp_text = "".join(hyp_words)
|
||||
this_batch.append((cut_id, ref_text, hyp_text))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
AsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.causal:
|
||||
assert (
|
||||
"," not in params.chunk_size
|
||||
), "chunk_size should be one value in decoding."
|
||||
assert (
|
||||
"," not in params.left_context_frames
|
||||
), "left_context_frames should be one value in decoding."
|
||||
params.suffix += f"-chunk-{params.chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context_frames}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
params.suffix += f"-blank-penalty-{params.blank_penalty}"
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
word_table = lexicon.word_table
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
word_table = None
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
data_module = AsrDataModule(args)
|
||||
multi_dataset = MultiDataset(args.manifest_dir, args.start_index, args.end_index)
|
||||
|
||||
def remove_short_utt(c: Cut):
|
||||
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||
if T <= 0:
|
||||
logging.warning(
|
||||
f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}"
|
||||
)
|
||||
return T > 0
|
||||
|
||||
test_sets_cuts = multi_dataset.test_cuts()
|
||||
|
||||
test_sets = test_sets_cuts.keys()
|
||||
test_dl = [
|
||||
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt))
|
||||
for cuts_name in test_sets
|
||||
]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
logging.info(f"Start decoding test set: {test_set}")
|
||||
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
egs/speechio/ASR/zipformer/decoder.py
Symbolic link
1
egs/speechio/ASR/zipformer/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/decoder.py
|
||||
1
egs/speechio/ASR/zipformer/encoder_interface.py
Symbolic link
1
egs/speechio/ASR/zipformer/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/encoder_interface.py
|
||||
1
egs/speechio/ASR/zipformer/joiner.py
Symbolic link
1
egs/speechio/ASR/zipformer/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/joiner.py
|
||||
1
egs/speechio/ASR/zipformer/model.py
Symbolic link
1
egs/speechio/ASR/zipformer/model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/model.py
|
||||
1
egs/speechio/ASR/zipformer/multi_dataset.py
Symbolic link
1
egs/speechio/ASR/zipformer/multi_dataset.py
Symbolic link
@ -0,0 +1 @@
|
||||
../whisper/multi_dataset.py
|
||||
1
egs/speechio/ASR/zipformer/optim.py
Symbolic link
1
egs/speechio/ASR/zipformer/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/optim.py
|
||||
1
egs/speechio/ASR/zipformer/scaling.py
Symbolic link
1
egs/speechio/ASR/zipformer/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/scaling.py
|
||||
1
egs/speechio/ASR/zipformer/scaling_converter.py
Symbolic link
1
egs/speechio/ASR/zipformer/scaling_converter.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/scaling_converter.py
|
||||
1
egs/speechio/ASR/zipformer/subsampling.py
Symbolic link
1
egs/speechio/ASR/zipformer/subsampling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/subsampling.py
|
||||
1
egs/speechio/ASR/zipformer/train.py
Symbolic link
1
egs/speechio/ASR/zipformer/train.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../multi_zh-hans/ASR/zipformer/train.py
|
||||
1
egs/speechio/ASR/zipformer/zipformer.py
Symbolic link
1
egs/speechio/ASR/zipformer/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/zipformer.py
|
||||
@ -16,11 +16,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
@ -30,8 +38,31 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
def compute_fbank_wenetspeech_dev_test():
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-mel-bins",
|
||||
type=int,
|
||||
default=80,
|
||||
help="""The number of mel bins for Fbank""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def compute_fbank_wenetspeech_dev_test(args):
|
||||
in_out_dir = Path("data/fbank")
|
||||
# number of workers in dataloader
|
||||
num_workers = 42
|
||||
@ -44,7 +75,12 @@ def compute_fbank_wenetspeech_dev_test():
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
if args.whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
@ -82,7 +118,11 @@ def main():
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_wenetspeech_dev_test()
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
compute_fbank_wenetspeech_dev_test(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -22,15 +22,19 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import (
|
||||
from lhotse import ( # KaldifeatWhisperFbank,; KaldifeatWhisperFbankConfig,
|
||||
CutSet,
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
set_audio_duration_mismatch_tolerance,
|
||||
set_caching_enabled,
|
||||
)
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
@ -87,6 +91,27 @@ def get_parser():
|
||||
default=-1,
|
||||
help="Stop processing pieces until this number (excluded).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-mel-bins",
|
||||
type=int,
|
||||
default=80,
|
||||
help="""The number of mel bins for Fbank""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-dir-prefix",
|
||||
type=str,
|
||||
default="",
|
||||
help="Prefix of the output directory.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@ -96,6 +121,7 @@ def compute_fbank_wenetspeech_splits(args):
|
||||
num_splits = args.num_splits
|
||||
output_dir = f"data/fbank/{subset}_split_{num_splits}"
|
||||
output_dir = Path(output_dir)
|
||||
output_dir = Path(args.output_dir_prefix) / output_dir
|
||||
assert output_dir.exists(), f"{output_dir} does not exist!"
|
||||
|
||||
num_digits = len(str(num_splits))
|
||||
@ -110,14 +136,21 @@ def compute_fbank_wenetspeech_splits(args):
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
if args.whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device=device)
|
||||
)
|
||||
# extractor = KaldifeatWhisperFbank(KaldifeatWhisperFbankConfig(num_filters=args.num_mel_bins, device=device))
|
||||
else:
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance
|
||||
set_caching_enabled(False)
|
||||
# with get_executor() as ex: # Initialize the executor only once.
|
||||
for i in range(start, stop):
|
||||
idx = f"{i + 1}".zfill(num_digits)
|
||||
logging.info(f"Processing {idx}/{num_splits}")
|
||||
idx = f"{i}".zfill(num_digits)
|
||||
logging.info(f"Processing {i+1}/{num_splits}")
|
||||
|
||||
cuts_path = output_dir / f"cuts_{subset}.{idx}.jsonl.gz"
|
||||
if cuts_path.is_file():
|
||||
@ -143,7 +176,6 @@ def compute_fbank_wenetspeech_splits(args):
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
logging.info(f"Saving to {cuts_path}")
|
||||
cut_set.to_file(cuts_path)
|
||||
|
||||
|
||||
@ -182,6 +182,43 @@ if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
whisper_mel_bins=80
|
||||
if [ $stage -le 129 ] && [ $stop_stage -ge 129 ]; then
|
||||
log "Stage 129: compute whisper fbank for dev and test sets"
|
||||
python3 ./local/compute_fbank_wenetspeech_dev_test.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
fi
|
||||
if [ $stage -le 130 ] && [ $stop_stage -ge 130 ]; then
|
||||
log "Stage 130: Comute features for whisper training set"
|
||||
|
||||
split_dir=data/fbank/L_split_${num_splits}
|
||||
if [ ! -f $split_dir/.split_completed ]; then
|
||||
lhotse split $num_splits ./data/fbank/cuts_L_raw.jsonl.gz $split_dir
|
||||
touch $split_dir/.split_completed
|
||||
fi
|
||||
|
||||
python3 ./local/compute_fbank_wenetspeech_splits.py \
|
||||
--training-subset L \
|
||||
--num-workers 8 \
|
||||
--batch-duration 1600 \
|
||||
--start 0 \
|
||||
--num-mel-bins ${whisper_mel_bins} --whisper-fbank true \
|
||||
--num-splits $num_splits
|
||||
|
||||
if [ ! -f data/fbank/cuts_L.jsonl.gz ]; then
|
||||
pieces=$(find data/fbank/L_split_${num_splits} -name "cuts_L.*.jsonl.gz")
|
||||
lhotse combine $pieces data/fbank/cuts_L.jsonl.gz
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 131 ] && [ $stop_stage -ge 131 ]; then
|
||||
log "Stage 131: concat feats into train set"
|
||||
if [ ! -f data/fbank/cuts_L.jsonl.gz ]; then
|
||||
pieces=$(find data/fbank/L_split_${num_splits} -name "cuts_L.*.jsonl.gz")
|
||||
lhotse combine $pieces data/fbank/cuts_L.jsonl.gz
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
|
||||
log "Stage 14: Compute fbank for musan"
|
||||
mkdir -p data/fbank
|
||||
@ -272,7 +309,7 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
||||
mkdir -p $text_out_dir
|
||||
|
||||
log "Genearating training text data"
|
||||
|
||||
|
||||
if [ ! -f $text_out_dir/lm_data.pt ]; then
|
||||
./local/prepare_char_lm_training_data.py \
|
||||
--lang-char data/lang_char \
|
||||
@ -281,14 +318,14 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
||||
fi
|
||||
|
||||
log "Generating DEV text data"
|
||||
# prepare validation text data
|
||||
# prepare validation text data
|
||||
if [ ! -f $text_out_dir/valid_text_words_segmentation ]; then
|
||||
valid_text=${text_out_dir}/
|
||||
|
||||
gunzip -c data/manifests/wenetspeech_supervisions_DEV.jsonl.gz \
|
||||
| jq '.text' | sed 's/"//g' \
|
||||
| ./local/text2token.py -t "char" > $text_out_dir/valid_text
|
||||
|
||||
|
||||
python3 ./local/text2segments.py \
|
||||
--num-process $nj \
|
||||
--input-file $text_out_dir/valid_text \
|
||||
@ -300,7 +337,7 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
||||
--lm-data $text_out_dir/valid_text_words_segmentation \
|
||||
--lm-archive $text_out_dir/lm_data_valid.pt
|
||||
|
||||
# prepare TEST text data
|
||||
# prepare TEST text data
|
||||
if [ ! -f $text_out_dir/TEST_text_words_segmentation ]; then
|
||||
log "Prepare text for test set."
|
||||
for test_set in TEST_MEETING TEST_NET; do
|
||||
@ -313,7 +350,7 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
||||
--input-file $text_out_dir/${test_set}_text \
|
||||
--output-file $text_out_dir/${test_set}_text_words_segmentation
|
||||
done
|
||||
|
||||
|
||||
cat $text_out_dir/TEST_*_text_words_segmentation > $text_out_dir/test_text_words_segmentation
|
||||
fi
|
||||
|
||||
|
||||
1
egs/wenetspeech/ASR/whisper/asr_datamodule.py
Symbolic link
1
egs/wenetspeech/ASR/whisper/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/asr_datamodule.py
|
||||
526
egs/wenetspeech/ASR/whisper/decode.py
Executable file
526
egs/wenetspeech/ASR/whisper/decode.py
Executable file
@ -0,0 +1,526 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
|
||||
# Fangjun Kuang,
|
||||
# Wei Kang)
|
||||
# 2024 Yuekai Zhang
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
# Command for decoding using fine-tuned models:
|
||||
git lfs install
|
||||
git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper
|
||||
ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt
|
||||
|
||||
python3 ./whisper/decode.py \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
--epoch 999 --avg 1 \
|
||||
--beam-size 10 --max-duration 50
|
||||
|
||||
# Command for decoding using pretrained models (before fine-tuning):
|
||||
|
||||
python3 ./whisper/decode.py \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
--epoch -1 --avg 1 \
|
||||
--remove-whisper-encoder-input-length-restriction False \
|
||||
--beam-size 10 --max-duration 50
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import whisper
|
||||
from asr_datamodule import WenetSpeechAsrDataModule
|
||||
from lhotse.cut import Cut
|
||||
from tn.chinese.normalizer import Normalizer
|
||||
from whisper.normalizers import BasicTextNormalizer
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from zhconv import convert
|
||||
|
||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def average_checkpoints(
|
||||
filenames: List[Path], device: torch.device = torch.device("cpu")
|
||||
) -> dict:
|
||||
"""Average a list of checkpoints.
|
||||
The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict.
|
||||
|
||||
Args:
|
||||
filenames:
|
||||
Filenames of the checkpoints to be averaged. We assume all
|
||||
checkpoints are saved by :func:`save_checkpoint`.
|
||||
device:
|
||||
Move checkpoints to this device before averaging.
|
||||
Returns:
|
||||
Return a dict (i.e., state_dict) which is the average of all
|
||||
model state dicts contained in the checkpoints.
|
||||
"""
|
||||
n = len(filenames)
|
||||
|
||||
if "model" in torch.load(filenames[0], map_location=device):
|
||||
avg = torch.load(filenames[0], map_location=device)["model"]
|
||||
else:
|
||||
avg = torch.load(filenames[0], map_location=device)
|
||||
|
||||
# Identify shared parameters. Two parameters are said to be shared
|
||||
# if they have the same data_ptr
|
||||
uniqued: Dict[int, str] = dict()
|
||||
|
||||
for k, v in avg.items():
|
||||
v_data_ptr = v.data_ptr()
|
||||
if v_data_ptr in uniqued:
|
||||
continue
|
||||
uniqued[v_data_ptr] = k
|
||||
|
||||
uniqued_names = list(uniqued.values())
|
||||
|
||||
for i in range(1, n):
|
||||
if "model" in torch.load(filenames[i], map_location=device):
|
||||
state_dict = torch.load(filenames[i], map_location=device)["model"]
|
||||
else:
|
||||
state_dict = torch.load(filenames[i], map_location=device)
|
||||
for k in uniqued_names:
|
||||
avg[k] += state_dict[k]
|
||||
|
||||
for k in uniqued_names:
|
||||
if avg[k].is_floating_point():
|
||||
avg[k] /= n
|
||||
else:
|
||||
avg[k] //= n
|
||||
|
||||
return avg
|
||||
|
||||
|
||||
def remove_punctuation(text: str or List[str]):
|
||||
"""Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
|
||||
|
||||
Args:
|
||||
text: It can be a string or a list of strings.
|
||||
Returns:
|
||||
Return a string or a list of strings without any punctuation.
|
||||
"""
|
||||
punctuation = "!,.;:?、!,。;:?《》 "
|
||||
if isinstance(text, str):
|
||||
text = re.sub(r"[{}]+".format(punctuation), "", text).strip()
|
||||
return text
|
||||
elif isinstance(text, list):
|
||||
result_text = []
|
||||
for t in text:
|
||||
t = re.sub(r"[{}]+".format(punctuation), "", t).strip()
|
||||
result_text.append(t)
|
||||
return result_text
|
||||
else:
|
||||
raise Exception(f"Not support type {type(text)}")
|
||||
|
||||
|
||||
def to_simple(text: str or List[str]):
|
||||
"""Convert traditional Chinese to simplified Chinese.
|
||||
Args:
|
||||
text: It can be a string or a list of strings.
|
||||
Returns:
|
||||
Return a string or a list of strings converted to simplified Chinese.
|
||||
"""
|
||||
if isinstance(text, str):
|
||||
text = convert(text, "zh-cn")
|
||||
return text
|
||||
elif isinstance(text, list):
|
||||
result_text = []
|
||||
for t in text:
|
||||
t = convert(t, "zh-cn")
|
||||
result_text.append(t)
|
||||
return result_text
|
||||
else:
|
||||
raise Exception(f"Not support type{type(text)}")
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="beam-search",
|
||||
help="""Decoding method.
|
||||
Supported values are:
|
||||
- beam-search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="beam size for beam search decoding",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="whisper/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="large-v2",
|
||||
choices=["large-v2", "large-v3", "medium", "small", "tiny"],
|
||||
help="""The model name to use.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--remove-whisper-encoder-input-length-restriction",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="replace whisper encoder forward method to remove input length restriction",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
batch: dict,
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: "beam-search"
|
||||
- value: A list of lists. Each sublist is a list of token IDs.
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
batch:
|
||||
It is returned by :meth:`torch.utils.data.DataLoader.__iter__`.
|
||||
Returns:
|
||||
Return a dict, whose key may be "beam-search".
|
||||
"""
|
||||
dtype = torch.float16
|
||||
device = torch.device("cuda")
|
||||
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device, dtype=dtype).transpose(1, 2)
|
||||
if not params.remove_whisper_encoder_input_length_restriction:
|
||||
T = 3000
|
||||
if feature.shape[2] < T:
|
||||
feature = torch.cat(
|
||||
[
|
||||
feature,
|
||||
torch.zeros(
|
||||
feature.shape[0], feature.shape[1], T - feature.shape[2]
|
||||
).to(device, dtype=dtype),
|
||||
],
|
||||
2,
|
||||
)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_len = supervisions["num_frames"]
|
||||
feature_len = feature_len.to(device, dtype=dtype)
|
||||
results = model.decode(feature, params.decoding_options)
|
||||
hyps = [result.text for result in results]
|
||||
|
||||
hyps = remove_punctuation(hyps)
|
||||
hyps = to_simple(hyps)
|
||||
hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
|
||||
print(hyps)
|
||||
return {"beam-search": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
The dataloader.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
Returns:
|
||||
Return a dict, whose key may be "beam-search".
|
||||
"""
|
||||
results = []
|
||||
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
|
||||
num_cuts += len(batch["supervisions"]["text"])
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
|
||||
enable_log = True
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
if enable_log:
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tCER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
WenetSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
setup_logger(
|
||||
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
|
||||
)
|
||||
|
||||
options = whisper.DecodingOptions(
|
||||
task="transcribe",
|
||||
language="zh",
|
||||
without_timestamps=True,
|
||||
beam_size=params.beam_size,
|
||||
)
|
||||
params.decoding_options = options
|
||||
params.cleaner = BasicTextNormalizer()
|
||||
params.normalizer = Normalizer()
|
||||
|
||||
logging.info("Decoding started")
|
||||
logging.info(params)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
if params.remove_whisper_encoder_input_length_restriction:
|
||||
replace_whisper_encoder_forward()
|
||||
model = whisper.load_model(params.model_name, "cpu")
|
||||
if params.epoch > 0:
|
||||
if params.avg > 1:
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
)
|
||||
if "model" not in checkpoint:
|
||||
# deepspeed converted checkpoint only contains model state_dict
|
||||
filenames = [
|
||||
f"{params.exp_dir}/epoch-{epoch}.pt"
|
||||
for epoch in range(start, params.epoch + 1)
|
||||
]
|
||||
model.load_state_dict(average_checkpoints(filenames))
|
||||
else:
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
# save checkpoints
|
||||
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
torch.save(model.state_dict(), filename)
|
||||
else:
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
)
|
||||
if "model" not in checkpoint:
|
||||
model.load_state_dict(checkpoint, strict=True)
|
||||
else:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
wenetspeech = WenetSpeechAsrDataModule(args)
|
||||
dev_cuts = wenetspeech.valid_cuts()
|
||||
dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
|
||||
|
||||
def remove_long_utt(c: Cut):
|
||||
# Keep only utterances with duration in 30 seconds
|
||||
#
|
||||
if c.duration > 30.0:
|
||||
# logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
return False
|
||||
return True
|
||||
|
||||
test_net_cuts = wenetspeech.test_net_cuts()
|
||||
test_net_cuts = test_net_cuts.filter(remove_long_utt)
|
||||
test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
|
||||
|
||||
test_meeting_cuts = wenetspeech.test_meeting_cuts()
|
||||
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
|
||||
|
||||
# test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
|
||||
# test_dls = [dev_dl, test_net_dl, test_meeting_dl]
|
||||
|
||||
test_sets = ["TEST_NET"]
|
||||
test_dls = [test_net_dl]
|
||||
|
||||
# test_sets = ["TEST_MEETING"]
|
||||
# test_dls = [test_meeting_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
)
|
||||
|
||||
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
egs/wenetspeech/ASR/whisper/ds_config_zero1.json
Symbolic link
1
egs/wenetspeech/ASR/whisper/ds_config_zero1.json
Symbolic link
@ -0,0 +1 @@
|
||||
../../../aishell/ASR/whisper/ds_config_zero1.json
|
||||
1
egs/wenetspeech/ASR/whisper/label_smoothing.py
Symbolic link
1
egs/wenetspeech/ASR/whisper/label_smoothing.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/conformer_ctc/label_smoothing.py
|
||||
1
egs/wenetspeech/ASR/whisper/optim.py
Symbolic link
1
egs/wenetspeech/ASR/whisper/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/optim.py
|
||||
1
egs/wenetspeech/ASR/whisper/requirements.txt
Symbolic link
1
egs/wenetspeech/ASR/whisper/requirements.txt
Symbolic link
@ -0,0 +1 @@
|
||||
../../../aishell/ASR/whisper/requirements.txt
|
||||
955
egs/wenetspeech/ASR/whisper/train.py
Normal file
955
egs/wenetspeech/ASR/whisper/train.py
Normal file
@ -0,0 +1,955 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||
# 2024 Yuekai Zhang
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
|
||||
#fine-tuning with deepspeed zero stage 1
|
||||
torchrun --nproc-per-node 8 ./whisper/train.py \
|
||||
--max-duration 200 \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./whisper/ds_config_zero1.json
|
||||
|
||||
# fine-tuning with ddp
|
||||
torchrun --nproc_per_node 8 ./whisper/train.py \
|
||||
--max-duration 200 \
|
||||
--exp-dir whisper/exp_medium \
|
||||
--base-lr 1e-5 \
|
||||
--model-name medium
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import deepspeed
|
||||
import k2
|
||||
import optim
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import whisper
|
||||
from asr_datamodule import WenetSpeechAsrDataModule
|
||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
from lhotse import CutSet, load_manifest
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.functional import pad as pad_tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import update_averaged_model
|
||||
from icefall.dist import cleanup_dist, get_rank, get_world_size, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
filter_uneven_sized_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
|
||||
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||
if isinstance(model, DDP):
|
||||
# get underlying nn.Module
|
||||
model = model.module
|
||||
for module in model.modules():
|
||||
if hasattr(module, "batch_count"):
|
||||
module.batch_count = batch_count
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tensorboard",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Should various information be logged in tensorboard.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Resume training from this epoch. It should be positive.
|
||||
If larger than 1, it will load checkpoint from
|
||||
exp-dir/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-batch",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --start-epoch is ignored and
|
||||
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="whisper/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="large-v2",
|
||||
choices=["large-v2", "large-v3", "medium", "small", "tiny"],
|
||||
help="""The model name to use.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--base-lr", type=float, default=1e-5, help="The base learning rate."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-batches",
|
||||
type=float,
|
||||
default=5000,
|
||||
help="""Number of steps that affects how rapidly the learning rate
|
||||
decreases. We suggest not to change this.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-epochs",
|
||||
type=float,
|
||||
default=6,
|
||||
help="""Number of epochs that affects how rapidly the learning rate decreases.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="The seed for random generators intended for reproducibility",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--print-diagnostics",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Accumulate stats on activations, print them and exit.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--inf-check",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Add hooks to check for infinite module outputs and gradients.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keep-last-k",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""Only keep this number of checkpoints on disk.
|
||||
For instance, if it is 3, there are only 3 checkpoints
|
||||
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
||||
It does not affect checkpoints with name `epoch-xxx.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--average-period",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Update the averaged model, namely `model_avg`, after processing
|
||||
this number of batches. `model_avg` is a separate version of model,
|
||||
in which each floating-point parameter is the average of all the
|
||||
parameters from the start of training. Each time we take the average,
|
||||
we do: `model_avg = model * (average_period / batch_idx_train) +
|
||||
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-fp16",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser = deepspeed.add_config_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
"""Return a dict containing training parameters.
|
||||
|
||||
All training related parameters that are not passed from the commandline
|
||||
are saved in the variable `params`.
|
||||
|
||||
Commandline options are merged into `params` after they are parsed, so
|
||||
you can also access them via `params`.
|
||||
|
||||
Explanation of options saved in `params`:
|
||||
|
||||
- frame_shift_ms: The frame shift in milliseconds.
|
||||
- allowed_excess_duration_ratio: The allowed excess duration ratio.
|
||||
- best_train_loss: The best training loss so far.
|
||||
- best_valid_loss: The best validation loss so far.
|
||||
- best_train_epoch: The epoch where the best training loss is achieved.
|
||||
- best_valid_epoch: The epoch where the best validation loss is achieved.
|
||||
- batch_idx_train: The batch index of the current batch.
|
||||
- log_interval: Log training stats every `log_interval` batches.
|
||||
- reset_interval: Reset the stats every `reset_interval` batches.
|
||||
- valid_interval: Run validation every `valid_interval` batches.
|
||||
- env_info: The environment information.
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"frame_shift_ms": 10.0,
|
||||
"subsampling_factor": 2,
|
||||
"allowed_excess_duration_ratio": 0.1,
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 10000,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
model_avg: nn.Module = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Load checkpoint from file.
|
||||
|
||||
If params.start_batch is positive, it will load the checkpoint from
|
||||
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
||||
params.start_epoch is larger than 1, it will load the checkpoint from
|
||||
`params.start_epoch - 1`.
|
||||
|
||||
Apart from loading state dict for `model` and `optimizer` it also updates
|
||||
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||
and `best_valid_loss` in `params`.
|
||||
|
||||
Args:
|
||||
params:
|
||||
The return value of :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
optimizer:
|
||||
The optimizer that we are using.
|
||||
scheduler:
|
||||
The scheduler that we are using.
|
||||
Returns:
|
||||
Return a dict containing previously saved training info.
|
||||
"""
|
||||
if params.start_batch > 0:
|
||||
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
||||
elif params.start_epoch > 1:
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
else:
|
||||
return None
|
||||
|
||||
assert filename.is_file(), f"{filename} does not exist!"
|
||||
|
||||
saved_params = load_checkpoint(
|
||||
filename,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
keys = [
|
||||
"best_train_epoch",
|
||||
"best_valid_epoch",
|
||||
"batch_idx_train",
|
||||
"best_train_loss",
|
||||
"best_valid_loss",
|
||||
]
|
||||
for k in keys:
|
||||
params[k] = saved_params[k]
|
||||
|
||||
if params.start_batch > 0:
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
optimizer:
|
||||
The optimizer used in the training.
|
||||
sampler:
|
||||
The sampler for the training dataset.
|
||||
scaler:
|
||||
The scaler used for mix precision training.
|
||||
"""
|
||||
if rank != 0:
|
||||
return
|
||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||
save_checkpoint_impl(
|
||||
filename=filename,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
tokenizer: whisper.tokenizer.Tokenizer,
|
||||
model: Union[nn.Module, DDP],
|
||||
batch: dict,
|
||||
is_training: bool,
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
"""
|
||||
Compute the loss for the given batch.
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
tokenizer:
|
||||
The tokenizer used to encode the text.
|
||||
model:
|
||||
The model for training.
|
||||
batch:
|
||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||
for the content in it.
|
||||
is_training:
|
||||
Whether it is training.
|
||||
Returns:
|
||||
Return a tuple of two elements. The first element is the loss tensor.
|
||||
"""
|
||||
# For the uneven-sized batch, the total duration after padding would possibly
|
||||
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
|
||||
# we simply drop the last few shortest samples, so that the retained total frames
|
||||
# (after padding) would not exceed `allowed_max_frames`:
|
||||
# `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
|
||||
# where `max_frames = max_duration * 1000 // frame_shift_ms`.
|
||||
# We set allowed_excess_duration_ratio=0.1.
|
||||
if isinstance(model, DDP):
|
||||
# get underlying nn.Module
|
||||
model = model.module
|
||||
|
||||
def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor:
|
||||
padding_size = max(tensor.shape[0] for tensor in tensors)
|
||||
dims = len(tensors[0].shape)
|
||||
padded_tensors = []
|
||||
for tensor in tensors:
|
||||
padding = [0] * 2 * dims
|
||||
padding[-1] = padding_size - tensor.shape[0]
|
||||
padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value))
|
||||
return torch.stack([tensor for tensor in padded_tensors], dim=0)
|
||||
|
||||
max_frames = params.max_duration * 1000 // params.frame_shift_ms
|
||||
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
|
||||
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
|
||||
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
feature = feature.transpose(1, 2) # (N, C, T)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
batch_idx_train = params.batch_idx_train
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
# remove spaces in texts
|
||||
texts = [text.replace(" ", "") for text in texts]
|
||||
|
||||
text_tokens_list = [
|
||||
list(tokenizer.sot_sequence_including_notimestamps)
|
||||
+ tokenizer.encode(text)
|
||||
+ [tokenizer.eot]
|
||||
for text in texts
|
||||
]
|
||||
# convert it to torch tensor
|
||||
text_tokens_list = [
|
||||
torch.LongTensor(text_tokens) for text_tokens in text_tokens_list
|
||||
]
|
||||
|
||||
# 50256 is the index of <pad> for all whisper models
|
||||
prev_outputs_tokens = _batch_tensors(
|
||||
[tokens[:-1] for tokens in text_tokens_list], pad_value=50256
|
||||
)
|
||||
target_tokens = _batch_tensors(
|
||||
[tokens[1:] for tokens in text_tokens_list], pad_value=50256
|
||||
)
|
||||
target_lengths = torch.LongTensor(
|
||||
[tokens.shape[0] - 1 for tokens in text_tokens_list]
|
||||
)
|
||||
|
||||
decoder_criterion = LabelSmoothingLoss(
|
||||
ignore_index=50256, label_smoothing=0.1, reduction="sum"
|
||||
)
|
||||
|
||||
# ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|>
|
||||
ignore_prefix_size = 3
|
||||
with torch.set_grad_enabled(is_training):
|
||||
encoder_out = model.encoder(feature)
|
||||
text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out)
|
||||
text_logits = text_logits[:, ignore_prefix_size:, :]
|
||||
target_tokens = target_tokens[:, ignore_prefix_size:]
|
||||
loss = decoder_criterion(text_logits, target_tokens.to(device))
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
tokenizer: whisper.tokenizer.Tokenizer,
|
||||
model: Union[nn.Module, DDP],
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> MetricsTracker:
|
||||
"""Run the validation process."""
|
||||
model.eval()
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
batch=batch,
|
||||
is_training=False,
|
||||
)
|
||||
assert loss.requires_grad is False
|
||||
tot_loss = tot_loss + loss_info
|
||||
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(loss.device)
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
tokenizer: whisper.tokenizer.Tokenizer,
|
||||
model: Union[nn.Module, DDP],
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: LRSchedulerType,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Train the model for one epoch.
|
||||
|
||||
The training loss from the mean of all frames is saved in
|
||||
`params.train_loss`. It runs the validation process every
|
||||
`params.valid_interval` batches.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The model for training.
|
||||
optimizer:
|
||||
The optimizer we are using.
|
||||
scheduler:
|
||||
The learning rate scheduler, we call step() every step.
|
||||
train_dl:
|
||||
Dataloader for the training dataset.
|
||||
valid_dl:
|
||||
Dataloader for the validation dataset.
|
||||
scaler:
|
||||
The scaler used for mix precision training.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
tb_writer:
|
||||
Writer to write log messages to tensorboard.
|
||||
world_size:
|
||||
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
||||
rank:
|
||||
The rank of the node in DDP training. If no DDP is used, it should
|
||||
be set to 0.
|
||||
"""
|
||||
model.train()
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
if params.deepspeed:
|
||||
model.save_checkpoint(
|
||||
save_dir=params.exp_dir,
|
||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||
client_state={},
|
||||
)
|
||||
if rank == 0:
|
||||
convert_zero_checkpoint_to_fp32_state_dict(
|
||||
params.exp_dir,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||
)
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
if params.deepspeed:
|
||||
# deepspeed's backward() is different from torch's backward()
|
||||
# in that it does not accept a loss tensor as input.
|
||||
# It computes the loss internally.
|
||||
model.backward(loss)
|
||||
model.step()
|
||||
else:
|
||||
scaler.scale(loss).backward()
|
||||
set_batch_count(model, params.batch_idx_train)
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
except: # noqa
|
||||
display_and_save_batch(batch, params=params)
|
||||
raise
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
return
|
||||
|
||||
if (
|
||||
rank == 0
|
||||
and params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.average_period == 0
|
||||
and not params.deepspeed
|
||||
):
|
||||
update_averaged_model(
|
||||
params=params,
|
||||
model_cur=model,
|
||||
model_avg=model_avg,
|
||||
)
|
||||
|
||||
if batch_idx % 100 == 0 and params.use_fp16 and not params.deepspeed:
|
||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||
# of the grad scaler is configurable, but we can't configure it to have different
|
||||
# behavior depending on the current grad scale.
|
||||
cur_grad_scale = scaler._scale.item()
|
||||
if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
|
||||
scaler.update(cur_grad_scale * 2.0)
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
if batch_idx % params.log_interval == 0:
|
||||
try:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
except: # noqa
|
||||
cur_lr = 0.0
|
||||
cur_grad_scale = (
|
||||
scaler._scale.item()
|
||||
if (params.use_fp16 and not params.deepspeed)
|
||||
else 1.0
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
f"lr: {cur_lr:.2e}, "
|
||||
+ (
|
||||
f"grad_scale: {scaler._scale.item()}"
|
||||
if (params.use_fp16 and not params.deepspeed)
|
||||
else ""
|
||||
)
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||
)
|
||||
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
if params.use_fp16:
|
||||
tb_writer.add_scalar(
|
||||
"train/grad_scale",
|
||||
cur_grad_scale,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
params.best_train_loss = params.train_loss
|
||||
|
||||
|
||||
def run(rank, world_size, args):
|
||||
"""
|
||||
Args:
|
||||
rank:
|
||||
It is a value between 0 and `world_size-1`, which is
|
||||
passed automatically by `mp.spawn()` in :func:`main`.
|
||||
The node with rank 0 is responsible for saving checkpoint.
|
||||
world_size:
|
||||
Number of GPUs for DDP training.
|
||||
args:
|
||||
The return value of get_parser().parse_args()
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
|
||||
replace_whisper_encoder_forward()
|
||||
model = whisper.load_model(params.model_name, "cpu")
|
||||
del model.alignment_heads
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
tokenizer = whisper.tokenizer.get_tokenizer(
|
||||
model.is_multilingual,
|
||||
num_languages=model.num_languages,
|
||||
language="zh",
|
||||
task="transcribe",
|
||||
)
|
||||
|
||||
model_avg: Optional[nn.Module] = None
|
||||
if rank == 0:
|
||||
# model_avg is only used with rank 0
|
||||
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||
|
||||
assert params.start_epoch > 0, params.start_epoch
|
||||
checkpoints = load_checkpoint_if_available(
|
||||
params=params, model=model, model_avg=model_avg
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logging.info(f"Device: {device}")
|
||||
model.to(device)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr)
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||
|
||||
if checkpoints and "optimizer" in checkpoints:
|
||||
logging.info("Loading optimizer state dict")
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
|
||||
if (
|
||||
checkpoints
|
||||
and "scheduler" in checkpoints
|
||||
and checkpoints["scheduler"] is not None
|
||||
):
|
||||
logging.info("Loading scheduler state dict")
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
if world_size > 1:
|
||||
if params.deepspeed:
|
||||
logging.info("Using DeepSpeed")
|
||||
model, optimizer, _, scheduler = deepspeed.initialize(
|
||||
args=params, model=model, model_parameters=model.parameters()
|
||||
)
|
||||
else:
|
||||
logging.info("Using DDP")
|
||||
setup_dist(use_ddp_launch=True)
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
if params.inf_check:
|
||||
register_inf_check_hooks(model)
|
||||
|
||||
wenetspeech = WenetSpeechAsrDataModule(args)
|
||||
|
||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||
# We only load the sampler's state dict when it loads a checkpoint
|
||||
# saved in the middle of an epoch
|
||||
sampler_state_dict = checkpoints["sampler"]
|
||||
else:
|
||||
sampler_state_dict = None
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 15 seconds
|
||||
#
|
||||
# Caution: There is a reason to select 15.0 here. Please see
|
||||
# ../local/display_manifest_statistics.py
|
||||
#
|
||||
# You should use ../local/display_manifest_statistics.py to get
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
if c.duration < 1.0 or c.duration > 15.0:
|
||||
# logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
return False
|
||||
return True
|
||||
|
||||
train_cuts = wenetspeech.train_cuts()
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
train_dl = wenetspeech.train_dataloaders(train_cuts)
|
||||
valid_dl = wenetspeech.valid_dataloaders(wenetspeech.valid_cuts())
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
logging.info(f"start training from epoch {params.start_epoch}")
|
||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||
if not params.deepspeed:
|
||||
scheduler.step_epoch(epoch - 1)
|
||||
fix_random_seed(params.seed + epoch - 1)
|
||||
train_dl.sampler.set_epoch(epoch - 1)
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
scaler=scaler,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.print_diagnostics:
|
||||
diagnostic.print_diagnostics()
|
||||
break
|
||||
|
||||
if params.deepspeed:
|
||||
model.save_checkpoint(
|
||||
save_dir=params.exp_dir,
|
||||
tag=f"epoch-{params.cur_epoch}",
|
||||
client_state={},
|
||||
)
|
||||
if rank == 0:
|
||||
convert_zero_checkpoint_to_fp32_state_dict(
|
||||
params.exp_dir,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
|
||||
tag=f"epoch-{params.cur_epoch}",
|
||||
)
|
||||
else:
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
if world_size > 1 and not params.deepspeed:
|
||||
torch.distributed.barrier()
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def display_and_save_batch(
|
||||
batch: dict,
|
||||
params: AttributeDict,
|
||||
) -> None:
|
||||
"""Display the batch statistics and save the batch into disk.
|
||||
|
||||
Args:
|
||||
batch:
|
||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||
for the content in it.
|
||||
params:
|
||||
Parameters for training. See :func:`get_params`.
|
||||
"""
|
||||
from lhotse.utils import uuid4
|
||||
|
||||
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
|
||||
logging.info(f"Saving batch to {filename}")
|
||||
torch.save(batch, filename)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
features = batch["inputs"]
|
||||
|
||||
logging.info(f"features shape: {features.shape}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
WenetSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
world_size = get_world_size()
|
||||
rank = get_rank()
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
run(rank=rank, world_size=world_size, args=args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -0,0 +1 @@
|
||||
../../../aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py
|
||||
Loading…
x
Reference in New Issue
Block a user