Add gigaspeech kws recipe

This commit is contained in:
pkufool 2023-12-25 18:04:24 +08:00
parent 17dab02dc9
commit 44bc60ff38
15 changed files with 2576 additions and 8 deletions

View File

@ -312,6 +312,8 @@ class GigaSpeechAsrDataModule:
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
)
else:
logging.info("Using SimpleCutSampler.")
@ -366,6 +368,8 @@ class GigaSpeechAsrDataModule:
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle=False,
)
logging.info("About to create dev dataloader")
@ -415,6 +419,7 @@ class GigaSpeechAsrDataModule:
logging.info(
f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode"
)
cuts_train = lhotse.combine(
lhotse.load_manifest_lazy(p) for p in sorted_filenames
)

View File

@ -1171,9 +1171,16 @@ def run(rank, world_size, args):
if params.inf_check:
register_inf_check_hooks(model)
def remove_short_utt(c: Cut):
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 7) // 2 + 1) // 2
return T > 0
gigaspeech = GigaSpeechAsrDataModule(args)
train_cuts = gigaspeech.train_cuts()
train_cuts = train_cuts.filter(remove_short_utt)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
@ -1187,16 +1194,17 @@ def run(rank, world_size, args):
)
valid_cuts = gigaspeech.dev_cuts()
valid_cuts = valid_cuts.filter(remove_short_utt)
valid_dl = gigaspeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
# if not params.print_diagnostics:
# scan_pessimistic_batches_for_oom(
# model=model,
# train_dl=train_dl,
# optimizer=optimizer,
# sp=sp,
# params=params,
# )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:

View File

@ -0,0 +1,449 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2023 Xiaomi Corporation (Author: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import inspect
import logging
import re
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import lhotse
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import AudioSamples, OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class GigaSpeechAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
# GigaSpeech specific arguments
group.add_argument(
"--subset",
type=str,
default="XL",
help="Select the GigaSpeech subset (XS|S|M|L|XL)",
)
group.add_argument(
"--small-dev",
type=str2bool,
default=False,
help="Should we use only 1000 utterances for dev (speeds up training)",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info(f"About to get train {self.args.subset} cuts")
if self.args.subset == "XL":
filenames = glob.glob(
f"{self.args.manifest_dir}/XL_split/gigaspeech_cuts_XL.*.jsonl.gz"
)
pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz")
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
sorted_filenames = [f[1] for f in idx_filenames]
logging.info(
f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode"
)
cuts_train = lhotse.combine(
lhotse.load_manifest_lazy(p) for p in sorted_filenames
)
else:
path = (
self.args.manifest_dir / f"gigaspeech_cuts_{self.args.subset}.jsonl.gz"
)
cuts_train = CutSet.from_jsonl_lazy(path)
return cuts_train
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
cuts_valid = load_manifest_lazy(
self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz"
)
if self.args.small_dev:
return cuts_valid.subset(first=1000)
else:
return cuts_valid
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py

View File

@ -0,0 +1,648 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
"""
import argparse
import logging
import math
import os
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import GigaSpeechAsrDataModule
from beam_search import (
keywords_search,
)
from train import add_model_arguments, get_model, get_params
from lhotse.cut import Cut
from icefall import ContextGraph
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
@dataclass
class KwMetric:
TP: int = 0 # True positive
FN: int = 0 # False negative
FP: int = 0 # False positive
TN: int = 0 # True negative
FN_list: List[str] = field(default_factory=list)
FP_list: List[str] = field(default_factory=list)
TP_list: List[str] = field(default_factory=list)
def __str__(self) -> str:
return f"(TP:{self.TP}, FN:{self.FN}, FP:{self.FP}, TN:{self.TN})"
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--beam",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--keywords-file",
type=str,
help="File contains keywords.",
)
parser.add_argument(
"--keywords-score",
type=float,
default=3.0,
help="""
The default boosting score (token level) for keywords. it will boost the
paths that match keywords to make them survive beam search.
""",
)
parser.add_argument(
"--keywords-threshold",
type=float,
default=0.75,
help="The default threshold (probability) to trigger the keyword.",
)
parser.add_argument(
"--num-tailing-blanks",
type=int,
default=8,
help="The number of tailing blanks should have after hitting one keyword.",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
kws_graph: Optional[ContextGraph] = None,
) -> List[List[Tuple[str, Tuple[int, int]]]]:
"""Decode one batch and return the result in a list.
The length of the list equals to batch size, the i-th element contains the
triggered keywords for the i-th utterance in the given batch. The triggered
keywords are also a list, each of it contains a tuple of hitting keyword and
the corresponding start timestamps and end timestamps of the hitting keyword.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
kws_graph:
The graph containing keywords.
Returns:
Return the decoding result. See above description for the format of
the returned list.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
if params.causal:
# this seems to cause insertions at the end of the utterance if used with zipformer.
pad_len = 30
feature_lens += pad_len
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, pad_len),
value=LOG_EPS,
)
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
ans_dict = keywords_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
kws_graph=kws_graph,
beam=params.beam,
num_tailing_blanks=params.num_tailing_blanks,
blank_penalty=params.blank_penalty,
)
hyps = []
for ans in ans_dict:
hyp = []
for hit in ans:
hyp.append((hit.phrase, (hit.timestamps[0], hit.timestamps[-1])))
hyps.append(hyp)
return hyps
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
kws_graph: ContextGraph,
keywords: Set[str],
) -> Tuple[List[Tuple[str, List[str], List[str]]], KwMetric]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
kws_graph:
The graph containing keywords.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
log_interval = 50
results = []
metric = {"all": KwMetric()}
for k in keywords:
metric[k] = KwMetric()
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps = decode_one_batch(
params=params,
model=model,
sp=sp,
kws_graph=kws_graph,
batch=batch,
)
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_text = ref_text.upper()
ref_words = ref_text.split()
hyp_words = [x[0] for x in hyp_words]
this_batch.append((cut_id, ref_words, " ".join(hyp_words).split()))
hyp_set = set(hyp_words)
hyp_str = " | ".join(hyp_words)
for x in hyp_set:
assert x in keywords, x
if x in ref_text and x in keywords:
metric["all"].TP += 1
metric[x].TP += 1
metric[x].TP_list.append(f"({ref_text} -> {x})")
if x not in ref_text and x in keywords:
metric["all"].FP += 1
metric[x].FP += 1
metric[x].FP_list.append(f"({ref_text} -> {x})")
for x in keywords:
if x not in ref_text and x not in hyp_set:
metric["all"].TN += 1
metric[x].TN += 1
if x in ref_text:
fn = True
for y in hyp_set:
if y in ref_text:
fn = False
break
if fn and ref_text.endswith(x):
metric["all"].FN += 1
metric[x].FN += 1
metric[x].FN_list.append(f"({ref_text} -> {hyp_str})")
results.extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results, metric
def save_results(
params: AttributeDict,
test_set_name: str,
results: List[Tuple[str, List[str], List[str]]],
metric: KwMetric,
):
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
metric_filename = params.res_dir / f"metric-{test_set_name}-{params.suffix}.txt"
print_s = ""
with open(metric_filename, "w") as of:
width = 10
for key, item in sorted(
metric.items(), key=lambda x: (x[1].FP, x[1].FN), reverse=True
):
acc = (item.TP + item.TN) / (item.TP + item.TN + item.FP + item.FN)
precision = (item.TP + 1) / (item.TP + item.FP + 1)
recall = (item.TP + 1) / (item.TP + item.FN + 1)
fpr = (item.FP + 1) / (item.FP + item.TN + 1)
s = f"{key}:\n"
s += f"\t{'TP':{width}}{'FP':{width}}{'FN':{width}}{'TN':{width}}\n"
s += f"\t{str(item.TP):{width}}{str(item.FP):{width}}{str(item.FN):{width}}{str(item.TN):{width}}\n"
s += f"\tAccuracy: {acc:.3f}\n"
s += f"\tPrecision: {precision:.3f}\n"
s += f"\tRecall(PPR): {recall:.3f}\n"
s += f"\tFPR: {fpr:.3f}\n"
s += f"\tF1: {2 * precision * recall / (precision + recall):.3f}\n"
s += f"\tTP list: {' # '.join(item.TP_list)}\n"
s += f"\tFP list: {' # '.join(item.FP_list)}\n"
s += f"\tFN list: {' # '.join(item.FN_list)}\n"
of.write(s + "\n")
if key == "all":
logging.info(s)
logging.info("Wrote metric stats to {}".format(metric_filename))
@torch.no_grad()
def main():
parser = get_parser()
GigaSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / "kws"
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.causal:
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
params.suffix += f"-score-{params.keywords_score}"
params.suffix += f"-threshold-{params.keywords_threshold}"
params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}"
if params.blank_penalty != 0:
params.suffix += f"-blank-penalty-{params.blank_penalty}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
phrases = []
token_ids = []
keywords_scores = []
keywords_thresholds = []
with open(params.keywords_file, "r") as f:
for line in f.readlines():
score = 0
threshold = 0
keyword = []
words = line.strip().upper().split()
for word in words:
word = word.strip()
if word[0] == ":":
score = float(word[1:])
continue
if word[0] == "#":
threshold = float(word[1:])
continue
keyword.append(word)
keyword = " ".join(keyword)
phrases.append(keyword)
token_ids.append(sp.encode(keyword))
keywords_scores.append(score)
keywords_thresholds.append(threshold)
kws_graph = ContextGraph(
context_score=params.keywords_score, ac_threshold=params.keywords_threshold
)
kws_graph.build(
token_ids=token_ids,
phrases=phrases,
scores=keywords_scores,
ac_thresholds=keywords_thresholds,
)
keywords = set(phrases)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
gigaspeech = GigaSpeechAsrDataModule(args)
test_cuts = gigaspeech.test_cuts()
test_dl = gigaspeech.test_dataloaders(test_cuts)
def select_keyword_cuts(c: Cut):
text = c.supervisions[0].text
text = text.strip().upper()
return text in keywords
test_sc1_cuts = gigaspeech.test_speechcommands1_cuts()
test_sc2_cuts = gigaspeech.test_speechcommands2_cuts()
test_fsc_cuts = gigaspeech.test_fluent_speechcommands_cuts()
test_fsc_cuts = test_fsc_cuts.filter(select_keyword_cuts)
test_sc1_dl = gigaspeech.test_dataloaders(test_sc1_cuts)
test_sc2_dl = gigaspeech.test_dataloaders(test_sc2_cuts)
test_fsc_dl = speechcommand.test_dataloaders(test_fsc_cuts)
test_sets = ["test-fsc", "test", "test-sc1", "test-sc2"]
test_dls = [test_fsc_dl, test_dl, test_sc1_dl, test_sc2_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results, metric = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
kws_graph=kws_graph,
keywords=keywords,
)
save_results(
params=params,
test_set_name=test_set,
results=results,
metric=metric,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/transducer_stateless/encoder_interface.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/optim.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/subsampling.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/zipformer.py

View File

@ -28,6 +28,8 @@ from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from pypinyin import pinyin, lazy_pinyin
from pypinyin.contrib.tone_convert import to_initials, to_finals_tone, to_finals
from shutil import copyfile
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
@ -327,6 +329,19 @@ def encode_supervisions_otc(
return supervision_segments, res, sorted_ids, sorted_verbatim_texts
@dataclass
class KeywordResult:
# timestamps[k] contains the frame number on which tokens[k]
# is decoded
timestamps: List[int]
# hyps is the keyword, i.e., word IDs or token IDs
hyps: List[int]
# The triggered phrase
phrase: str
@dataclass
class DecodingResults:
# timestamps[i][k] contains the frame number on which tokens[i][k]
@ -1583,6 +1598,87 @@ def load_averaged_model(
return model
def text_to_pinyin(
txt: str, mode: str = "full_with_tone", errors: str = "default"
) -> List[str]:
"""
Convert a Chinese text (might contain some latin characters) to pinyin sequence.
Args:
txt:
The input Chinese text.
mode:
The style of the output pinyin, should be:
full_with_tone : zhong1 guo2
full_no_tone : zhong guo
partial_with_tone : zh ong1 g uo2
partial_no_tone : zh ong g uo
errors:
How to handle the characters (latin) that has no pinyin.
default : output the same as input.
split : split into single characters (i.e. alphabets)
Return:
Return a list of str.
Examples:
txt: 想吃KFC
output: ['xiǎng', 'chī', 'KFC'] # mode=full_with_tone; errors=default
output: ['xiǎng', 'chī', 'K', 'F', 'C'] # mode=full_with_tone; errors=split
output: ['xiang', 'chi', 'KFC'] # mode=full_no_tone; errors=default
output: ['xiang', 'chi', 'K', 'F', 'C'] # mode=full_no_tone; errors=split
output: ['x', 'iǎng', 'ch', 'ī', 'KFC'] # mode=partial_with_tone; errors=default
output: ['x', 'iang', 'ch', 'i', 'KFC'] # mode=partial_no_tone; errors=default
"""
assert mode in (
"full_with_tone",
"full_no_tone",
"partial_no_tone",
"partial_with_tone",
), mode
assert errors in ("default", "split"), errors
txt = txt.strip()
res = []
if "full" in mode:
if errors == "default":
py = pinyin(txt) if mode == "full_with_tone" else lazy_pinyin(txt)
else:
py = (
pinyin(txt, errors=lambda x: list(x))
if mode == "full_with_tone"
else lazy_pinyin(txt, errors=lambda x: list(x))
)
res = [x[0] for x in py] if mode == "full_with_tone" else py
else:
if errors == "default":
py = pinyin(txt) if mode == "partial_with_tone" else lazy_pinyin(txt)
else:
py = (
pinyin(txt, errors=lambda x: list(x))
if mode == "partial_with_tone"
else lazy_pinyin(txt, errors=lambda x: list(x))
)
py = [x[0] for x in py] if mode == "partial_with_tone" else py
for x in py:
initial = to_initials(x, strict=False)
final = (
to_finals(x, strict=False)
if mode == "partial_no_tone"
else to_finals_tone(x, strict=False)
)
if initial == "" and final == "":
res.append(x)
else:
if initial != "":
res.append(initial)
if final != "":
res.append(final)
return res
def tokenize_by_bpe_model(
sp: spm.SentencePieceProcessor,
txt: str,