From 25873de7b695e3edc572e3e82197ec4bcb6ff629 Mon Sep 17 00:00:00 2001 From: glynpu Date: Thu, 16 Mar 2023 12:25:41 +0800 Subject: [PATCH 01/24] dataloader --- egs/himia/wuw/ctc_tdnn/asr_datamodule.py | 431 +++++++++++++++++++++++ 1 file changed, 431 insertions(+) create mode 100644 egs/himia/wuw/ctc_tdnn/asr_datamodule.py diff --git a/egs/himia/wuw/ctc_tdnn/asr_datamodule.py b/egs/himia/wuw/ctc_tdnn/asr_datamodule.py new file mode 100644 index 000000000..c02a8e634 --- /dev/null +++ b/egs/himia/wuw/ctc_tdnn/asr_datamodule.py @@ -0,0 +1,431 @@ +# Copyright 2022 Xiaomi Corporation (Author: Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures + +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class HiMiaWuwDataModule: + """ + DataModule for Himia wake word experiments. + + It contains common data pipeline modules e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="Data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=6000.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=False, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="PrecomputedFeatures", + ) + group.add_argument( + "--train-channel", + type=str, + default="_7_01", + help="""channel of HI_MIA train dataset. + All channels are used if it is set "all". + """, + ) + group.add_argument( + "--dev-channel", + type=str, + default="_7_01", + help="""channel of HI_MIA dev dataset. + All channels are used if it is set "all". + """, + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + num_buckets=2, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + train_cuts_file = ( + f"cuts_train_himia{self.args.train_channel}-aishell-shuf.jsonl.gz" + ) + if "all" == self.args.train_channel: + train_cuts_file = "cuts_train_himia-aishell-shuf.jsonl.gz" + return load_manifest_lazy(self.args.manifest_dir / f"{train_cuts_file}") + + @lru_cache() + def aishell_test_cuts(self) -> CutSet: + logging.info("About to get aishell test cuts") + return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz") + + @lru_cache() + def cw_test_cuts(self) -> CutSet: + logging.info("About to get HI-MIA-CW test cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_cw_test.jsonl.gz") + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + dev_cuts_file = "cuts_dev.jsonl.gz" + if "all" != self.args.dev_channel: + dev_cuts_file = f"cuts_dev{self.args.dev_channel}.jsonl.gz" + return load_manifest_lazy(self.args.manifest_dir / f"{dev_cuts_file}") + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + # 7_01 is short for microphone 7 and channel 1. + return load_manifest_lazy(self.args.manifest_dir / "cuts_test_7_01.jsonl.gz") From b55ae4fd5368af928231054741d8c140f81731d4 Mon Sep 17 00:00:00 2001 From: glynpu Date: Thu, 16 Mar 2023 12:28:59 +0800 Subject: [PATCH 02/24] data preparation --- egs/himia/wuw/local/compute_fbank_aishell.py | 1 + egs/himia/wuw/local/compute_fbank_himia.py | 139 ++++++++++++++ egs/himia/wuw/local/compute_fbank_musan.py | 1 + egs/himia/wuw/prepare.sh | 189 +++++++++++++++++++ 4 files changed, 330 insertions(+) create mode 120000 egs/himia/wuw/local/compute_fbank_aishell.py create mode 100755 egs/himia/wuw/local/compute_fbank_himia.py create mode 120000 egs/himia/wuw/local/compute_fbank_musan.py create mode 100755 egs/himia/wuw/prepare.sh diff --git a/egs/himia/wuw/local/compute_fbank_aishell.py b/egs/himia/wuw/local/compute_fbank_aishell.py new file mode 120000 index 000000000..f66261581 --- /dev/null +++ b/egs/himia/wuw/local/compute_fbank_aishell.py @@ -0,0 +1 @@ +../../../aishell/ASR/local/compute_fbank_aishell.py \ No newline at end of file diff --git a/egs/himia/wuw/local/compute_fbank_himia.py b/egs/himia/wuw/local/compute_fbank_himia.py new file mode 100755 index 000000000..f930a8c4e --- /dev/null +++ b/egs/himia/wuw/local/compute_fbank_himia.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the HI_MIA and HI_MIA_CW dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path + +import torch +from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--train-set-channel", + type=str, + default="_7_01", + help="""channel of HI_MIA dataset. + All channels are used if it is set "all". + """, + ) + + parser.add_argument( + "--enable-speed-perturb", + type=str2bool, + default=False, + help="""channel of trianing set. + """, + ) + return parser.parse_args() + + +def compute_fbank_himia( + train_set_channel: str = None, + enable_speed_perturb: bool = True, +): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + num_jobs = min(40, os.cpu_count()) + num_mel_bins = 80 + + if "all" == train_set_channel: + dataset_parts = ( + "train", + "dev", + "test", + "cw_test", + ) + else: + dataset_parts = ( + f"train{train_set_channel}", + f"dev{train_set_channel}", + f"test{train_set_channel}", + "cw_test", + ) + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, prefix="himia", output_dir=src_dir + ) + assert manifests is not None + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + if (output_dir / f"cuts_{partition}.jsonl.gz").is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + if "train" in partition and enable_speed_perturb: + cut_set = ( + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + ) + cut_set = cut_set.resample(16000) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomHdf5Writer, + ) + output_file_name = f"cuts_{partition}.jsonl.gz" + if "all" != train_set_channel: + output_file_name = f"cuts_{partition}{train_set_channel}.jsonl.gz" + + cut_set.to_file(output_dir / f"{output_file_name}") + + +def main(): + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + args = get_args() + logging.basicConfig(format=formatter, level=logging.INFO) + + compute_fbank_himia( + train_set_channel=args.train_set_channel, + enable_speed_perturb=args.enable_speed_perturb, + ) + + +if __name__ == "__main__": + main() diff --git a/egs/himia/wuw/local/compute_fbank_musan.py b/egs/himia/wuw/local/compute_fbank_musan.py new file mode 120000 index 000000000..5833f2484 --- /dev/null +++ b/egs/himia/wuw/local/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/himia/wuw/prepare.sh b/egs/himia/wuw/prepare.sh new file mode 100755 index 000000000..bb4f0f36c --- /dev/null +++ b/egs/himia/wuw/prepare.sh @@ -0,0 +1,189 @@ +#!/usr/bin/env bash + +set -eou pipefail + +stage=6 +stop_stage=6 + +# HI_MIA and aishell dataset are used in this experiment. +# musan dataset is used for data augmentation. +# +# For aishell dataset downlading and preparation, +# refer to icefall/egs/aishell/ASR/prepare.sh. +# +# For HI_MIA and HI_MIA_CW dataset, +# we assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# Then these files will be extracted to $dl_dir/HiMia/ +# +# - $dl_dir/train.tar.gz +# Himia training dataset. +# From https://www.openslr.org/85 +# +# - $dl_dir/dev.tar.gz +# Himia Devlopment dataset. +# From https://www.openslr.org/85 +# +# - $dl_dir/test_v2.tar.gz +# Himia test dataset. +# From https://www.openslr.org/85 +# +# - $dl_dir/data.tgz +# Himia confusion words(HI_MIA_CW) test dataset. +# From https://www.openslr.org/120 + +# - $dl_dir/resource.tgz +# Transcripts of (HI_MIA_CW) test dataset. +# From https://www.openslr.org/120 + +dl_dir=$PWD/download +train_set_channel=_7_01 +enable_speed_perturb=False + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded HI_MIA and HI_MIA_CW dataset to /path/to/himia/, + # you can create a symlink + # + # ln -sfv /path/to/himia $dl_dir/ + # + if [ ! -f $dl_dir/train.tar.gz ]; then + lhotse download himia $dl_dir/ + fi + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/ + # + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir + fi + + # If you have pre-downloaded it to /path/to/aishell, + # you can create a symlink + # + # ln -sfv /path/to/aishell $dl_dir/aishell + # + # The directory structure is + # aishell/ + # |-- data_aishell + # | |-- transcript + # | `-- wav + # `-- resource_aishell + # |-- lexicon.txt + # `-- speaker.info + + if [ ! -d $dl_dir/aishell/data_aishell/wav/train ]; then + lhotse download aishell $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare HI_MIA and HI_MIA_CWmanifest" + mkdir -p data/manifests + if [ ! -e data/manifests/.himia.done ]; then + lhotse prepare himia $dl_dir/HiMia data/manifests + touch data/manifests/.himia.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to data/musan + mkdir -p data/manifests + if [ ! -e data/manifests/.musan.done ]; then + lhotse prepare musan $dl_dir/musan data/manifests + touch data/manifests/.musan.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Prepare aishell manifest" + # We assume that you have downloaded the aishell corpus + # to $dl_dir/aishell + if [ ! -f data/manifests/.aishell_manifests.done ]; then + mkdir -p data/manifests + lhotse prepare aishell $dl_dir/aishell data/manifests + touch data/manifests/.aishell_manifests.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for aishell" + if [ ! -f data/fbank/.aishell.done ]; then + mkdir -p data/fbank + ./local/compute_fbank_aishell.py \ + --enable-speed-perturb=${enable_speed_perturb} + touch data/fbank/.aishell.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Compute fbank for musan" + mkdir -p data/fbank + if [ ! -e data/fbank/.musan.done ]; then + ./local/compute_fbank_musan.py + touch data/fbank/.musan.done + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Compute fbank for HI_MIA and HI_MIA_CW dataset" + # Format of train_set_channel is "micropohone position"_"channel" + # Microphone 1 to 6 is an array with 16 channels. + # Microphone 8 only has a single channel. + # So valid examples of train_set_channel could be: + # 1_01, ..., 1_16 + # 2_01, ..., 2_16 + # ... + # 6_01, ..., 6_16 + # 7_01 + train_set_channel="_7_01" + for subset in train dev test; do + for file_type in recordings supervisions; do + src=data/manifests/himia_${file_type}_${subset}.jsonl.gz + dst=data/manifests/himia_${file_type}_${subset}${train_set_channel}.jsonl.gz + cat <(gunzip -c ${src}) | \ + grep ${train_set_channel} | \ + gzip -c > ${dst} + done + done + + mkdir -p data/fbank + if [ ! -e data/fbank/.himia.done ]; then + ./local/compute_fbank_himia.py \ + --train-set-channel=${train_set_channel} \ + --enable-speed-perturb=${enable_speed_perturb} + touch data/fbank/.himia.done + fi + + train_file=data/fbank/cuts_train_himia${train_set_channel}-aishell-shuf.jsonl.gz + if [ ! -f ${train_file} ]; then + # SingleCutSampler is prefered for this experiment. + # So `shuf` the training dataset here. + cat <(gunzip -c data/fbank/aishell_cuts_train.jsonl.gz) \ + <(gunzip -c data/fbank/cuts_train${train_set_channel}.jsonl.gz) | \ + grep -v _sp | \ + shuf |shuf | gzip -c > ${train_file} + fi + +fi + From a49817385a6b82a7d2f9bc06039727b3b34adf8b Mon Sep 17 00:00:00 2001 From: glynpu Date: Thu, 16 Mar 2023 12:31:22 +0800 Subject: [PATCH 03/24] train and inference --- egs/himia/wuw/ctc_tdnn/graph.py | 49 ++ egs/himia/wuw/ctc_tdnn/inference.py | 203 +++++++++ egs/himia/wuw/ctc_tdnn/tokenizer.py | 94 ++++ egs/himia/wuw/ctc_tdnn/train.py | 678 ++++++++++++++++++++++++++++ egs/himia/wuw/prepare.sh | 2 +- egs/himia/wuw/run_ctc_tdnn.sh | 55 +++ 6 files changed, 1080 insertions(+), 1 deletion(-) create mode 100644 egs/himia/wuw/ctc_tdnn/graph.py create mode 100755 egs/himia/wuw/ctc_tdnn/inference.py create mode 100644 egs/himia/wuw/ctc_tdnn/tokenizer.py create mode 100755 egs/himia/wuw/ctc_tdnn/train.py create mode 100644 egs/himia/wuw/run_ctc_tdnn.sh diff --git a/egs/himia/wuw/ctc_tdnn/graph.py b/egs/himia/wuw/ctc_tdnn/graph.py new file mode 100644 index 000000000..184e01ed1 --- /dev/null +++ b/egs/himia/wuw/ctc_tdnn/graph.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (Author: Weiji Zhuang, +# Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + + +def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]): + """ + A graph starts with blank/unknown and follwoing by wakeup word. + + Args: + wakeup_word_tokens: A sequence of token ids corresponding wakeup_word. + It should not contain 0 and 1. + We assume 0 is for blank and 1 is for unknown. + """ + assert 0 not in wakeup_word_tokens + assert 1 not in wakeup_word_tokens + assert len(wakeup_word_tokens) >= 2 + keyword_ilabel_start = wakeup_word_tokens[0] + fst_graph = "" + for non_wake_word_token in range(keyword_ilabel_start): + fst_graph += f"0 0 {non_wake_word_token} 0\n" + cur_state = 1 + for token_idx in range(len(wakeup_word_tokens) - 1): + token = wakeup_word_tokens[token_idx] + fst_graph += f"{cur_state - 1} {cur_state} {token} 0\n" + fst_graph += f"{cur_state} {cur_state} {token} 0\n" + cur_state += 1 + + token = wakeup_word_tokens[-1] + fst_graph += f"{cur_state - 1} {cur_state} {token} 1\n" + fst_graph += f"{cur_state} {cur_state} {token} 0\n" + fst_graph += f"{cur_state}\n" + return fst_graph diff --git a/egs/himia/wuw/ctc_tdnn/inference.py b/egs/himia/wuw/ctc_tdnn/inference.py new file mode 100755 index 000000000..eae9c5333 --- /dev/null +++ b/egs/himia/wuw/ctc_tdnn/inference.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corporation (Author: Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from pathlib import Path + +import torch +from lhotse.features.io import NumpyHdf5Writer + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + setup_logger, +) + +from asr_datamodule import HiMiaWuwDataModule +from tdnn import Tdnn + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=10, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 1.", + ) + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="ctc_tdnn/exp", + help="The experiment dir", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "env_info": get_env_info(), + "feature_dim": 80, + "number_class": 9, + } + ) + return params + + +def inference_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: torch.nn.Module, + test_set: str, +): + """Compute and save model output of each utterance. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + test_set: + Name of test set. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + writer = NumpyHdf5Writer(f"{params.out_dir}/{test_set}") + for batch_idx, batch in enumerate(dl): + device = params.device + feature = batch["inputs"] + assert feature.ndim == 3 + supervisions = batch["supervisions"] + start_frames = supervisions["start_frame"] + end_frames = start_frames + supervisions["num_frames"] + + feature = feature.to(device) + # model_output is log_softmax(logit) with shape [N, T, C] + model_output = model(feature) + + for i in range(feature.size(0)): + assert start_frames[i] == 0 + cut = batch["supervisions"]["cut"][i] + cur_target = model_output[i][start_frames[i] : end_frames[i]] + writer.store_array(key=cut.id, value=cur_target.cpu().numpy()) + + num_cuts += len(batch["supervisions"]["text"]) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + +@torch.no_grad() +def main(): + parser = get_parser() + HiMiaWuwDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + out_dir = f"{params.exp_dir}/post/epoch_{params.epoch}-avg_{params.avg}/" + Path(out_dir).mkdir(parents=True, exist_ok=True) + params.out_dir = out_dir + setup_logger(f"{out_dir}/log-decode") + logging.info("Decoding started") + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = Tdnn(params.feature_dim, params.number_class) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=True) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=True + ) + + model.to(device) + model.eval() + params.device = device + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + himia = HiMiaWuwDataModule(args) + + aishell_test_cuts = himia.aishell_test_cuts() + test_cuts = himia.test_cuts() + cw_test_cuts = himia.cw_test_cuts() + + aishell_test_dl = himia.test_dataloaders(aishell_test_cuts) + test_dl = himia.test_dataloaders(test_cuts) + cw_test_dl = himia.test_dataloaders(cw_test_cuts) + + test_sets = ["aishell_test", "test", "cw_test"] + test_dls = [aishell_test_dl, test_dl, cw_test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + inference_dataset( + dl=test_dl, + params=params, + model=model, + test_set=test_set, + ) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/himia/wuw/ctc_tdnn/tokenizer.py b/egs/himia/wuw/ctc_tdnn/tokenizer.py new file mode 100644 index 000000000..bb988da6d --- /dev/null +++ b/egs/himia/wuw/ctc_tdnn/tokenizer.py @@ -0,0 +1,94 @@ +# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import torch + +from typing import List, Tuple + + +class WakeupWordTokenizer(object): + def __init__( + self, + wakeup_word: str = "", + wakeup_word_tokens: List[int] = None, + ) -> None: + """ + Args: + wakeup_word: content of positive samples. + A sample will be treated as a negative sample unless its context + is exactly the same to key_words. + wakeup_word_tokens: A list if int represents token ids of wakeup_word. + For example: the pronunciation of "你好米雅" is + "n i h ao m i y a". + Suppose we are using following lexicon: + blk 0 + unk 1 + n 2 + i 3 + h 4 + ao 5 + m 6 + y 7 + a 8 + Then wakeup_word_tokens for "你好米雅" is: + n i h ao m i y a + [2, 3, 4, 5, 6, 3, 7, 8] + """ + super().__init__() + assert wakeup_word is not None + assert wakeup_word_tokens is not None + assert ( + 0 not in wakeup_word_tokens + ), f"0 is kept for blank. Please Remove 0 from {wakeup_word_tokens}" + assert 1 not in wakeup_word_tokens, ( + f"1 is kept for unknown and negative samples. " + f" Please Remove 1 from {wakeup_word_tokens}" + ) + self.wakeup_word = wakeup_word + self.wakeup_word_tokens = wakeup_word_tokens + self.positive_number_tokens = len(wakeup_word_tokens) + self.negative_word_tokens = [1] + self.negative_number_tokens = 1 + + def texts_to_token_ids(self, texts: List[str]) -> Tuple[torch.Tensor, int]: + """Convert a list of texts to a list of k2.Fsa based texts. + + Args: + texts: + It is a list of strings. + Returns: + Return a list of k2.Fsa, one for an element in texts. + If the element is `wakeup_word`, a graph for positive samples is appneded + into resulting graph_vec, otherwise, a graph for negative samples is appended. + + Number of positive samples is also returned to track its proportion. + """ + batch_token_ids = [] + target_lengths = [] + number_positive_samples = 0 + for utt_text in texts: + if utt_text == self.wakeup_word: + batch_token_ids.append(self.wakeup_word_tokens) + target_lengths.append(self.positive_number_tokens) + number_positive_samples += 1 + else: + batch_token_ids.append(self.negative_word_tokens) + target_lengths.append(self.negative_number_tokens) + + target = torch.tensor(list(itertools.chain.from_iterable(batch_token_ids))) + target_lengths = torch.tensor(target_lengths) + return target, target_lengths, number_positive_samples diff --git a/egs/himia/wuw/ctc_tdnn/train.py b/egs/himia/wuw/ctc_tdnn/train.py new file mode 100755 index 000000000..95ad6c324 --- /dev/null +++ b/egs/himia/wuw/ctc_tdnn/train.py @@ -0,0 +1,678 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + export CUDA_VISIBLE_DEVICES="0,1,2,3" + ./ctc_tdnn/train.py \ + --exp-dir ./tdnn/exp \ + --world-size 4 \ + --max-duration 200 \ + --num-epochs 20 +""" + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import HiMiaWuwDataModule +from tdnn import Tdnn + +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter + +from tokenizer import WakeupWordTokenizer +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=20, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + ctc_tdnn/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="ctc_tdnn/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lr-factor", + type=float, + default=0.001, + help="The lr_factor for optimizer", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - number_class: Numer of classes. Each token will have a token id + from [0, num_class). + In this recipe, 0 is usually kept for blank, + and 1 is usually kept for negative words. + - wakeup_word: Text of wakeup word, i.e. positive samples. + - wakeup_word_tokens: A sequence of token ids corresponding wakeup_word. + - weight_decay: The weight_decay for the optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 5, + "reset_interval": 200, + "valid_interval": 3000, + # parameters for model + "feature_dim": 80, + "number_class": 9, + # parameters for tokenizer + "wakeup_word": "你好米雅", + "wakeup_word_tokens": [2, 3, 4, 5, 6, 3, 7, 8], + # parameters for Optimizer + "weight_decay": 1e-6, + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + batch: dict, + tokenizer: WakeupWordTokenizer, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + tokenizer: + For positive samples, map their texts to corresponding token index sequence. + While for negative samples, map their texts to unknown no matter what they are. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = model.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + N, T, C = feature.shape + feature = feature.to(device) + + supervisions = batch["supervisions"] + texts = supervisions["text"] + with torch.set_grad_enabled(is_training): + # model_output is log_softmax(logit) with shape [N, T, C] + model_output = model(feature) + + assert torch.all(supervisions["start_frame"] == 0) + num_frames = supervisions["num_frames"].to(device) + + target, target_lengths, number_positive_samples = tokenizer.texts_to_token_ids( + texts + ) # noqa E501 + target = target.to(device) + target_lengths = target_lengths.to(device) + ctc_loss = nn.CTCLoss(reduction="sum") + # [N, T, C] --> [T, N, C] + model_output = model_output.transpose(0, 1) + loss = ctc_loss(model_output, target, num_frames, target_lengths) + loss /= num_frames.sum() + + assert loss.requires_grad == is_training + + info = MetricsTracker() + info["frames"] = num_frames.sum().item() + + info["loss"] = loss.detach().cpu().item() * info["frames"] + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = supervisions["num_frames"].sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() + ) + + info["number_positive_cuts_ratio"] = (number_positive_samples / N) * info["frames"] + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + tokenizer: WakeupWordTokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + tokenizer=tokenizer, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + tokenizer: WakeupWordTokenizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + tokenizer: + For positive samples, map their texts to corresponding token index sequence. + While for negative samples, map their texts to unknown no matter what they are. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + tokenizer=tokenizer, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + + if batch_idx % params.log_interval == 0: + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + tokenizer = WakeupWordTokenizer( + wakeup_word=params.wakeup_word, + wakeup_word_tokens=params.wakeup_word_tokens, + ) + + logging.info("About to create model") + + model = Tdnn(params.feature_dim, params.number_class) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) + model.device = device + + optimizer = torch.optim.Adam( + model.parameters(), + lr=params.lr_factor, + weight_decay=params.weight_decay, + ) + + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + + himia = HiMiaWuwDataModule(args) + + train_cuts = himia.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 0.5 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_dl = himia.train_dataloaders(train_cuts) + + valid_cuts = himia.dev_cuts() + valid_dl = himia.valid_dataloaders(valid_cuts) + + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + tokenizer=tokenizer, + params=params, + ) + + for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) + train_dl.sampler.set_epoch(epoch) + + # TODO: Support lr scheduler + cur_lr = params.lr_factor + if tb_writer is not None: + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + tokenizer=tokenizer, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + tokenizer: WakeupWordTokenizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + tokenizer=tokenizer, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + HiMiaWuwDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/himia/wuw/prepare.sh b/egs/himia/wuw/prepare.sh index bb4f0f36c..a47a20682 100755 --- a/egs/himia/wuw/prepare.sh +++ b/egs/himia/wuw/prepare.sh @@ -2,7 +2,7 @@ set -eou pipefail -stage=6 +stage=0 stop_stage=6 # HI_MIA and aishell dataset are used in this experiment. diff --git a/egs/himia/wuw/run_ctc_tdnn.sh b/egs/himia/wuw/run_ctc_tdnn.sh new file mode 100644 index 000000000..6556eab93 --- /dev/null +++ b/egs/himia/wuw/run_ctc_tdnn.sh @@ -0,0 +1,55 @@ +#!/usr/bin/env bash + +set -eou pipefail + +# You need to execute ./prepare.sh to prepare datasets. +stage=1 +stop_stage=2 + +epoch=10 +avg=1 +exp_dir=./ctc_tdnn/exp/ +epoch_avg=epoch_${epoch}-avg_${avg} +post_dir=${exp_dir}/post/${epoch_avg} + +. shared/parse_options.sh || exit 1 + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Model training" + python ./ctc_tdnn/train.py \ + --num-epochs $epoch +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Get posterior of test sets" + python ctc_tdnn/inference.py \ + --avg $avg \ + --epoch $epoch \ + --exp-dir ${exp_dir} +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Decode and compute area under curve(AUC)" + for test_set in test aishell_test cw_test; do + python ctc_tdnn/decode.py \ + --decoding-graph ./data/LG.int \ + --post-h5 ${post_dir}/${test_set}.h5 \ + --score-file ${post_dir}/fst_${test_set}_pos_h5.txt + done + python ./local/auc.py \ + --legend himia_cw \ + --positive-score-file ${post_dir}/fst_test_pos_h5.txt \ + --negative-score-file ${post_dir}/fst_cw_test_pos_h5.txt + + python ./local/auc.py \ + --legend himia_aishell \ + --positive-score-file ${post_dir}/fst_test_pos_h5.txt \ + --negative-score-file ${post_dir}/fst_aishell_test_pos_h5.txt +fi From 39c0ae7749f56397f91d2cc9dc4eb6e1f7f371d1 Mon Sep 17 00:00:00 2001 From: glynpu Date: Thu, 16 Mar 2023 12:37:46 +0800 Subject: [PATCH 04/24] auc --- egs/himia/wuw/ctc_tdnn/decode.py | 279 +++++++++++++++++++++++++++++++ egs/himia/wuw/local/auc.py | 115 +++++++++++++ 2 files changed, 394 insertions(+) create mode 100755 egs/himia/wuw/ctc_tdnn/decode.py create mode 100755 egs/himia/wuw/local/auc.py diff --git a/egs/himia/wuw/ctc_tdnn/decode.py b/egs/himia/wuw/ctc_tdnn/decode.py new file mode 100755 index 000000000..7acc7d595 --- /dev/null +++ b/egs/himia/wuw/ctc_tdnn/decode.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (Author: Weiji Zhuang, +# Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import copy +import logging +from concurrent.futures import ProcessPoolExecutor +from typing import Tuple + +import numpy as np +from lhotse.features.io import NumpyHdf5Reader +from tqdm import tqdm + +from icefall.utils import AttributeDict + +from train import get_params +from graph import ctc_trivial_decoding_graph + + +class Arc: + def __init__( + self, src_state: int, dst_state: int, ilabel: int, olabel: int + ) -> None: + self.src_state = int(src_state) + self.dst_state = int(dst_state) + self.ilabel = int(ilabel) + self.olabel = int(olabel) + + def next_state(self) -> None: + return self.dst_state + + +class State: + def __init__(self) -> None: + self.arc_list = list() + + def add_arc(self, arc: Arc) -> None: + self.arc_list.append(arc) + + +class FiniteStateTransducer: + """Represents a decoding graph for wake word detection.""" + + def __init__(self, graph: str) -> None: + self.state_list = list() + for arc_str in graph.split("\n"): + arc = arc_str.strip().split() + if len(arc) == 0: + continue + # 1 and 2 for final state + # 4 for non-final state + assert len(arc) in [1, 2, 4], f"{len(arc)} {arc_str}" + if len(arc) == 4: # Non-final state + # FST must be sorted + if len(self.state_list) <= int(arc[0]): + new_state = State() + self.state_list.append(new_state) + self.state_list[int(arc[0])].add_arc( + Arc(arc[0], arc[1], arc[2], arc[3]) + ) + else: + self.final_state_id = int(arc[0]) + + def to_str(self) -> None: + fst_str = "" + for state_idx in range(len(self.state_list)): + cur_state = self.state_list[state_idx] + for arc_idx in range(len(cur_state.arc_list)): + cur_arc = cur_state.arc_list[arc_idx] + ilabel = cur_arc.ilabel + olabel = cur_arc.olabel + src_state = cur_arc.src_state + dst_state = cur_arc.dst_state + fst_str += f"{src_state} {dst_state} {ilabel} {olabel}\n" + fst_str += f"{dst_state}\n" + return fst_str + + +class Token: + def __init__(self) -> None: + self.is_active = False + self.total_score = -float("inf") + self.keyword_frames = 0 + self.average_keyword_score = -float("inf") + self.average_max_keyword_score = 0.0 + + def set_token( + self, + src_token, + is_keyword_ilabel: bool, + acoustic_score: float, + ) -> None: + """ + A dynamic programming process computing the highest score for a token + from all possible paths which could reach this token. + + Args: + src_token: The source token connected to current token with an arc. + is_keyword_ilabel: If true, the arc consumes an input label which is + a part of wake word. Otherwhise, the input label is + blank or unknown, i.e. current token is still not part of wake word. + acoustic_score: acoustic score of this arc. + """ + + if ( + not self.is_active + or self.total_score < src_token.total_score + acoustic_score + ): + self.is_active = True + self.total_score = src_token.total_score + acoustic_score + + if is_keyword_ilabel: + self.average_keyword_score = ( + acoustic_score + + src_token.average_keyword_score * src_token.keyword_frames + ) / (src_token.keyword_frames + 1) + + self.keyword_frames = src_token.keyword_frames + 1 + else: + self.average_keyword_score = 0.0 + + +class SingleDecodable: + def __init__( + self, + model_output, + keyword_ilabel_start, + graph, + ): + """ + Args: + model_output: log_softmax(logit) with shape [T, C] + keyword_ilabel_start: index of the first token of the wake word. + In this recipe, tokens not for wake word has smaller token index, + i.e. blank 0; unk 1. + graph: decoding graph of the wake word. + + """ + self.init_token_list = [Token() for i in range(len(graph.state_list))] + self.reset_token_list() + self.model_output = model_output + self.T = model_output.shape[0] + self.utt_score = 0.0 + self.current_frame_index = 0 + self.keyword_ilabel_start = keyword_ilabel_start + self.graph = graph + self.number_tokens = len(self.cur_token_list) + + def reset_token_list(self) -> None: + """ + Reset all tokens to a condition without consuming any acoustic frames. + """ + self.cur_token_list = copy.deepcopy(self.init_token_list) + self.expand_token_list = copy.deepcopy(self.init_token_list) + self.cur_token_list[0].is_active = True + self.cur_token_list[0].total_score = 0 + self.cur_token_list[0].average_keyword_score = 0 + + def process_oneframe(self) -> None: + """ + Decode a frame and update all tokens. + """ + for state_id, cur_token in enumerate(self.cur_token_list): + if cur_token.is_active: + for arc_id in self.graph.state_list[state_id].arc_list: + acoustic_score = self.model_output[self.current_frame_index][ + arc_id.ilabel + ] + is_keyword_ilabel = arc_id.ilabel >= self.keyword_ilabel_start + self.expand_token_list[arc_id.next_state()].set_token( + cur_token, + is_keyword_ilabel, + acoustic_score, + ) + # use best_score to keep total_score in a good range + self.best_state_id = 0 + best_score = self.expand_token_list[0].total_score + for state_id in range(self.number_tokens): + if self.expand_token_list[state_id].is_active: + if best_score < self.expand_token_list[state_id].total_score: + best_score = self.expand_token_list[state_id].total_score + self.best_state_id = state_id + + self.cur_token_list = self.expand_token_list + for state_id in range(self.number_tokens): + self.cur_token_list[state_id].total_score -= best_score + self.expand_token_list = copy.deepcopy(self.init_token_list) + potential_score = np.exp( + self.cur_token_list[self.graph.final_state_id].average_keyword_score + ) + if potential_score > self.utt_score: + self.utt_score = potential_score + self.current_frame_index += 1 + + +def decode_utt( + params: AttributeDict, utt_id: str, post_file, graph: FiniteStateTransducer +) -> Tuple[str, float]: + """ + Decode a single utterance. + + Args: + params: + The return value of :func:`get_params`. + utt_id: utt_id to be decoded, used to fetch posterior matrix from post_file. + post_file: file to save posterior for all test set. + graph: decoding graph. + + Returns: + utt_id and its corresponding probability to be a wake word. + """ + reader = NumpyHdf5Reader(post_file) + model_output = reader.read(utt_id) + keyword_ilabel_start = params.wakeup_word_tokens[0] + decodable = SingleDecodable( + model_output=model_output, + keyword_ilabel_start=keyword_ilabel_start, + graph=graph, + ) + for t in range(decodable.T): + decodable.process_oneframe() + return utt_id, decodable.utt_score + + +def get_parser(): + parser = argparse.ArgumentParser( + description="A simple FST decoder for the wake word detection\n" + ) + parser.add_argument( + "--decoding-graph", help="decoding graph", default="himia_ctc_graph.txt" + ) + parser.add_argument("--post-h5", help="model output in h5 format") + parser.add_argument("--score-file", help="file to save scores of each utterance") + return parser + + +def main(): + logging.basicConfig( + level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s" + ) + parser = get_parser() + args = parser.parse_args() + params = get_params() + params.update(vars(args)) + + keys = NumpyHdf5Reader(params.post_h5).hdf.keys() + graph = FiniteStateTransducer(ctc_trivial_decoding_graph(params.wakeup_word_tokens)) + logging.info(f"Graph used:\n{graph.to_str()}") + logging.info("About to load data to decoder.") + with ProcessPoolExecutor() as executor, open( + params.score_file, "w", encoding="utf8" + ) as fout: + futures = [ + executor.submit(decode_utt, params, key, params.post_h5, graph) + for key in tqdm(keys) + ] + logging.info("Decoding.") + for future in tqdm(futures): + k, v = future.result() + fout.write(str(k) + " " + str(v) + "\n") + + +if __name__ == "__main__": + main() diff --git a/egs/himia/wuw/local/auc.py b/egs/himia/wuw/local/auc.py new file mode 100755 index 000000000..d7357a0f1 --- /dev/null +++ b/egs/himia/wuw/local/auc.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (Author: Weiji Zhuang, +# Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from typing import Dict, Tuple + +import matplotlib.pyplot as plt +import numpy as np +from pathlib import Path +from sklearn.metrics import roc_curve, auc + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--positive-score-file", required=True, help="score file of positive data" + ) + parser.add_argument( + "--negative-score-file", required=True, help="score file of negative data" + ) + parser.add_argument("--legend", required=True, help="utt2dur file of negative data") + return parser.parse_args() + + +def load_score(score_file: Path) -> Dict[str, float]: + """ + Args: + score_file: Path to score file. Each line has two columns. + The first colume is utt-id, and the second one is score. + This score could be viewed as probability of being wakeup word. + + Returns: + A dict with that key is utt-id and value is corresponding score. + """ + pos_dict = {} + with open(score_file, "r", encoding="utf8") as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + key = arr[0] + score = float(arr[1]) + pos_dict[key] = score + return pos_dict + + +def get_roc_and_auc( + pos_dict: Dict, + neg_dict: Dict, +) -> Tuple[np.array, np.array, float]: + """ + Args: + pos_dict: scores of positive samples. + neg_dict: scores of negative samples. + Return: + A tuple of three elements, which will be used to plot roc curve. + Refer to sklearn.metrics.roc_curve for meaning of the first and second elements. + The third element is area under the roc curve(AUC). + """ + pos_scores = np.fromiter(pos_dict.values(), dtype=float) + neg_scores = np.fromiter(neg_dict.values(), dtype=float) + + pos_y = np.ones_like(pos_scores, dtype=int) + neg_y = np.zeros_like(neg_scores, dtype=int) + + scores = np.concatenate([pos_scores, neg_scores]) + y = np.concatenate([pos_y, neg_y]) + + fpr, tpr, thresholds = roc_curve(y, scores, pos_label=1) + roc_auc = auc(fpr, tpr) + + return fpr, tpr, roc_auc + + +def main(): + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + args = get_args() + logging.basicConfig(format=formatter, level=logging.INFO) + pos_dict = load_score(args.positive_score_file) + neg_dict = load_score(args.negative_score_file) + fpr, tpr, roc_auc = get_roc_and_auc(pos_dict, neg_dict) + + plt.figure(figsize=(16, 9)) + plt.plot(fpr, tpr, label=f"{args.legend}(AUC = %1.8f)" % roc_auc) + + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.0]) + plt.xlabel("False Positive Rate") + plt.ylabel("True Positive Rate") + plt.title("Receiver operating characteristic(ROC)") + plt.legend(loc="lower right") + + output_path = Path(args.positive_score_file).parent + logging.info(f"AUC of {args.legend} {output_path}: {roc_auc}") + plt.savefig(f"{output_path}/{args.legend}.pdf", bbox_inches="tight") + + +if __name__ == "__main__": + main() From 07a8f050b7476097a03664b431b26993f4bfc2c4 Mon Sep 17 00:00:00 2001 From: glynpu Date: Thu, 16 Mar 2023 12:56:25 +0800 Subject: [PATCH 05/24] update comments --- egs/himia/wuw/ctc_tdnn/train.py | 15 +-------------- egs/himia/wuw/run_ctc_tdnn.sh | 2 +- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/egs/himia/wuw/ctc_tdnn/train.py b/egs/himia/wuw/ctc_tdnn/train.py index 95ad6c324..0b140020e 100755 --- a/egs/himia/wuw/ctc_tdnn/train.py +++ b/egs/himia/wuw/ctc_tdnn/train.py @@ -19,7 +19,7 @@ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3" ./ctc_tdnn/train.py \ - --exp-dir ./tdnn/exp \ + --exp-dir ./ctc_tdnn/exp \ --world-size 4 \ --max-duration 200 \ --num-epochs 20 @@ -552,19 +552,6 @@ def run(rank, world_size, args): train_cuts = himia.train_cuts() - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - return 0.5 <= c.duration <= 20.0 - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_dl = himia.train_dataloaders(train_cuts) valid_cuts = himia.dev_cuts() diff --git a/egs/himia/wuw/run_ctc_tdnn.sh b/egs/himia/wuw/run_ctc_tdnn.sh index 6556eab93..16ecacf6a 100644 --- a/egs/himia/wuw/run_ctc_tdnn.sh +++ b/egs/himia/wuw/run_ctc_tdnn.sh @@ -3,7 +3,7 @@ set -eou pipefail # You need to execute ./prepare.sh to prepare datasets. -stage=1 +stage=0 stop_stage=2 epoch=10 From 3feef0a7d05586b2971521e6b3392d9fc535ffe6 Mon Sep 17 00:00:00 2001 From: glynpu Date: Thu, 16 Mar 2023 13:05:46 +0800 Subject: [PATCH 06/24] update tokenizer comments --- egs/himia/wuw/ctc_tdnn/tokenizer.py | 17 +++++++++++------ egs/himia/wuw/ctc_tdnn/train.py | 1 - 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/egs/himia/wuw/ctc_tdnn/tokenizer.py b/egs/himia/wuw/ctc_tdnn/tokenizer.py index bb988da6d..bc207ec04 100644 --- a/egs/himia/wuw/ctc_tdnn/tokenizer.py +++ b/egs/himia/wuw/ctc_tdnn/tokenizer.py @@ -64,18 +64,23 @@ class WakeupWordTokenizer(object): self.negative_word_tokens = [1] self.negative_number_tokens = 1 - def texts_to_token_ids(self, texts: List[str]) -> Tuple[torch.Tensor, int]: + def texts_to_token_ids(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor, int]: """Convert a list of texts to a list of k2.Fsa based texts. Args: texts: - It is a list of strings. + It is a list of strings, + each element is a reference text for an audio. Returns: - Return a list of k2.Fsa, one for an element in texts. - If the element is `wakeup_word`, a graph for positive samples is appneded - into resulting graph_vec, otherwise, a graph for negative samples is appended. + Return a tuple of 3 elements. + The first one is torch.Tensor(List[List[int]]), + each List[int] is tokens sequence for each a reference text. - Number of positive samples is also returned to track its proportion. + The second one is number of tokens for each sample, + mainly used by CTC loss. + + The last one is number_positive_samples, + used to track proportion of positive samples in each batch. """ batch_token_ids = [] target_lengths = [] diff --git a/egs/himia/wuw/ctc_tdnn/train.py b/egs/himia/wuw/ctc_tdnn/train.py index 0b140020e..249821c29 100755 --- a/egs/himia/wuw/ctc_tdnn/train.py +++ b/egs/himia/wuw/ctc_tdnn/train.py @@ -37,7 +37,6 @@ import torch.nn as nn from asr_datamodule import HiMiaWuwDataModule from tdnn import Tdnn -from lhotse.cut import Cut from lhotse.utils import fix_random_seed from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP From 93a168ab063431dc9f48bcef7a1358add1690cb6 Mon Sep 17 00:00:00 2001 From: glynpu Date: Thu, 16 Mar 2023 13:11:11 +0800 Subject: [PATCH 07/24] tdnn model --- egs/himia/wuw/ctc_tdnn/tdnn.py | 108 +++++++++++++++++++++++++++++++++ egs/himia/wuw/shared | 1 + 2 files changed, 109 insertions(+) create mode 100644 egs/himia/wuw/ctc_tdnn/tdnn.py create mode 120000 egs/himia/wuw/shared diff --git a/egs/himia/wuw/ctc_tdnn/tdnn.py b/egs/himia/wuw/ctc_tdnn/tdnn.py new file mode 100644 index 000000000..0f685b6c2 --- /dev/null +++ b/egs/himia/wuw/ctc_tdnn/tdnn.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch import nn, Tensor + + +class Tdnn(nn.Module): + """ + Args: + num_features (int): Number of input features + num_classes (int): Number of output classes + """ + + def __init__(self, num_features: int, num_classes: int) -> None: + super().__init__() + self.num_features = num_features + self.num_classes = num_classes + self.tdnn = nn.Sequential( + nn.Conv1d( + in_channels=num_features, + out_channels=240, + kernel_size=3, + stride=1, + padding=1, + ), + nn.ReLU(inplace=True), + nn.BatchNorm1d(num_features=240, affine=False), + nn.Conv1d( + in_channels=240, out_channels=240, kernel_size=3, stride=1, padding=1 + ), + nn.ReLU(inplace=True), + nn.BatchNorm1d(num_features=240, affine=False), + nn.Conv1d( + in_channels=240, out_channels=240, kernel_size=3, stride=1, padding=1 + ), + nn.ReLU(inplace=True), + nn.BatchNorm1d(num_features=240, affine=False), + nn.Conv1d( + in_channels=240, out_channels=240, kernel_size=3, stride=1, padding=1 + ), + nn.ReLU(inplace=True), + nn.BatchNorm1d(num_features=240, affine=False), + nn.Conv1d( + in_channels=240, out_channels=240, kernel_size=3, stride=1, padding=1 + ), + nn.ReLU(inplace=True), + nn.BatchNorm1d(num_features=240, affine=False), + nn.Conv1d( + in_channels=240, out_channels=240, kernel_size=3, stride=1, padding=1 + ), + nn.ReLU(inplace=True), + nn.BatchNorm1d(num_features=240, affine=False), + nn.Conv1d( + in_channels=240, out_channels=240, kernel_size=3, stride=1, padding=1 + ), + nn.ReLU(inplace=True), + nn.BatchNorm1d(num_features=240, affine=False), + nn.Conv1d( + in_channels=240, out_channels=240, kernel_size=3, stride=1, padding=1 + ), + nn.ReLU(inplace=True), + nn.BatchNorm1d(num_features=240, affine=False), + nn.Conv1d( + in_channels=240, out_channels=240, kernel_size=3, stride=1, padding=1 + ), + nn.ReLU(inplace=True), + nn.BatchNorm1d(num_features=240, affine=False), + nn.Conv1d( + in_channels=240, out_channels=240, kernel_size=1, stride=1, padding=0 + ), + nn.ReLU(inplace=True), + nn.BatchNorm1d(num_features=240, affine=False), + nn.Conv1d( + in_channels=240, + out_channels=num_classes, + kernel_size=1, + stride=1, + padding=0, + ), + nn.LogSoftmax(1), + ) + + def forward(self, x: Tensor) -> Tensor: + r""" + Args: + x (torch.Tensor): Tensor of dimension (N, T, C). + Returns: + Tensor: Predictor tensor of dimension (N, T, C). + """ + + x = x.transpose(1, 2) + x = self.tdnn(x) + x = x.transpose(1, 2) + return x diff --git a/egs/himia/wuw/shared b/egs/himia/wuw/shared new file mode 120000 index 000000000..4cbd91a7e --- /dev/null +++ b/egs/himia/wuw/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file From 27428187d0b5ccf75279c7928183cc3b6e93b389 Mon Sep 17 00:00:00 2001 From: glynpu Date: Thu, 16 Mar 2023 14:50:24 +0800 Subject: [PATCH 08/24] update logging information --- egs/himia/wuw/ctc_tdnn/decode.py | 33 ++++++++++++++++++++--------- egs/himia/wuw/ctc_tdnn/inference.py | 5 ++++- egs/himia/wuw/ctc_tdnn/tokenizer.py | 4 +++- egs/himia/wuw/ctc_tdnn/train.py | 2 +- egs/himia/wuw/local/auc.py | 26 ++++++++++++++++++----- egs/himia/wuw/run_ctc_tdnn.sh | 1 - 6 files changed, 52 insertions(+), 19 deletions(-) diff --git a/egs/himia/wuw/ctc_tdnn/decode.py b/egs/himia/wuw/ctc_tdnn/decode.py index 7acc7d595..6715c8b9c 100755 --- a/egs/himia/wuw/ctc_tdnn/decode.py +++ b/egs/himia/wuw/ctc_tdnn/decode.py @@ -21,12 +21,16 @@ import copy import logging from concurrent.futures import ProcessPoolExecutor from typing import Tuple +from pathlib import Path import numpy as np from lhotse.features.io import NumpyHdf5Reader from tqdm import tqdm -from icefall.utils import AttributeDict +from icefall.utils import ( + AttributeDict, + setup_logger, +) from train import get_params from graph import ctc_trivial_decoding_graph @@ -242,26 +246,33 @@ def get_parser(): description="A simple FST decoder for the wake word detection\n" ) parser.add_argument( - "--decoding-graph", help="decoding graph", default="himia_ctc_graph.txt" + "--post-h5", + type=str, + help="model output in h5 format", + ) + parser.add_argument( + "--score-file", + type=str, + help="file to save scores of each utterance", ) - parser.add_argument("--post-h5", help="model output in h5 format") - parser.add_argument("--score-file", help="file to save scores of each utterance") return parser def main(): - logging.basicConfig( - level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s" - ) parser = get_parser() args = parser.parse_args() params = get_params() params.update(vars(args)) + post_dir = Path(params.post_h5).parent + test_set = Path(params.post_h5).stem + setup_logger(f"{post_dir}/log/log-decode-{test_set}") - keys = NumpyHdf5Reader(params.post_h5).hdf.keys() graph = FiniteStateTransducer(ctc_trivial_decoding_graph(params.wakeup_word_tokens)) + logging.info(f"Graph used:\n{graph.to_str()}") - logging.info("About to load data to decoder.") + + logging.info(f"About to load {test_set}.") + keys = NumpyHdf5Reader(params.post_h5).hdf.keys() with ProcessPoolExecutor() as executor, open( params.score_file, "w", encoding="utf8" ) as fout: @@ -269,11 +280,13 @@ def main(): executor.submit(decode_utt, params, key, params.post_h5, graph) for key in tqdm(keys) ] - logging.info("Decoding.") + logging.info(f"Decoding {test_set}.") for future in tqdm(futures): k, v = future.result() fout.write(str(k) + " " + str(v) + "\n") + logging.info(f"Finish decoding {test_set}.") + if __name__ == "__main__": main() diff --git a/egs/himia/wuw/ctc_tdnn/inference.py b/egs/himia/wuw/ctc_tdnn/inference.py index eae9c5333..10950cec9 100755 --- a/egs/himia/wuw/ctc_tdnn/inference.py +++ b/egs/himia/wuw/ctc_tdnn/inference.py @@ -140,7 +140,7 @@ def main(): out_dir = f"{params.exp_dir}/post/epoch_{params.epoch}-avg_{params.avg}/" Path(out_dir).mkdir(parents=True, exist_ok=True) params.out_dir = out_dir - setup_logger(f"{out_dir}/log-decode") + setup_logger(f"{out_dir}/log/log-inference") logging.info("Decoding started") logging.info(params) @@ -186,6 +186,7 @@ def main(): test_dls = [aishell_test_dl, test_dl, cw_test_dl] for test_set, test_dl in zip(test_sets, test_dls): + logging.info(f"About to inference {test_set}") inference_dataset( dl=test_dl, params=params, @@ -193,6 +194,8 @@ def main(): test_set=test_set, ) + logging.info(f"finish inferencing {test_set}") + logging.info("Done!") diff --git a/egs/himia/wuw/ctc_tdnn/tokenizer.py b/egs/himia/wuw/ctc_tdnn/tokenizer.py index bc207ec04..e019ebb86 100644 --- a/egs/himia/wuw/ctc_tdnn/tokenizer.py +++ b/egs/himia/wuw/ctc_tdnn/tokenizer.py @@ -64,7 +64,9 @@ class WakeupWordTokenizer(object): self.negative_word_tokens = [1] self.negative_number_tokens = 1 - def texts_to_token_ids(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor, int]: + def texts_to_token_ids( + self, texts: List[str] + ) -> Tuple[torch.Tensor, torch.Tensor, int]: """Convert a list of texts to a list of k2.Fsa based texts. Args: diff --git a/egs/himia/wuw/ctc_tdnn/train.py b/egs/himia/wuw/ctc_tdnn/train.py index 249821c29..fd9d42cad 100755 --- a/egs/himia/wuw/ctc_tdnn/train.py +++ b/egs/himia/wuw/ctc_tdnn/train.py @@ -564,7 +564,7 @@ def run(rank, world_size, args): params=params, ) - for epoch in range(params.start_epoch, params.num_epochs): + for epoch in range(params.start_epoch, params.num_epochs + 1): fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) diff --git a/egs/himia/wuw/local/auc.py b/egs/himia/wuw/local/auc.py index d7357a0f1..7b35ef06b 100755 --- a/egs/himia/wuw/local/auc.py +++ b/egs/himia/wuw/local/auc.py @@ -25,16 +25,29 @@ import numpy as np from pathlib import Path from sklearn.metrics import roc_curve, auc +from icefall.utils import setup_logger + def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--positive-score-file", required=True, help="score file of positive data" + "--positive-score-file", + type=str, + required=True, + help="score file of positive data", ) parser.add_argument( - "--negative-score-file", required=True, help="score file of negative data" + "--negative-score-file", + type=str, + required=True, + help="score file of negative data", + ) + parser.add_argument( + "--legend", + type=str, + required=True, + help="legend of ROC curve picture.", ) - parser.add_argument("--legend", required=True, help="utt2dur file of negative data") return parser.parse_args() @@ -88,10 +101,13 @@ def get_roc_and_auc( def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" args = get_args() - logging.basicConfig(format=formatter, level=logging.INFO) + + score_dir = Path(args.positive_score_file).parent + setup_logger(f"{score_dir}/log/log-auc-{args.legend}") + logging.info(f"About to compute AUC of {args.legend}") + pos_dict = load_score(args.positive_score_file) neg_dict = load_score(args.negative_score_file) fpr, tpr, roc_auc = get_roc_and_auc(pos_dict, neg_dict) diff --git a/egs/himia/wuw/run_ctc_tdnn.sh b/egs/himia/wuw/run_ctc_tdnn.sh index 16ecacf6a..7dfaeefea 100644 --- a/egs/himia/wuw/run_ctc_tdnn.sh +++ b/egs/himia/wuw/run_ctc_tdnn.sh @@ -39,7 +39,6 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Decode and compute area under curve(AUC)" for test_set in test aishell_test cw_test; do python ctc_tdnn/decode.py \ - --decoding-graph ./data/LG.int \ --post-h5 ${post_dir}/${test_set}.h5 \ --score-file ${post_dir}/fst_${test_set}_pos_h5.txt done From d5471c5284f9867c72230ebb8a8f533171ffb266 Mon Sep 17 00:00:00 2001 From: glynpu Date: Thu, 16 Mar 2023 14:59:16 +0800 Subject: [PATCH 09/24] rename score file --- egs/himia/wuw/run_ctc_tdnn.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/himia/wuw/run_ctc_tdnn.sh b/egs/himia/wuw/run_ctc_tdnn.sh index 7dfaeefea..fdf5ec1db 100644 --- a/egs/himia/wuw/run_ctc_tdnn.sh +++ b/egs/himia/wuw/run_ctc_tdnn.sh @@ -40,15 +40,15 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then for test_set in test aishell_test cw_test; do python ctc_tdnn/decode.py \ --post-h5 ${post_dir}/${test_set}.h5 \ - --score-file ${post_dir}/fst_${test_set}_pos_h5.txt + --score-file ${post_dir}/fst_${test_set}_score.txt done python ./local/auc.py \ --legend himia_cw \ - --positive-score-file ${post_dir}/fst_test_pos_h5.txt \ - --negative-score-file ${post_dir}/fst_cw_test_pos_h5.txt + --positive-score-file ${post_dir}/fst_test_score.txt \ + --negative-score-file ${post_dir}/fst_cw_test_score.txt python ./local/auc.py \ --legend himia_aishell \ - --positive-score-file ${post_dir}/fst_test_pos_h5.txt \ - --negative-score-file ${post_dir}/fst_aishell_test_pos_h5.txt + --positive-score-file ${post_dir}/fst_test_score.txt \ + --negative-score-file ${post_dir}/fst_aishell_test_score.txt fi From 20f36efedd97d2b3fc0752da367c7eae59cdbec5 Mon Sep 17 00:00:00 2001 From: glynpu Date: Thu, 16 Mar 2023 16:49:05 +0800 Subject: [PATCH 10/24] url for pretrained models --- egs/himia/wuw/ctc_tdnn/README.md | 5 +++++ egs/himia/wuw/ctc_tdnn/asr_datamodule.py | 2 +- egs/himia/wuw/ctc_tdnn/train.py | 6 +++--- egs/himia/wuw/run_ctc_tdnn.sh | 6 ++++-- 4 files changed, 13 insertions(+), 6 deletions(-) create mode 100644 egs/himia/wuw/ctc_tdnn/README.md diff --git a/egs/himia/wuw/ctc_tdnn/README.md b/egs/himia/wuw/ctc_tdnn/README.md new file mode 100644 index 000000000..6eeb9161f --- /dev/null +++ b/egs/himia/wuw/ctc_tdnn/README.md @@ -0,0 +1,5 @@ +# Pretrained models and releated logs/results. + +## ctc tdnn baseline + +https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline diff --git a/egs/himia/wuw/ctc_tdnn/asr_datamodule.py b/egs/himia/wuw/ctc_tdnn/asr_datamodule.py index c02a8e634..785e06a7c 100644 --- a/egs/himia/wuw/ctc_tdnn/asr_datamodule.py +++ b/egs/himia/wuw/ctc_tdnn/asr_datamodule.py @@ -83,7 +83,7 @@ class HiMiaWuwDataModule: group.add_argument( "--max-duration", type=int, - default=6000.0, + default=200.0, help="Maximum pooled recordings duration (seconds) in a " "single batch. You can reduce it if it causes CUDA OOM.", ) diff --git a/egs/himia/wuw/ctc_tdnn/train.py b/egs/himia/wuw/ctc_tdnn/train.py index fd9d42cad..61842de79 100755 --- a/egs/himia/wuw/ctc_tdnn/train.py +++ b/egs/himia/wuw/ctc_tdnn/train.py @@ -17,11 +17,11 @@ """ Usage: - export CUDA_VISIBLE_DEVICES="0,1,2,3" + export CUDA_VISIBLE_DEVICES="0" ./ctc_tdnn/train.py \ --exp-dir ./ctc_tdnn/exp \ - --world-size 4 \ - --max-duration 200 \ + --world-size 1 \ + --max-duration 100 \ --num-epochs 20 """ diff --git a/egs/himia/wuw/run_ctc_tdnn.sh b/egs/himia/wuw/run_ctc_tdnn.sh index fdf5ec1db..8a65d9c54 100644 --- a/egs/himia/wuw/run_ctc_tdnn.sh +++ b/egs/himia/wuw/run_ctc_tdnn.sh @@ -8,7 +8,8 @@ stop_stage=2 epoch=10 avg=1 -exp_dir=./ctc_tdnn/exp/ +max_duration=150 +exp_dir=./ctc_tdnn/exp_max_duration_${max_duration}/ epoch_avg=epoch_${epoch}-avg_${avg} post_dir=${exp_dir}/post/${epoch_avg} @@ -24,7 +25,8 @@ log() { if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Model training" python ./ctc_tdnn/train.py \ - --num-epochs $epoch + --num-epochs $epoch \ + --max-duration $max_duration fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then From 2230669129330827855c1918fd0d35a6a4fc2761 Mon Sep 17 00:00:00 2001 From: glynpu Date: Thu, 16 Mar 2023 16:53:24 +0800 Subject: [PATCH 11/24] Short intro to results files in huggingface --- egs/himia/wuw/ctc_tdnn/README.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/egs/himia/wuw/ctc_tdnn/README.md b/egs/himia/wuw/ctc_tdnn/README.md index 6eeb9161f..4bf30774e 100644 --- a/egs/himia/wuw/ctc_tdnn/README.md +++ b/egs/himia/wuw/ctc_tdnn/README.md @@ -2,4 +2,9 @@ ## ctc tdnn baseline -https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline +Auc results for different epochs could be found at + +E.g. for epoch 2 and avg 1, auc log file is: + +Corresponding ROC curve is: + From e64a6e7becfeeccd1e7df2abeaf7c49c7a74dd8b Mon Sep 17 00:00:00 2001 From: glynpu Date: Thu, 16 Mar 2023 20:03:57 +0800 Subject: [PATCH 12/24] update comments --- egs/himia/wuw/README.md | 10 +++++++ egs/himia/wuw/ctc_tdnn/README.md | 10 ------- egs/himia/wuw/ctc_tdnn/decode.py | 33 ++++++++++++++++++---- egs/himia/wuw/ctc_tdnn/graph.py | 2 +- egs/himia/wuw/ctc_tdnn/inference.py | 4 +-- egs/himia/wuw/ctc_tdnn/tdnn.py | 2 +- egs/himia/wuw/ctc_tdnn/tokenizer.py | 2 +- egs/himia/wuw/ctc_tdnn/train.py | 8 +++--- egs/himia/wuw/local/auc.py | 6 ++-- egs/himia/wuw/local/compute_fbank_himia.py | 4 +-- egs/himia/wuw/prepare.sh | 12 +++++--- egs/himia/wuw/run_ctc_tdnn.sh | 15 +++++----- 12 files changed, 67 insertions(+), 41 deletions(-) create mode 100644 egs/himia/wuw/README.md delete mode 100644 egs/himia/wuw/ctc_tdnn/README.md diff --git a/egs/himia/wuw/README.md b/egs/himia/wuw/README.md new file mode 100644 index 000000000..59dba046e --- /dev/null +++ b/egs/himia/wuw/README.md @@ -0,0 +1,10 @@ +# Pretrained models and related logs/results. + +## ctc tdnn baseline + +AUC results for different epochs could be found at + +E.g. for epoch 15 and avg 1, result log file is: + +Corresponding ROC curve is: + diff --git a/egs/himia/wuw/ctc_tdnn/README.md b/egs/himia/wuw/ctc_tdnn/README.md deleted file mode 100644 index 4bf30774e..000000000 --- a/egs/himia/wuw/ctc_tdnn/README.md +++ /dev/null @@ -1,10 +0,0 @@ -# Pretrained models and releated logs/results. - -## ctc tdnn baseline - -Auc results for different epochs could be found at - -E.g. for epoch 2 and avg 1, auc log file is: - -Corresponding ROC curve is: - diff --git a/egs/himia/wuw/ctc_tdnn/decode.py b/egs/himia/wuw/ctc_tdnn/decode.py index 6715c8b9c..9d05a3310 100755 --- a/egs/himia/wuw/ctc_tdnn/decode.py +++ b/egs/himia/wuw/ctc_tdnn/decode.py @@ -61,28 +61,49 @@ class FiniteStateTransducer: """Represents a decoding graph for wake word detection.""" def __init__(self, graph: str) -> None: + """ + Construct a decoding graph in FST format given string format graph. + + Args: + graph: A string format fst. Each arc is separated by "\n". + """ self.state_list = list() for arc_str in graph.split("\n"): arc = arc_str.strip().split() if len(arc) == 0: continue + # An arc may contain 1, 2 or 4 elements, with format: + # src_state [dst_state] [ilabel] [olabel] # 1 and 2 for final state # 4 for non-final state assert len(arc) in [1, 2, 4], f"{len(arc)} {arc_str}" + arc = [int(element) for element in arc] + src_state_id = arc[0] + max_state_id = len(self.state_list) - 1 if len(arc) == 4: # Non-final state - # FST must be sorted - if len(self.state_list) <= int(arc[0]): + assert max_state_id <= src_state_id, ( + f"Fsa must be sorted by src_state, " + f"while {cur_number_states} <= {src_state_id}. Check your graph." + ) + if max_state_id < src_state_id: new_state = State() self.state_list.append(new_state) - self.state_list[int(arc[0])].add_arc( - Arc(arc[0], arc[1], arc[2], arc[3]) + + self.state_list[src_state_id].add_arc( + Arc(src_state_id, arc[1], arc[2], arc[3]) ) else: - self.final_state_id = int(arc[0]) + assert ( + max_state_id == src_state_id + ), f"Final state seems unreachable. Check your graph." + self.final_state_id = src_state_id def to_str(self) -> None: fst_str = "" - for state_idx in range(len(self.state_list)): + number_states = len(self.state_list) + if number_states == 0: + return fst_str + for state_idx in range(number_states): cur_state = self.state_list[state_idx] for arc_idx in range(len(cur_state.arc_list)): cur_arc = cur_state.arc_list[arc_idx] diff --git a/egs/himia/wuw/ctc_tdnn/graph.py b/egs/himia/wuw/ctc_tdnn/graph.py index 184e01ed1..60e8afe2e 100644 --- a/egs/himia/wuw/ctc_tdnn/graph.py +++ b/egs/himia/wuw/ctc_tdnn/graph.py @@ -21,7 +21,7 @@ from typing import List def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]): """ - A graph starts with blank/unknown and follwoing by wakeup word. + A graph starts with blank/unknown and following by wakeup word. Args: wakeup_word_tokens: A sequence of token ids corresponding wakeup_word. diff --git a/egs/himia/wuw/ctc_tdnn/inference.py b/egs/himia/wuw/ctc_tdnn/inference.py index 10950cec9..b530eda62 100755 --- a/egs/himia/wuw/ctc_tdnn/inference.py +++ b/egs/himia/wuw/ctc_tdnn/inference.py @@ -69,7 +69,7 @@ def get_params() -> AttributeDict: { "env_info": get_env_info(), "feature_dim": 80, - "number_class": 9, + "num_class": 9, } ) return params @@ -150,7 +150,7 @@ def main(): logging.info(f"device: {device}") - model = Tdnn(params.feature_dim, params.number_class) + model = Tdnn(params.feature_dim, params.num_class) if params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=True) diff --git a/egs/himia/wuw/ctc_tdnn/tdnn.py b/egs/himia/wuw/ctc_tdnn/tdnn.py index 0f685b6c2..3425d4cca 100644 --- a/egs/himia/wuw/ctc_tdnn/tdnn.py +++ b/egs/himia/wuw/ctc_tdnn/tdnn.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo) +# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo) # # See ../../../../LICENSE for clarification regarding multiple authors # diff --git a/egs/himia/wuw/ctc_tdnn/tokenizer.py b/egs/himia/wuw/ctc_tdnn/tokenizer.py index e019ebb86..5bd54d2f0 100644 --- a/egs/himia/wuw/ctc_tdnn/tokenizer.py +++ b/egs/himia/wuw/ctc_tdnn/tokenizer.py @@ -1,4 +1,4 @@ -# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo) +# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo) # # See ../../LICENSE for clarification regarding multiple authors # diff --git a/egs/himia/wuw/ctc_tdnn/train.py b/egs/himia/wuw/ctc_tdnn/train.py index 61842de79..04953d9c3 100755 --- a/egs/himia/wuw/ctc_tdnn/train.py +++ b/egs/himia/wuw/ctc_tdnn/train.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo) +# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -162,7 +162,7 @@ def get_params() -> AttributeDict: - feature_dim: The model input dim. It has to match the one used in computing features. - - number_class: Numer of classes. Each token will have a token id + - num_class: Number of classes. Each token will have a token id from [0, num_class). In this recipe, 0 is usually kept for blank, and 1 is usually kept for negative words. @@ -182,7 +182,7 @@ def get_params() -> AttributeDict: "valid_interval": 3000, # parameters for model "feature_dim": 80, - "number_class": 9, + "num_class": 9, # parameters for tokenizer "wakeup_word": "你好米雅", "wakeup_word_tokens": [2, 3, 4, 5, 6, 3, 7, 8], @@ -529,7 +529,7 @@ def run(rank, world_size, args): logging.info("About to create model") - model = Tdnn(params.feature_dim, params.number_class) + model = Tdnn(params.feature_dim, params.numb_class) checkpoints = load_checkpoint_if_available(params=params, model=model) diff --git a/egs/himia/wuw/local/auc.py b/egs/himia/wuw/local/auc.py index 7b35ef06b..f5a210d87 100755 --- a/egs/himia/wuw/local/auc.py +++ b/egs/himia/wuw/local/auc.py @@ -55,7 +55,7 @@ def load_score(score_file: Path) -> Dict[str, float]: """ Args: score_file: Path to score file. Each line has two columns. - The first colume is utt-id, and the second one is score. + The first column is utt-id, and the second one is score. This score could be viewed as probability of being wakeup word. Returns: @@ -81,9 +81,9 @@ def get_roc_and_auc( pos_dict: scores of positive samples. neg_dict: scores of negative samples. Return: - A tuple of three elements, which will be used to plot roc curve. + A tuple of three elements, which will be used to plot ROC curve. Refer to sklearn.metrics.roc_curve for meaning of the first and second elements. - The third element is area under the roc curve(AUC). + The third element is area under the ROC curve(AUC). """ pos_scores = np.fromiter(pos_dict.values(), dtype=float) neg_scores = np.fromiter(neg_dict.values(), dtype=float) diff --git a/egs/himia/wuw/local/compute_fbank_himia.py b/egs/himia/wuw/local/compute_fbank_himia.py index f930a8c4e..3acac8b0f 100755 --- a/egs/himia/wuw/local/compute_fbank_himia.py +++ b/egs/himia/wuw/local/compute_fbank_himia.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo) +# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -57,7 +57,7 @@ def get_args(): "--enable-speed-perturb", type=str2bool, default=False, - help="""channel of trianing set. + help="""channel of training set. """, ) return parser.parse_args() diff --git a/egs/himia/wuw/prepare.sh b/egs/himia/wuw/prepare.sh index a47a20682..96df29097 100755 --- a/egs/himia/wuw/prepare.sh +++ b/egs/himia/wuw/prepare.sh @@ -8,7 +8,7 @@ stop_stage=6 # HI_MIA and aishell dataset are used in this experiment. # musan dataset is used for data augmentation. # -# For aishell dataset downlading and preparation, +# For aishell dataset downloading and preparation, # refer to icefall/egs/aishell/ASR/prepare.sh. # # For HI_MIA and HI_MIA_CW dataset, @@ -96,7 +96,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare HI_MIA and HI_MIA_CWmanifest" + log "Stage 1: Prepare HI_MIA and HI_MIA_CW manifest" mkdir -p data/manifests if [ ! -e data/manifests/.himia.done ]; then lhotse prepare himia $dl_dir/HiMia data/manifests @@ -177,8 +177,12 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then train_file=data/fbank/cuts_train_himia${train_set_channel}-aishell-shuf.jsonl.gz if [ ! -f ${train_file} ]; then - # SingleCutSampler is prefered for this experiment. - # So `shuf` the training dataset here. + # SingleCutSampler is preferred for this experiment + # rather than DynamicBucketingSampler. + # Since negative audios(Aishell) tends to be longer than positive ones(HiMia). + # if DynamicBucketingSample is used, a batch may contain either all negative sample + # or positive sample. + # So `shuf` the training dataset here and use SingleCutSampler to load data. cat <(gunzip -c data/fbank/aishell_cuts_train.jsonl.gz) \ <(gunzip -c data/fbank/cuts_train${train_set_channel}.jsonl.gz) | \ grep -v _sp | \ diff --git a/egs/himia/wuw/run_ctc_tdnn.sh b/egs/himia/wuw/run_ctc_tdnn.sh index 8a65d9c54..258c2b2b1 100644 --- a/egs/himia/wuw/run_ctc_tdnn.sh +++ b/egs/himia/wuw/run_ctc_tdnn.sh @@ -26,6 +26,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Model training" python ./ctc_tdnn/train.py \ --num-epochs $epoch \ + --exp_dir $exp_dir --max-duration $max_duration fi @@ -34,7 +35,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then python ctc_tdnn/inference.py \ --avg $avg \ --epoch $epoch \ - --exp-dir ${exp_dir} + --exp-dir $exp_dir fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then @@ -45,12 +46,12 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then --score-file ${post_dir}/fst_${test_set}_score.txt done python ./local/auc.py \ - --legend himia_cw \ - --positive-score-file ${post_dir}/fst_test_score.txt \ - --negative-score-file ${post_dir}/fst_cw_test_score.txt + --legend himia_cw \ + --positive-score-file ${post_dir}/fst_test_score.txt \ + --negative-score-file ${post_dir}/fst_cw_test_score.txt python ./local/auc.py \ - --legend himia_aishell \ - --positive-score-file ${post_dir}/fst_test_score.txt \ - --negative-score-file ${post_dir}/fst_aishell_test_score.txt + --legend himia_aishell \ + --positive-score-file ${post_dir}/fst_test_score.txt \ + --negative-score-file ${post_dir}/fst_aishell_test_score.txt fi From c11faa8c77f51e24db6bd50e484561fc87e9407a Mon Sep 17 00:00:00 2001 From: glynpu Date: Thu, 16 Mar 2023 20:05:44 +0800 Subject: [PATCH 13/24] update result url --- egs/himia/wuw/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/himia/wuw/README.md b/egs/himia/wuw/README.md index 59dba046e..5b642ee7f 100644 --- a/egs/himia/wuw/README.md +++ b/egs/himia/wuw/README.md @@ -2,7 +2,7 @@ ## ctc tdnn baseline -AUC results for different epochs could be found at +AUC results for different epochs could be found at E.g. for epoch 15 and avg 1, result log file is: From fe6dcd74297f7d62c98469f14f1f7dc01df7874a Mon Sep 17 00:00:00 2001 From: glynpu Date: Thu, 16 Mar 2023 20:08:51 +0800 Subject: [PATCH 14/24] update result url --- egs/himia/wuw/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/himia/wuw/README.md b/egs/himia/wuw/README.md index 5b642ee7f..50d92ccd9 100644 --- a/egs/himia/wuw/README.md +++ b/egs/himia/wuw/README.md @@ -4,7 +4,7 @@ AUC results for different epochs could be found at -E.g. for epoch 15 and avg 1, result log file is: +E.g. for epoch 15 and avg 1, result log file is: Corresponding ROC curve is: From b769f10bf6832a6babba1275e22851b63c75e043 Mon Sep 17 00:00:00 2001 From: glynpu Date: Thu, 16 Mar 2023 20:17:44 +0800 Subject: [PATCH 15/24] fix comments --- egs/himia/wuw/ctc_tdnn/decode.py | 4 ++-- egs/himia/wuw/ctc_tdnn/train.py | 2 +- egs/himia/wuw/run_ctc_tdnn.sh | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/himia/wuw/ctc_tdnn/decode.py b/egs/himia/wuw/ctc_tdnn/decode.py index 9d05a3310..747f7ff52 100755 --- a/egs/himia/wuw/ctc_tdnn/decode.py +++ b/egs/himia/wuw/ctc_tdnn/decode.py @@ -83,7 +83,7 @@ class FiniteStateTransducer: if len(arc) == 4: # Non-final state assert max_state_id <= src_state_id, ( f"Fsa must be sorted by src_state, " - f"while {cur_number_states} <= {src_state_id}. Check your graph." + f"while {max_state_id} <= {src_state_id}. Check your graph." ) if max_state_id < src_state_id: new_state = State() @@ -95,7 +95,7 @@ class FiniteStateTransducer: else: assert ( max_state_id == src_state_id - ), f"Final state seems unreachable. Check your graph." + ), "Final state seems unreachable. Check your graph." self.final_state_id = src_state_id def to_str(self) -> None: diff --git a/egs/himia/wuw/ctc_tdnn/train.py b/egs/himia/wuw/ctc_tdnn/train.py index 04953d9c3..d35744e20 100755 --- a/egs/himia/wuw/ctc_tdnn/train.py +++ b/egs/himia/wuw/ctc_tdnn/train.py @@ -529,7 +529,7 @@ def run(rank, world_size, args): logging.info("About to create model") - model = Tdnn(params.feature_dim, params.numb_class) + model = Tdnn(params.feature_dim, params.num_class) checkpoints = load_checkpoint_if_available(params=params, model=model) diff --git a/egs/himia/wuw/run_ctc_tdnn.sh b/egs/himia/wuw/run_ctc_tdnn.sh index 258c2b2b1..2f6a1788f 100644 --- a/egs/himia/wuw/run_ctc_tdnn.sh +++ b/egs/himia/wuw/run_ctc_tdnn.sh @@ -26,7 +26,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Model training" python ./ctc_tdnn/train.py \ --num-epochs $epoch \ - --exp_dir $exp_dir + --exp-dir $exp_dir --max-duration $max_duration fi From af048b279c0f31b9dbf67f7fcc467381c4ad637b Mon Sep 17 00:00:00 2001 From: glynpu Date: Thu, 16 Mar 2023 20:54:45 +0800 Subject: [PATCH 16/24] update type hints --- egs/himia/wuw/ctc_tdnn/decode.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/egs/himia/wuw/ctc_tdnn/decode.py b/egs/himia/wuw/ctc_tdnn/decode.py index 747f7ff52..3369f5403 100755 --- a/egs/himia/wuw/ctc_tdnn/decode.py +++ b/egs/himia/wuw/ctc_tdnn/decode.py @@ -126,7 +126,7 @@ class Token: def set_token( self, - src_token, + src_token, # Token conneted to current token. is_keyword_ilabel: bool, acoustic_score: float, ) -> None: @@ -163,9 +163,9 @@ class Token: class SingleDecodable: def __init__( self, - model_output, - keyword_ilabel_start, - graph, + model_output: np.array, + keyword_ilabel_start: int, + graph: FiniteStateTransducer, ): """ Args: @@ -234,7 +234,10 @@ class SingleDecodable: def decode_utt( - params: AttributeDict, utt_id: str, post_file, graph: FiniteStateTransducer + params: AttributeDict, + utt_id: str, + post_file: str, + graph: FiniteStateTransducer, ) -> Tuple[str, float]: """ Decode a single utterance. @@ -244,7 +247,7 @@ def decode_utt( The return value of :func:`get_params`. utt_id: utt_id to be decoded, used to fetch posterior matrix from post_file. post_file: file to save posterior for all test set. - graph: decoding graph. + graph: decoding graph in FiniteStateTransducer format. Returns: utt_id and its corresponding probability to be a wake word. From c055f0cc491c59567355341728f37c20c9f6c75a Mon Sep 17 00:00:00 2001 From: glynpu Date: Fri, 17 Mar 2023 10:08:06 +0800 Subject: [PATCH 17/24] update comments --- egs/himia/wuw/run_ctc_tdnn.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/himia/wuw/run_ctc_tdnn.sh b/egs/himia/wuw/run_ctc_tdnn.sh index 2f6a1788f..b6c80bcf2 100644 --- a/egs/himia/wuw/run_ctc_tdnn.sh +++ b/egs/himia/wuw/run_ctc_tdnn.sh @@ -6,9 +6,9 @@ set -eou pipefail stage=0 stop_stage=2 -epoch=10 +epoch=20 avg=1 -max_duration=150 +max_duration=200 exp_dir=./ctc_tdnn/exp_max_duration_${max_duration}/ epoch_avg=epoch_${epoch}-avg_${avg} post_dir=${exp_dir}/post/${epoch_avg} @@ -26,12 +26,12 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Model training" python ./ctc_tdnn/train.py \ --num-epochs $epoch \ - --exp-dir $exp_dir + --exp-dir $exp_dir \ --max-duration $max_duration fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Get posterior of test sets" + log "Stage 1: Get posterior(log_softmax(logit)) of test sets" python ctc_tdnn/inference.py \ --avg $avg \ --epoch $epoch \ From 0a5b639ec1cd3bd292b8ae84210f171c0355dcda Mon Sep 17 00:00:00 2001 From: glynpu Date: Fri, 17 Mar 2023 12:17:08 +0800 Subject: [PATCH 18/24] update comments --- egs/himia/wuw/ctc_tdnn/asr_datamodule.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/egs/himia/wuw/ctc_tdnn/asr_datamodule.py b/egs/himia/wuw/ctc_tdnn/asr_datamodule.py index 785e06a7c..565039062 100644 --- a/egs/himia/wuw/ctc_tdnn/asr_datamodule.py +++ b/egs/himia/wuw/ctc_tdnn/asr_datamodule.py @@ -16,7 +16,6 @@ import argparse -import inspect import logging from functools import lru_cache from pathlib import Path @@ -190,7 +189,7 @@ class HiMiaWuwDataModule: "--input-strategy", type=str, default="PrecomputedFeatures", - help="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", ) group.add_argument( "--train-channel", @@ -198,6 +197,8 @@ class HiMiaWuwDataModule: default="_7_01", help="""channel of HI_MIA train dataset. All channels are used if it is set "all". + Please refer state 6 in prepare.sh for its meaning and other + potential values. Currently, Only "_7_01" is verified. """, ) group.add_argument( @@ -206,6 +207,8 @@ class HiMiaWuwDataModule: default="_7_01", help="""channel of HI_MIA dev dataset. All channels are used if it is set "all". + Please refer state 6 in prepare.sh for its meaning and other + potential values. Currently, Only "_7_01" is verified. """, ) @@ -248,22 +251,11 @@ class HiMiaWuwDataModule: input_transforms = [] if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, + num_frame_masks=10, features_mask_size=27, num_feature_masks=2, frames_mask_size=100, From 9f94984dbb8feb7b532d21446f9cc8ef8a7f2031 Mon Sep 17 00:00:00 2001 From: glynpu Date: Fri, 17 Mar 2023 12:21:35 +0800 Subject: [PATCH 19/24] fix typo --- egs/himia/wuw/ctc_tdnn/asr_datamodule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/himia/wuw/ctc_tdnn/asr_datamodule.py b/egs/himia/wuw/ctc_tdnn/asr_datamodule.py index 565039062..72eb2dc8b 100644 --- a/egs/himia/wuw/ctc_tdnn/asr_datamodule.py +++ b/egs/himia/wuw/ctc_tdnn/asr_datamodule.py @@ -197,7 +197,7 @@ class HiMiaWuwDataModule: default="_7_01", help="""channel of HI_MIA train dataset. All channels are used if it is set "all". - Please refer state 6 in prepare.sh for its meaning and other + Please refer to state 6 in prepare.sh for its meaning and other potential values. Currently, Only "_7_01" is verified. """, ) @@ -207,7 +207,7 @@ class HiMiaWuwDataModule: default="_7_01", help="""channel of HI_MIA dev dataset. All channels are used if it is set "all". - Please refer state 6 in prepare.sh for its meaning and other + Please refer to state 6 in prepare.sh for its meaning and other potential values. Currently, Only "_7_01" is verified. """, ) From 17768da017697205d7831fb42c00c3baa164fbb4 Mon Sep 17 00:00:00 2001 From: glynpu Date: Fri, 17 Mar 2023 14:26:50 +0800 Subject: [PATCH 20/24] update comments --- egs/himia/wuw/ctc_tdnn/asr_datamodule.py | 4 ++-- egs/himia/wuw/ctc_tdnn/graph.py | 6 +++++- egs/himia/wuw/ctc_tdnn/tokenizer.py | 18 +++++++++--------- egs/himia/wuw/ctc_tdnn/train.py | 3 +++ 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/egs/himia/wuw/ctc_tdnn/asr_datamodule.py b/egs/himia/wuw/ctc_tdnn/asr_datamodule.py index 72eb2dc8b..db633f9f9 100644 --- a/egs/himia/wuw/ctc_tdnn/asr_datamodule.py +++ b/egs/himia/wuw/ctc_tdnn/asr_datamodule.py @@ -197,7 +197,7 @@ class HiMiaWuwDataModule: default="_7_01", help="""channel of HI_MIA train dataset. All channels are used if it is set "all". - Please refer to state 6 in prepare.sh for its meaning and other + Please refer to stage 6 in prepare.sh for its meaning and other potential values. Currently, Only "_7_01" is verified. """, ) @@ -207,7 +207,7 @@ class HiMiaWuwDataModule: default="_7_01", help="""channel of HI_MIA dev dataset. All channels are used if it is set "all". - Please refer to state 6 in prepare.sh for its meaning and other + Please refer to stage 6 in prepare.sh for its meaning and other potential values. Currently, Only "_7_01" is verified. """, ) diff --git a/egs/himia/wuw/ctc_tdnn/graph.py b/egs/himia/wuw/ctc_tdnn/graph.py index 60e8afe2e..d1ff3114d 100644 --- a/egs/himia/wuw/ctc_tdnn/graph.py +++ b/egs/himia/wuw/ctc_tdnn/graph.py @@ -19,7 +19,7 @@ from typing import List -def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]): +def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]) -> str: """ A graph starts with blank/unknown and following by wakeup word. @@ -27,6 +27,10 @@ def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]): wakeup_word_tokens: A sequence of token ids corresponding wakeup_word. It should not contain 0 and 1. We assume 0 is for blank and 1 is for unknown. + Returns: + Returns a finite-state transducer in string format, + used as a decoding graph. + Arcs are separated with "\n". """ assert 0 not in wakeup_word_tokens assert 1 not in wakeup_word_tokens diff --git a/egs/himia/wuw/ctc_tdnn/tokenizer.py b/egs/himia/wuw/ctc_tdnn/tokenizer.py index 5bd54d2f0..b6225c66c 100644 --- a/egs/himia/wuw/ctc_tdnn/tokenizer.py +++ b/egs/himia/wuw/ctc_tdnn/tokenizer.py @@ -23,15 +23,15 @@ from typing import List, Tuple class WakeupWordTokenizer(object): def __init__( self, - wakeup_word: str = "", - wakeup_word_tokens: List[int] = None, + wakeup_word: str, + wakeup_word_tokens: List[int], ) -> None: """ Args: wakeup_word: content of positive samples. - A sample will be treated as a negative sample unless its context + A sample will be treated as a negative sample unless its content is exactly the same to key_words. - wakeup_word_tokens: A list if int represents token ids of wakeup_word. + wakeup_word_tokens: A list of int representing token ids of wakeup_word. For example: the pronunciation of "你好米雅" is "n i h ao m i y a". Suppose we are using following lexicon: @@ -67,7 +67,7 @@ class WakeupWordTokenizer(object): def texts_to_token_ids( self, texts: List[str] ) -> Tuple[torch.Tensor, torch.Tensor, int]: - """Convert a list of texts to a list of k2.Fsa based texts. + """Convert a list of texts to parameters needed by CTC loss. Args: texts: @@ -76,7 +76,7 @@ class WakeupWordTokenizer(object): Returns: Return a tuple of 3 elements. The first one is torch.Tensor(List[List[int]]), - each List[int] is tokens sequence for each a reference text. + each List[int] is tokens sequence for each reference text. The second one is number of tokens for each sample, mainly used by CTC loss. @@ -89,13 +89,13 @@ class WakeupWordTokenizer(object): number_positive_samples = 0 for utt_text in texts: if utt_text == self.wakeup_word: - batch_token_ids.append(self.wakeup_word_tokens) + batch_token_ids.extend(self.wakeup_word_tokens) target_lengths.append(self.positive_number_tokens) number_positive_samples += 1 else: - batch_token_ids.append(self.negative_word_tokens) + batch_token_ids.extend(self.negative_word_tokens) target_lengths.append(self.negative_number_tokens) - target = torch.tensor(list(itertools.chain.from_iterable(batch_token_ids))) + target = torch.tensor(batch_token_ids) target_lengths = torch.tensor(target_lengths) return target, target_lengths, number_positive_samples diff --git a/egs/himia/wuw/ctc_tdnn/train.py b/egs/himia/wuw/ctc_tdnn/train.py index d35744e20..62d71b0bf 100755 --- a/egs/himia/wuw/ctc_tdnn/train.py +++ b/egs/himia/wuw/ctc_tdnn/train.py @@ -531,6 +531,9 @@ def run(rank, world_size, args): model = Tdnn(params.feature_dim, params.num_class) + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + checkpoints = load_checkpoint_if_available(params=params, model=model) model.to(device) From 1162d74dc4ae3ab5b5ced5a934cb7765f60fda5e Mon Sep 17 00:00:00 2001 From: glynpu Date: Fri, 17 Mar 2023 14:27:15 +0800 Subject: [PATCH 21/24] add results.md --- egs/himia/wuw/RESULTS.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 egs/himia/wuw/RESULTS.md diff --git a/egs/himia/wuw/RESULTS.md b/egs/himia/wuw/RESULTS.md new file mode 100644 index 000000000..7b7006d36 --- /dev/null +++ b/egs/himia/wuw/RESULTS.md @@ -0,0 +1,17 @@ +## Results + +### ctc tdnn model with Number of model parameters: 1,502,169 + +AUC results for different epochs could be found at + +Here is the result for epoch_15-avg_1(with the highest AUC). + +| test set | HiMia-Aishell | HiMia-CW| +| ---- | ---- | ----| +| AUC | 0.9597 |0.9292| + +HiMia-CW +![himia_cw](https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/resolve/main/exp_max_duration_100/post/epoch_15-avg_1/himia_cw.pdf) + +HiMia-Aishell +![himia_aishell](https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/resolve/main/exp_max_duration_100/post/epoch_15-avg_1/himia_aishell.pdf) From 29760e255b511f83422d83d67e86a137ad131be6 Mon Sep 17 00:00:00 2001 From: glynpu Date: Fri, 17 Mar 2023 16:04:58 +0800 Subject: [PATCH 22/24] auc images --- egs/himia/wuw/README.md | 2 +- egs/himia/wuw/RESULTS.md | 7 +++---- egs/himia/wuw/local/auc.py | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/egs/himia/wuw/README.md b/egs/himia/wuw/README.md index 50d92ccd9..f557aaaaf 100644 --- a/egs/himia/wuw/README.md +++ b/egs/himia/wuw/README.md @@ -6,5 +6,5 @@ AUC results for different epochs could be found at -Corresponding ROC curve is: +Corresponding ROC curve is: diff --git a/egs/himia/wuw/RESULTS.md b/egs/himia/wuw/RESULTS.md index 7b7006d36..21a331202 100644 --- a/egs/himia/wuw/RESULTS.md +++ b/egs/himia/wuw/RESULTS.md @@ -10,8 +10,7 @@ Here is the result for epoch_15-avg_1(with the highest AUC). | ---- | ---- | ----| | AUC | 0.9597 |0.9292| -HiMia-CW -![himia_cw](https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/resolve/main/exp_max_duration_100/post/epoch_15-avg_1/himia_cw.pdf) +![himia_aishell](https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/resolve/main/exp_max_duration_100/post/epoch_15-avg_1/himia_aishell.png) + +![himia_cw](https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/resolve/main/exp_max_duration_100/post/epoch_15-avg_1/himia_cw.png) -HiMia-Aishell -![himia_aishell](https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/resolve/main/exp_max_duration_100/post/epoch_15-avg_1/himia_aishell.pdf) diff --git a/egs/himia/wuw/local/auc.py b/egs/himia/wuw/local/auc.py index f5a210d87..09f2a276b 100755 --- a/egs/himia/wuw/local/auc.py +++ b/egs/himia/wuw/local/auc.py @@ -124,7 +124,7 @@ def main(): output_path = Path(args.positive_score_file).parent logging.info(f"AUC of {args.legend} {output_path}: {roc_auc}") - plt.savefig(f"{output_path}/{args.legend}.pdf", bbox_inches="tight") + plt.savefig(f"{output_path}/{args.legend}.png", bbox_inches="tight") if __name__ == "__main__": From f17b0e8035cf2113f5a1b2a6c1d2bac388e0dcd2 Mon Sep 17 00:00:00 2001 From: glynpu Date: Fri, 17 Mar 2023 16:12:17 +0800 Subject: [PATCH 23/24] add nubmer of parameters --- egs/himia/wuw/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/himia/wuw/README.md b/egs/himia/wuw/README.md index f557aaaaf..2d5ccde97 100644 --- a/egs/himia/wuw/README.md +++ b/egs/himia/wuw/README.md @@ -1,6 +1,6 @@ # Pretrained models and related logs/results. -## ctc tdnn baseline +## ctc tdnn model with Number of model parameters: 1,502,169 AUC results for different epochs could be found at From 0fb43289f477aa1a9f1d88215684f7808c1c0fd8 Mon Sep 17 00:00:00 2001 From: "LIyong.Guo" <839019390@qq.com> Date: Fri, 17 Mar 2023 17:50:40 +0800 Subject: [PATCH 24/24] Update egs/himia/wuw/ctc_tdnn/graph.py Co-authored-by: Fangjun Kuang --- egs/himia/wuw/ctc_tdnn/graph.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/egs/himia/wuw/ctc_tdnn/graph.py b/egs/himia/wuw/ctc_tdnn/graph.py index d1ff3114d..94b17a435 100644 --- a/egs/himia/wuw/ctc_tdnn/graph.py +++ b/egs/himia/wuw/ctc_tdnn/graph.py @@ -40,8 +40,7 @@ def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]) -> str: for non_wake_word_token in range(keyword_ilabel_start): fst_graph += f"0 0 {non_wake_word_token} 0\n" cur_state = 1 - for token_idx in range(len(wakeup_word_tokens) - 1): - token = wakeup_word_tokens[token_idx] + for token in wakeup_word_tokens[:-1]: fst_graph += f"{cur_state - 1} {cur_state} {token} 0\n" fst_graph += f"{cur_state} {cur_state} {token} 0\n" cur_state += 1