Merge 0fb43289f477aa1a9f1d88215684f7808c1c0fd8 into abd9437e6d5419a497707748eb935e50976c3b7b

This commit is contained in:
Liyong.Guo 2025-06-27 11:33:00 +00:00 committed by GitHub
commit b5ae175e8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 2422 additions and 0 deletions

10
egs/himia/wuw/README.md Normal file
View File

@ -0,0 +1,10 @@
# Pretrained models and related logs/results.
## ctc tdnn model with Number of model parameters: 1,502,169
AUC results for different epochs could be found at <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/tree/main>
E.g. for epoch 15 and avg 1, result log file is: <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/blob/main/exp_max_duration_100/post/epoch_15-avg_1/log/log-auc-himia_aishell-2023-03-16-17-42-14>
Corresponding ROC curve is: <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/blob/main/exp_max_duration_100/post/epoch_15-avg_1/himia_aishell.png>

16
egs/himia/wuw/RESULTS.md Normal file
View File

@ -0,0 +1,16 @@
## Results
### ctc tdnn model with Number of model parameters: 1,502,169
AUC results for different epochs could be found at <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/tree/main>
Here is the result for epoch_15-avg_1(with the highest AUC).
| test set | HiMia-Aishell | HiMia-CW|
| ---- | ---- | ----|
| AUC | 0.9597 |0.9292|
![himia_aishell](https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/resolve/main/exp_max_duration_100/post/epoch_15-avg_1/himia_aishell.png)
![himia_cw](https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/resolve/main/exp_max_duration_100/post/epoch_15-avg_1/himia_cw.png)

View File

@ -0,0 +1,423 @@
# Copyright 2022 Xiaomi Corporation (Author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class HiMiaWuwDataModule:
"""
DataModule for Himia wake word experiments.
It contains common data pipeline modules e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="Data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=False,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
group.add_argument(
"--train-channel",
type=str,
default="_7_01",
help="""channel of HI_MIA train dataset.
All channels are used if it is set "all".
Please refer to stage 6 in prepare.sh for its meaning and other
potential values. Currently, Only "_7_01" is verified.
""",
)
group.add_argument(
"--dev-channel",
type=str,
default="_7_01",
help="""channel of HI_MIA dev dataset.
All channels are used if it is set "all".
Please refer to stage 6 in prepare.sh for its meaning and other
potential values. Currently, Only "_7_01" is verified.
""",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=10,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
num_buckets=2,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
train_cuts_file = (
f"cuts_train_himia{self.args.train_channel}-aishell-shuf.jsonl.gz"
)
if "all" == self.args.train_channel:
train_cuts_file = "cuts_train_himia-aishell-shuf.jsonl.gz"
return load_manifest_lazy(self.args.manifest_dir / f"{train_cuts_file}")
@lru_cache()
def aishell_test_cuts(self) -> CutSet:
logging.info("About to get aishell test cuts")
return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz")
@lru_cache()
def cw_test_cuts(self) -> CutSet:
logging.info("About to get HI-MIA-CW test cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_cw_test.jsonl.gz")
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
dev_cuts_file = "cuts_dev.jsonl.gz"
if "all" != self.args.dev_channel:
dev_cuts_file = f"cuts_dev{self.args.dev_channel}.jsonl.gz"
return load_manifest_lazy(self.args.manifest_dir / f"{dev_cuts_file}")
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
# 7_01 is short for microphone 7 and channel 1.
return load_manifest_lazy(self.args.manifest_dir / "cuts_test_7_01.jsonl.gz")

316
egs/himia/wuw/ctc_tdnn/decode.py Executable file
View File

@ -0,0 +1,316 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (Author: Weiji Zhuang,
# Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
import logging
from concurrent.futures import ProcessPoolExecutor
from typing import Tuple
from pathlib import Path
import numpy as np
from lhotse.features.io import NumpyHdf5Reader
from tqdm import tqdm
from icefall.utils import (
AttributeDict,
setup_logger,
)
from train import get_params
from graph import ctc_trivial_decoding_graph
class Arc:
def __init__(
self, src_state: int, dst_state: int, ilabel: int, olabel: int
) -> None:
self.src_state = int(src_state)
self.dst_state = int(dst_state)
self.ilabel = int(ilabel)
self.olabel = int(olabel)
def next_state(self) -> None:
return self.dst_state
class State:
def __init__(self) -> None:
self.arc_list = list()
def add_arc(self, arc: Arc) -> None:
self.arc_list.append(arc)
class FiniteStateTransducer:
"""Represents a decoding graph for wake word detection."""
def __init__(self, graph: str) -> None:
"""
Construct a decoding graph in FST format given string format graph.
Args:
graph: A string format fst. Each arc is separated by "\n".
"""
self.state_list = list()
for arc_str in graph.split("\n"):
arc = arc_str.strip().split()
if len(arc) == 0:
continue
# An arc may contain 1, 2 or 4 elements, with format:
# src_state [dst_state] [ilabel] [olabel]
# 1 and 2 for final state
# 4 for non-final state
assert len(arc) in [1, 2, 4], f"{len(arc)} {arc_str}"
arc = [int(element) for element in arc]
src_state_id = arc[0]
max_state_id = len(self.state_list) - 1
if len(arc) == 4: # Non-final state
assert max_state_id <= src_state_id, (
f"Fsa must be sorted by src_state, "
f"while {max_state_id} <= {src_state_id}. Check your graph."
)
if max_state_id < src_state_id:
new_state = State()
self.state_list.append(new_state)
self.state_list[src_state_id].add_arc(
Arc(src_state_id, arc[1], arc[2], arc[3])
)
else:
assert (
max_state_id == src_state_id
), "Final state seems unreachable. Check your graph."
self.final_state_id = src_state_id
def to_str(self) -> None:
fst_str = ""
number_states = len(self.state_list)
if number_states == 0:
return fst_str
for state_idx in range(number_states):
cur_state = self.state_list[state_idx]
for arc_idx in range(len(cur_state.arc_list)):
cur_arc = cur_state.arc_list[arc_idx]
ilabel = cur_arc.ilabel
olabel = cur_arc.olabel
src_state = cur_arc.src_state
dst_state = cur_arc.dst_state
fst_str += f"{src_state} {dst_state} {ilabel} {olabel}\n"
fst_str += f"{dst_state}\n"
return fst_str
class Token:
def __init__(self) -> None:
self.is_active = False
self.total_score = -float("inf")
self.keyword_frames = 0
self.average_keyword_score = -float("inf")
self.average_max_keyword_score = 0.0
def set_token(
self,
src_token, # Token conneted to current token.
is_keyword_ilabel: bool,
acoustic_score: float,
) -> None:
"""
A dynamic programming process computing the highest score for a token
from all possible paths which could reach this token.
Args:
src_token: The source token connected to current token with an arc.
is_keyword_ilabel: If true, the arc consumes an input label which is
a part of wake word. Otherwhise, the input label is
blank or unknown, i.e. current token is still not part of wake word.
acoustic_score: acoustic score of this arc.
"""
if (
not self.is_active
or self.total_score < src_token.total_score + acoustic_score
):
self.is_active = True
self.total_score = src_token.total_score + acoustic_score
if is_keyword_ilabel:
self.average_keyword_score = (
acoustic_score
+ src_token.average_keyword_score * src_token.keyword_frames
) / (src_token.keyword_frames + 1)
self.keyword_frames = src_token.keyword_frames + 1
else:
self.average_keyword_score = 0.0
class SingleDecodable:
def __init__(
self,
model_output: np.array,
keyword_ilabel_start: int,
graph: FiniteStateTransducer,
):
"""
Args:
model_output: log_softmax(logit) with shape [T, C]
keyword_ilabel_start: index of the first token of the wake word.
In this recipe, tokens not for wake word has smaller token index,
i.e. blank 0; unk 1.
graph: decoding graph of the wake word.
"""
self.init_token_list = [Token() for i in range(len(graph.state_list))]
self.reset_token_list()
self.model_output = model_output
self.T = model_output.shape[0]
self.utt_score = 0.0
self.current_frame_index = 0
self.keyword_ilabel_start = keyword_ilabel_start
self.graph = graph
self.number_tokens = len(self.cur_token_list)
def reset_token_list(self) -> None:
"""
Reset all tokens to a condition without consuming any acoustic frames.
"""
self.cur_token_list = copy.deepcopy(self.init_token_list)
self.expand_token_list = copy.deepcopy(self.init_token_list)
self.cur_token_list[0].is_active = True
self.cur_token_list[0].total_score = 0
self.cur_token_list[0].average_keyword_score = 0
def process_oneframe(self) -> None:
"""
Decode a frame and update all tokens.
"""
for state_id, cur_token in enumerate(self.cur_token_list):
if cur_token.is_active:
for arc_id in self.graph.state_list[state_id].arc_list:
acoustic_score = self.model_output[self.current_frame_index][
arc_id.ilabel
]
is_keyword_ilabel = arc_id.ilabel >= self.keyword_ilabel_start
self.expand_token_list[arc_id.next_state()].set_token(
cur_token,
is_keyword_ilabel,
acoustic_score,
)
# use best_score to keep total_score in a good range
self.best_state_id = 0
best_score = self.expand_token_list[0].total_score
for state_id in range(self.number_tokens):
if self.expand_token_list[state_id].is_active:
if best_score < self.expand_token_list[state_id].total_score:
best_score = self.expand_token_list[state_id].total_score
self.best_state_id = state_id
self.cur_token_list = self.expand_token_list
for state_id in range(self.number_tokens):
self.cur_token_list[state_id].total_score -= best_score
self.expand_token_list = copy.deepcopy(self.init_token_list)
potential_score = np.exp(
self.cur_token_list[self.graph.final_state_id].average_keyword_score
)
if potential_score > self.utt_score:
self.utt_score = potential_score
self.current_frame_index += 1
def decode_utt(
params: AttributeDict,
utt_id: str,
post_file: str,
graph: FiniteStateTransducer,
) -> Tuple[str, float]:
"""
Decode a single utterance.
Args:
params:
The return value of :func:`get_params`.
utt_id: utt_id to be decoded, used to fetch posterior matrix from post_file.
post_file: file to save posterior for all test set.
graph: decoding graph in FiniteStateTransducer format.
Returns:
utt_id and its corresponding probability to be a wake word.
"""
reader = NumpyHdf5Reader(post_file)
model_output = reader.read(utt_id)
keyword_ilabel_start = params.wakeup_word_tokens[0]
decodable = SingleDecodable(
model_output=model_output,
keyword_ilabel_start=keyword_ilabel_start,
graph=graph,
)
for t in range(decodable.T):
decodable.process_oneframe()
return utt_id, decodable.utt_score
def get_parser():
parser = argparse.ArgumentParser(
description="A simple FST decoder for the wake word detection\n"
)
parser.add_argument(
"--post-h5",
type=str,
help="model output in h5 format",
)
parser.add_argument(
"--score-file",
type=str,
help="file to save scores of each utterance",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
post_dir = Path(params.post_h5).parent
test_set = Path(params.post_h5).stem
setup_logger(f"{post_dir}/log/log-decode-{test_set}")
graph = FiniteStateTransducer(ctc_trivial_decoding_graph(params.wakeup_word_tokens))
logging.info(f"Graph used:\n{graph.to_str()}")
logging.info(f"About to load {test_set}.")
keys = NumpyHdf5Reader(params.post_h5).hdf.keys()
with ProcessPoolExecutor() as executor, open(
params.score_file, "w", encoding="utf8"
) as fout:
futures = [
executor.submit(decode_utt, params, key, params.post_h5, graph)
for key in tqdm(keys)
]
logging.info(f"Decoding {test_set}.")
for future in tqdm(futures):
k, v = future.result()
fout.write(str(k) + " " + str(v) + "\n")
logging.info(f"Finish decoding {test_set}.")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,52 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (Author: Weiji Zhuang,
# Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]) -> str:
"""
A graph starts with blank/unknown and following by wakeup word.
Args:
wakeup_word_tokens: A sequence of token ids corresponding wakeup_word.
It should not contain 0 and 1.
We assume 0 is for blank and 1 is for unknown.
Returns:
Returns a finite-state transducer in string format,
used as a decoding graph.
Arcs are separated with "\n".
"""
assert 0 not in wakeup_word_tokens
assert 1 not in wakeup_word_tokens
assert len(wakeup_word_tokens) >= 2
keyword_ilabel_start = wakeup_word_tokens[0]
fst_graph = ""
for non_wake_word_token in range(keyword_ilabel_start):
fst_graph += f"0 0 {non_wake_word_token} 0\n"
cur_state = 1
for token in wakeup_word_tokens[:-1]:
fst_graph += f"{cur_state - 1} {cur_state} {token} 0\n"
fst_graph += f"{cur_state} {cur_state} {token} 0\n"
cur_state += 1
token = wakeup_word_tokens[-1]
fst_graph += f"{cur_state - 1} {cur_state} {token} 1\n"
fst_graph += f"{cur_state} {cur_state} {token} 0\n"
fst_graph += f"{cur_state}\n"
return fst_graph

View File

@ -0,0 +1,206 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corporation (Author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
import torch
from lhotse.features.io import NumpyHdf5Writer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
setup_logger,
)
from asr_datamodule import HiMiaWuwDataModule
from tdnn import Tdnn
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=10,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 1.",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="ctc_tdnn/exp",
help="The experiment dir",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"env_info": get_env_info(),
"feature_dim": 80,
"num_class": 9,
}
)
return params
def inference_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: torch.nn.Module,
test_set: str,
):
"""Compute and save model output of each utterance.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
test_set:
Name of test set.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
writer = NumpyHdf5Writer(f"{params.out_dir}/{test_set}")
for batch_idx, batch in enumerate(dl):
device = params.device
feature = batch["inputs"]
assert feature.ndim == 3
supervisions = batch["supervisions"]
start_frames = supervisions["start_frame"]
end_frames = start_frames + supervisions["num_frames"]
feature = feature.to(device)
# model_output is log_softmax(logit) with shape [N, T, C]
model_output = model(feature)
for i in range(feature.size(0)):
assert start_frames[i] == 0
cut = batch["supervisions"]["cut"][i]
cur_target = model_output[i][start_frames[i] : end_frames[i]]
writer.store_array(key=cut.id, value=cur_target.cpu().numpy())
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
@torch.no_grad()
def main():
parser = get_parser()
HiMiaWuwDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
out_dir = f"{params.exp_dir}/post/epoch_{params.epoch}-avg_{params.avg}/"
Path(out_dir).mkdir(parents=True, exist_ok=True)
params.out_dir = out_dir
setup_logger(f"{out_dir}/log/log-inference")
logging.info("Decoding started")
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = Tdnn(params.feature_dim, params.num_class)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=True)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=True
)
model.to(device)
model.eval()
params.device = device
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
himia = HiMiaWuwDataModule(args)
aishell_test_cuts = himia.aishell_test_cuts()
test_cuts = himia.test_cuts()
cw_test_cuts = himia.cw_test_cuts()
aishell_test_dl = himia.test_dataloaders(aishell_test_cuts)
test_dl = himia.test_dataloaders(test_cuts)
cw_test_dl = himia.test_dataloaders(cw_test_cuts)
test_sets = ["aishell_test", "test", "cw_test"]
test_dls = [aishell_test_dl, test_dl, cw_test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
logging.info(f"About to inference {test_set}")
inference_dataset(
dl=test_dl,
params=params,
model=model,
test_set=test_set,
)
logging.info(f"finish inferencing {test_set}")
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,108 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch import nn, Tensor
class Tdnn(nn.Module):
"""
Args:
num_features (int): Number of input features
num_classes (int): Number of output classes
"""
def __init__(self, num_features: int, num_classes: int) -> None:
super().__init__()
self.num_features = num_features
self.num_classes = num_classes
self.tdnn = nn.Sequential(
nn.Conv1d(
in_channels=num_features,
out_channels=240,
kernel_size=3,
stride=1,
padding=1,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=240, affine=False),
nn.Conv1d(
in_channels=240, out_channels=240, kernel_size=3, stride=1, padding=1
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=240, affine=False),
nn.Conv1d(
in_channels=240, out_channels=240, kernel_size=3, stride=1, padding=1
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=240, affine=False),
nn.Conv1d(
in_channels=240, out_channels=240, kernel_size=3, stride=1, padding=1
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=240, affine=False),
nn.Conv1d(
in_channels=240, out_channels=240, kernel_size=3, stride=1, padding=1
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=240, affine=False),
nn.Conv1d(
in_channels=240, out_channels=240, kernel_size=3, stride=1, padding=1
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=240, affine=False),
nn.Conv1d(
in_channels=240, out_channels=240, kernel_size=3, stride=1, padding=1
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=240, affine=False),
nn.Conv1d(
in_channels=240, out_channels=240, kernel_size=3, stride=1, padding=1
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=240, affine=False),
nn.Conv1d(
in_channels=240, out_channels=240, kernel_size=3, stride=1, padding=1
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=240, affine=False),
nn.Conv1d(
in_channels=240, out_channels=240, kernel_size=1, stride=1, padding=0
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=240, affine=False),
nn.Conv1d(
in_channels=240,
out_channels=num_classes,
kernel_size=1,
stride=1,
padding=0,
),
nn.LogSoftmax(1),
)
def forward(self, x: Tensor) -> Tensor:
r"""
Args:
x (torch.Tensor): Tensor of dimension (N, T, C).
Returns:
Tensor: Predictor tensor of dimension (N, T, C).
"""
x = x.transpose(1, 2)
x = self.tdnn(x)
x = x.transpose(1, 2)
return x

View File

@ -0,0 +1,101 @@
# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import torch
from typing import List, Tuple
class WakeupWordTokenizer(object):
def __init__(
self,
wakeup_word: str,
wakeup_word_tokens: List[int],
) -> None:
"""
Args:
wakeup_word: content of positive samples.
A sample will be treated as a negative sample unless its content
is exactly the same to key_words.
wakeup_word_tokens: A list of int representing token ids of wakeup_word.
For example: the pronunciation of "你好米雅" is
"n i h ao m i y a".
Suppose we are using following lexicon:
blk 0
unk 1
n 2
i 3
h 4
ao 5
m 6
y 7
a 8
Then wakeup_word_tokens for "你好米雅" is:
n i h ao m i y a
[2, 3, 4, 5, 6, 3, 7, 8]
"""
super().__init__()
assert wakeup_word is not None
assert wakeup_word_tokens is not None
assert (
0 not in wakeup_word_tokens
), f"0 is kept for blank. Please Remove 0 from {wakeup_word_tokens}"
assert 1 not in wakeup_word_tokens, (
f"1 is kept for unknown and negative samples. "
f" Please Remove 1 from {wakeup_word_tokens}"
)
self.wakeup_word = wakeup_word
self.wakeup_word_tokens = wakeup_word_tokens
self.positive_number_tokens = len(wakeup_word_tokens)
self.negative_word_tokens = [1]
self.negative_number_tokens = 1
def texts_to_token_ids(
self, texts: List[str]
) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""Convert a list of texts to parameters needed by CTC loss.
Args:
texts:
It is a list of strings,
each element is a reference text for an audio.
Returns:
Return a tuple of 3 elements.
The first one is torch.Tensor(List[List[int]]),
each List[int] is tokens sequence for each reference text.
The second one is number of tokens for each sample,
mainly used by CTC loss.
The last one is number_positive_samples,
used to track proportion of positive samples in each batch.
"""
batch_token_ids = []
target_lengths = []
number_positive_samples = 0
for utt_text in texts:
if utt_text == self.wakeup_word:
batch_token_ids.extend(self.wakeup_word_tokens)
target_lengths.append(self.positive_number_tokens)
number_positive_samples += 1
else:
batch_token_ids.extend(self.negative_word_tokens)
target_lengths.append(self.negative_number_tokens)
target = torch.tensor(batch_token_ids)
target_lengths = torch.tensor(target_lengths)
return target, target_lengths, number_positive_samples

667
egs/himia/wuw/ctc_tdnn/train.py Executable file
View File

@ -0,0 +1,667 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
export CUDA_VISIBLE_DEVICES="0"
./ctc_tdnn/train.py \
--exp-dir ./ctc_tdnn/exp \
--world-size 1 \
--max-duration 100 \
--num-epochs 20
"""
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import HiMiaWuwDataModule
from tdnn import Tdnn
from lhotse.utils import fix_random_seed
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from tokenizer import WakeupWordTokenizer
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=20,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=1,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
ctc_tdnn/exp/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="ctc_tdnn/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lr-factor",
type=float,
default=0.001,
help="The lr_factor for optimizer",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- feature_dim: The model input dim. It has to match the one used
in computing features.
- num_class: Number of classes. Each token will have a token id
from [0, num_class).
In this recipe, 0 is usually kept for blank,
and 1 is usually kept for negative words.
- wakeup_word: Text of wakeup word, i.e. positive samples.
- wakeup_word_tokens: A sequence of token ids corresponding wakeup_word.
- weight_decay: The weight_decay for the optimizer.
"""
params = AttributeDict(
{
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 5,
"reset_interval": 200,
"valid_interval": 3000,
# parameters for model
"feature_dim": 80,
"num_class": 9,
# parameters for tokenizer
"wakeup_word": "你好米雅",
"wakeup_word_tokens": [2, 3, 4, 5, 6, 3, 7, 8],
# parameters for Optimizer
"weight_decay": 1e-6,
"env_info": get_env_info(),
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
params: AttributeDict,
model: nn.Module,
batch: dict,
tokenizer: WakeupWordTokenizer,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Conformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
tokenizer:
For positive samples, map their texts to corresponding token index sequence.
While for negative samples, map their texts to unknown no matter what they are.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
"""
device = model.device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
N, T, C = feature.shape
feature = feature.to(device)
supervisions = batch["supervisions"]
texts = supervisions["text"]
with torch.set_grad_enabled(is_training):
# model_output is log_softmax(logit) with shape [N, T, C]
model_output = model(feature)
assert torch.all(supervisions["start_frame"] == 0)
num_frames = supervisions["num_frames"].to(device)
target, target_lengths, number_positive_samples = tokenizer.texts_to_token_ids(
texts
) # noqa E501
target = target.to(device)
target_lengths = target_lengths.to(device)
ctc_loss = nn.CTCLoss(reduction="sum")
# [N, T, C] --> [T, N, C]
model_output = model_output.transpose(0, 1)
loss = ctc_loss(model_output, target, num_frames, target_lengths)
loss /= num_frames.sum()
assert loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = num_frames.sum().item()
info["loss"] = loss.detach().cpu().item() * info["frames"]
# `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
info["utterances"] = feature.size(0)
# averaged input duration in frames over utterances
info["utt_duration"] = supervisions["num_frames"].sum().item()
# averaged padding proportion over utterances
info["utt_pad_proportion"] = (
((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item()
)
info["number_positive_cuts_ratio"] = (number_positive_samples / N) * info["frames"]
return loss, info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
tokenizer: WakeupWordTokenizer,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
tokenizer=tokenizer,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
tokenizer: WakeupWordTokenizer,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
tokenizer:
For positive samples, map their texts to corresponding token index sequence.
While for negative samples, map their texts to unknown no matter what they are.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
tokenizer=tokenizer,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
tokenizer=tokenizer,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
logging.info(params)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
tokenizer = WakeupWordTokenizer(
wakeup_word=params.wakeup_word,
wakeup_word_tokens=params.wakeup_word_tokens,
)
logging.info("About to create model")
model = Tdnn(params.feature_dim, params.num_class)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[rank])
model.device = device
optimizer = torch.optim.Adam(
model.parameters(),
lr=params.lr_factor,
weight_decay=params.weight_decay,
)
if checkpoints:
optimizer.load_state_dict(checkpoints["optimizer"])
himia = HiMiaWuwDataModule(args)
train_cuts = himia.train_cuts()
train_dl = himia.train_dataloaders(train_cuts)
valid_cuts = himia.dev_cuts()
valid_dl = himia.valid_dataloaders(valid_cuts)
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
tokenizer=tokenizer,
params=params,
)
for epoch in range(params.start_epoch, params.num_epochs + 1):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
# TODO: Support lr scheduler
cur_lr = params.lr_factor
if tb_writer is not None:
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
tokenizer=tokenizer,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def scan_pessimistic_batches_for_oom(
model: nn.Module,
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
tokenizer: WakeupWordTokenizer,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
logging.info(
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
)
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
optimizer.zero_grad()
loss, _ = compute_loss(
params=params,
model=model,
batch=batch,
tokenizer=tokenizer,
is_training=True,
)
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logging.error(
"Your GPU ran out of memory with the current "
"max_duration setting. We recommend decreasing "
"max_duration and trying again.\n"
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
raise
def main():
parser = get_parser()
HiMiaWuwDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

131
egs/himia/wuw/local/auc.py Executable file
View File

@ -0,0 +1,131 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (Author: Weiji Zhuang,
# Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from typing import Dict, Tuple
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from sklearn.metrics import roc_curve, auc
from icefall.utils import setup_logger
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--positive-score-file",
type=str,
required=True,
help="score file of positive data",
)
parser.add_argument(
"--negative-score-file",
type=str,
required=True,
help="score file of negative data",
)
parser.add_argument(
"--legend",
type=str,
required=True,
help="legend of ROC curve picture.",
)
return parser.parse_args()
def load_score(score_file: Path) -> Dict[str, float]:
"""
Args:
score_file: Path to score file. Each line has two columns.
The first column is utt-id, and the second one is score.
This score could be viewed as probability of being wakeup word.
Returns:
A dict with that key is utt-id and value is corresponding score.
"""
pos_dict = {}
with open(score_file, "r", encoding="utf8") as fin:
for line in fin:
arr = line.strip().split()
assert len(arr) == 2
key = arr[0]
score = float(arr[1])
pos_dict[key] = score
return pos_dict
def get_roc_and_auc(
pos_dict: Dict,
neg_dict: Dict,
) -> Tuple[np.array, np.array, float]:
"""
Args:
pos_dict: scores of positive samples.
neg_dict: scores of negative samples.
Return:
A tuple of three elements, which will be used to plot ROC curve.
Refer to sklearn.metrics.roc_curve for meaning of the first and second elements.
The third element is area under the ROC curve(AUC).
"""
pos_scores = np.fromiter(pos_dict.values(), dtype=float)
neg_scores = np.fromiter(neg_dict.values(), dtype=float)
pos_y = np.ones_like(pos_scores, dtype=int)
neg_y = np.zeros_like(neg_scores, dtype=int)
scores = np.concatenate([pos_scores, neg_scores])
y = np.concatenate([pos_y, neg_y])
fpr, tpr, thresholds = roc_curve(y, scores, pos_label=1)
roc_auc = auc(fpr, tpr)
return fpr, tpr, roc_auc
def main():
args = get_args()
score_dir = Path(args.positive_score_file).parent
setup_logger(f"{score_dir}/log/log-auc-{args.legend}")
logging.info(f"About to compute AUC of {args.legend}")
pos_dict = load_score(args.positive_score_file)
neg_dict = load_score(args.negative_score_file)
fpr, tpr, roc_auc = get_roc_and_auc(pos_dict, neg_dict)
plt.figure(figsize=(16, 9))
plt.plot(fpr, tpr, label=f"{args.legend}(AUC = %1.8f)" % roc_auc)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver operating characteristic(ROC)")
plt.legend(loc="lower right")
output_path = Path(args.positive_score_file).parent
logging.info(f"AUC of {args.legend} {output_path}: {roc_auc}")
plt.savefig(f"{output_path}/{args.legend}.png", bbox_inches="tight")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../aishell/ASR/local/compute_fbank_aishell.py

View File

@ -0,0 +1,139 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the HI_MIA and HI_MIA_CW dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--train-set-channel",
type=str,
default="_7_01",
help="""channel of HI_MIA dataset.
All channels are used if it is set "all".
""",
)
parser.add_argument(
"--enable-speed-perturb",
type=str2bool,
default=False,
help="""channel of training set.
""",
)
return parser.parse_args()
def compute_fbank_himia(
train_set_channel: str = None,
enable_speed_perturb: bool = True,
):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(40, os.cpu_count())
num_mel_bins = 80
if "all" == train_set_channel:
dataset_parts = (
"train",
"dev",
"test",
"cw_test",
)
else:
dataset_parts = (
f"train{train_set_channel}",
f"dev{train_set_channel}",
f"test{train_set_channel}",
"cw_test",
)
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts, prefix="himia", output_dir=src_dir
)
assert manifests is not None
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
if (output_dir / f"cuts_{partition}.jsonl.gz").is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition and enable_speed_perturb:
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cut_set = cut_set.resample(16000)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomHdf5Writer,
)
output_file_name = f"cuts_{partition}.jsonl.gz"
if "all" != train_set_channel:
output_file_name = f"cuts_{partition}{train_set_channel}.jsonl.gz"
cut_set.to_file(output_dir / f"{output_file_name}")
def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
args = get_args()
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_himia(
train_set_channel=args.train_set_channel,
enable_speed_perturb=args.enable_speed_perturb,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compute_fbank_musan.py

193
egs/himia/wuw/prepare.sh Executable file
View File

@ -0,0 +1,193 @@
#!/usr/bin/env bash
set -eou pipefail
stage=0
stop_stage=6
# HI_MIA and aishell dataset are used in this experiment.
# musan dataset is used for data augmentation.
#
# For aishell dataset downloading and preparation,
# refer to icefall/egs/aishell/ASR/prepare.sh.
#
# For HI_MIA and HI_MIA_CW dataset,
# we assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
# Then these files will be extracted to $dl_dir/HiMia/
#
# - $dl_dir/train.tar.gz
# Himia training dataset.
# From https://www.openslr.org/85
#
# - $dl_dir/dev.tar.gz
# Himia Devlopment dataset.
# From https://www.openslr.org/85
#
# - $dl_dir/test_v2.tar.gz
# Himia test dataset.
# From https://www.openslr.org/85
#
# - $dl_dir/data.tgz
# Himia confusion words(HI_MIA_CW) test dataset.
# From https://www.openslr.org/120
# - $dl_dir/resource.tgz
# Transcripts of (HI_MIA_CW) test dataset.
# From https://www.openslr.org/120
dl_dir=$PWD/download
train_set_channel=_7_01
enable_speed_perturb=False
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# If you have pre-downloaded HI_MIA and HI_MIA_CW dataset to /path/to/himia/,
# you can create a symlink
#
# ln -sfv /path/to/himia $dl_dir/
#
if [ ! -f $dl_dir/train.tar.gz ]; then
lhotse download himia $dl_dir/
fi
# If you have pre-downloaded it to /path/to/musan,
# you can create a symlink
#
# ln -sfv /path/to/musan $dl_dir/
#
if [ ! -d $dl_dir/musan ]; then
lhotse download musan $dl_dir
fi
# If you have pre-downloaded it to /path/to/aishell,
# you can create a symlink
#
# ln -sfv /path/to/aishell $dl_dir/aishell
#
# The directory structure is
# aishell/
# |-- data_aishell
# | |-- transcript
# | `-- wav
# `-- resource_aishell
# |-- lexicon.txt
# `-- speaker.info
if [ ! -d $dl_dir/aishell/data_aishell/wav/train ]; then
lhotse download aishell $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare HI_MIA and HI_MIA_CW manifest"
mkdir -p data/manifests
if [ ! -e data/manifests/.himia.done ]; then
lhotse prepare himia $dl_dir/HiMia data/manifests
touch data/manifests/.himia.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to data/musan
mkdir -p data/manifests
if [ ! -e data/manifests/.musan.done ]; then
lhotse prepare musan $dl_dir/musan data/manifests
touch data/manifests/.musan.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare aishell manifest"
# We assume that you have downloaded the aishell corpus
# to $dl_dir/aishell
if [ ! -f data/manifests/.aishell_manifests.done ]; then
mkdir -p data/manifests
lhotse prepare aishell $dl_dir/aishell data/manifests
touch data/manifests/.aishell_manifests.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for aishell"
if [ ! -f data/fbank/.aishell.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell.py \
--enable-speed-perturb=${enable_speed_perturb}
touch data/fbank/.aishell.done
fi
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute fbank for musan"
mkdir -p data/fbank
if [ ! -e data/fbank/.musan.done ]; then
./local/compute_fbank_musan.py
touch data/fbank/.musan.done
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Compute fbank for HI_MIA and HI_MIA_CW dataset"
# Format of train_set_channel is "micropohone position"_"channel"
# Microphone 1 to 6 is an array with 16 channels.
# Microphone 8 only has a single channel.
# So valid examples of train_set_channel could be:
# 1_01, ..., 1_16
# 2_01, ..., 2_16
# ...
# 6_01, ..., 6_16
# 7_01
train_set_channel="_7_01"
for subset in train dev test; do
for file_type in recordings supervisions; do
src=data/manifests/himia_${file_type}_${subset}.jsonl.gz
dst=data/manifests/himia_${file_type}_${subset}${train_set_channel}.jsonl.gz
cat <(gunzip -c ${src}) | \
grep ${train_set_channel} | \
gzip -c > ${dst}
done
done
mkdir -p data/fbank
if [ ! -e data/fbank/.himia.done ]; then
./local/compute_fbank_himia.py \
--train-set-channel=${train_set_channel} \
--enable-speed-perturb=${enable_speed_perturb}
touch data/fbank/.himia.done
fi
train_file=data/fbank/cuts_train_himia${train_set_channel}-aishell-shuf.jsonl.gz
if [ ! -f ${train_file} ]; then
# SingleCutSampler is preferred for this experiment
# rather than DynamicBucketingSampler.
# Since negative audios(Aishell) tends to be longer than positive ones(HiMia).
# if DynamicBucketingSample is used, a batch may contain either all negative sample
# or positive sample.
# So `shuf` the training dataset here and use SingleCutSampler to load data.
cat <(gunzip -c data/fbank/aishell_cuts_train.jsonl.gz) \
<(gunzip -c data/fbank/cuts_train${train_set_channel}.jsonl.gz) | \
grep -v _sp | \
shuf |shuf | gzip -c > ${train_file}
fi
fi

View File

@ -0,0 +1,57 @@
#!/usr/bin/env bash
set -eou pipefail
# You need to execute ./prepare.sh to prepare datasets.
stage=0
stop_stage=2
epoch=20
avg=1
max_duration=200
exp_dir=./ctc_tdnn/exp_max_duration_${max_duration}/
epoch_avg=epoch_${epoch}-avg_${avg}
post_dir=${exp_dir}/post/${epoch_avg}
. shared/parse_options.sh || exit 1
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Model training"
python ./ctc_tdnn/train.py \
--num-epochs $epoch \
--exp-dir $exp_dir \
--max-duration $max_duration
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Get posterior(log_softmax(logit)) of test sets"
python ctc_tdnn/inference.py \
--avg $avg \
--epoch $epoch \
--exp-dir $exp_dir
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Decode and compute area under curve(AUC)"
for test_set in test aishell_test cw_test; do
python ctc_tdnn/decode.py \
--post-h5 ${post_dir}/${test_set}.h5 \
--score-file ${post_dir}/fst_${test_set}_score.txt
done
python ./local/auc.py \
--legend himia_cw \
--positive-score-file ${post_dir}/fst_test_score.txt \
--negative-score-file ${post_dir}/fst_cw_test_score.txt
python ./local/auc.py \
--legend himia_aishell \
--positive-score-file ${post_dir}/fst_test_score.txt \
--negative-score-file ${post_dir}/fst_aishell_test_score.txt
fi

1
egs/himia/wuw/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared