mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Merge aa3942ac5190b31d026c96205216a8892a6c05a1 into 9af144c26b91065a119d4e67c03004974462d24d
This commit is contained in:
commit
c20ceedafa
0
egs/libriheavy/ASR/zipformer_prompt_asr/__init__.py
Normal file
0
egs/libriheavy/ASR/zipformer_prompt_asr/__init__.py
Normal file
467
egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py
Normal file
467
egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py
Normal file
@ -0,0 +1,467 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from dataset import PromptASRDataset
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import (
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
ExtraPadding,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class LibriHeavyAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--concatenate-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, utterances (cuts) will be concatenated "
|
||||
"to minimize the amount of padding.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--duration-factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Determines the maximum duration of a concatenated cut "
|
||||
"relative to the duration of the longest cut in a batch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gap",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The amount of padding (in seconds) inserted between "
|
||||
"concatenated cuts. This padding is filled with noise when "
|
||||
"noise augmentation is used.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it "
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
# Libriheavy specific arguments
|
||||
group.add_argument(
|
||||
"--subset",
|
||||
type=str,
|
||||
default="small",
|
||||
help="Select the Libriheavy subset (small|medium|large)",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--random-left-padding",
|
||||
type=str2bool,
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
text_sampling_func: Callable[[List[str]], str] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(
|
||||
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
||||
)
|
||||
transforms.append(
|
||||
CutMix(
|
||||
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||
)
|
||||
# Cut concatenation should be the first transform in the list,
|
||||
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||
# different utterances.
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(
|
||||
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||
)
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = PromptASRDataset(
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
text_sampling_func=text_sampling_func,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
# NOTE: the PerturbSpeed transform should be added only if we
|
||||
# remove it from data prep stage.
|
||||
# Add on-the-fly speed perturbation; since originally it would
|
||||
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||
# 3x more epochs.
|
||||
# Speed perturbation probably should come first before
|
||||
# concatenation, but in principle the transforms order doesn't have
|
||||
# to be strict (e.g. could be randomized)
|
||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# Drop feats to be on the safe side.
|
||||
train = PromptASRDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
text_sampling_func=text_sampling_func,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=True,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self,
|
||||
cuts_valid: CutSet,
|
||||
text_sampling_func: Callable[[List[str]], str] = None,
|
||||
) -> DataLoader:
|
||||
transforms = []
|
||||
if self.args.random_left_padding:
|
||||
logging.info("Enable random left padding")
|
||||
transforms.append(
|
||||
ExtraPadding(extra_frames=16, randomized=True, direction="left")
|
||||
)
|
||||
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = PromptASRDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
return_cuts=self.args.return_cuts,
|
||||
text_sampling_func=text_sampling_func,
|
||||
)
|
||||
else:
|
||||
validate = PromptASRDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
text_sampling_func=text_sampling_func,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info(f"About to get {self.args.subset} cuts")
|
||||
path = (
|
||||
self.args.manifest_dir
|
||||
/ f"librilight_cuts_train_{self.args.subset}.jsonl.gz"
|
||||
)
|
||||
cuts_train = CutSet.from_jsonl_lazy(path)
|
||||
return cuts_train
|
||||
|
||||
@lru_cache()
|
||||
def dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
cuts_valid = load_manifest_lazy(
|
||||
self.args.manifest_dir / "librilight_cuts_dev.jsonl.gz"
|
||||
)
|
||||
return cuts_valid
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
cuts_valid = load_manifest_lazy(
|
||||
self.args.manifest_dir / "librilight_cuts_test.jsonl.gz"
|
||||
)
|
||||
return cuts_valid
|
||||
|
||||
@lru_cache()
|
||||
def test_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-clean cuts")
|
||||
cuts = load_manifest_lazy(
|
||||
self.args.manifest_dir / "librilight_finetuning_clean.jsonl.gz"
|
||||
)
|
||||
return cuts
|
||||
|
||||
@lru_cache()
|
||||
def test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-other cuts")
|
||||
cuts = load_manifest_lazy(
|
||||
self.args.manifest_dir / "librilight_finetuning_other.jsonl.gz"
|
||||
)
|
||||
return cuts
|
||||
|
||||
@lru_cache()
|
||||
def librispeech_test_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def librispeech_test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
||||
)
|
1714
egs/libriheavy/ASR/zipformer_prompt_asr/beam_search.py
Normal file
1714
egs/libriheavy/ASR/zipformer_prompt_asr/beam_search.py
Normal file
File diff suppressed because it is too large
Load Diff
354
egs/libriheavy/ASR/zipformer_prompt_asr/dataset.py
Normal file
354
egs/libriheavy/ASR/zipformer_prompt_asr/dataset.py
Normal file
@ -0,0 +1,354 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from lhotse import validate
|
||||
from lhotse.cut import CutSet
|
||||
from lhotse.dataset import K2SpeechRecognitionDataset
|
||||
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
|
||||
from lhotse.utils import compute_num_frames, ifnone
|
||||
from torch.utils.data.dataloader import DataLoader, default_collate
|
||||
|
||||
from text_normalization import (
|
||||
remove_non_alphabetic,
|
||||
upper_only_alpha,
|
||||
lower_only_alpha,
|
||||
upper_all_char,
|
||||
lower_all_char,
|
||||
train_text_normalization,
|
||||
)
|
||||
|
||||
|
||||
class PromptASRDataset(torch.utils.data.Dataset):
|
||||
"""This is a dataset for Prompt ASR. It supports the following features:
|
||||
1. Select a tuple of (text, pre_text, style_text) randomly from a
|
||||
list of texts as supervisions.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
return_cuts: bool = False,
|
||||
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
|
||||
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
|
||||
input_strategy: BatchIO = PrecomputedFeatures(),
|
||||
text_sampling_func: Optional[Callable[[List[str]], str]] = None,
|
||||
):
|
||||
"""
|
||||
Icefall ASR IterableDataset constructor. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py
|
||||
for more details.
|
||||
|
||||
:param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut
|
||||
objects used to create that batch.
|
||||
:param cut_transforms: A list of transforms to be applied on each sampled batch,
|
||||
before converting cuts to an input representation (audio/features).
|
||||
Examples: cut concatenation, noise cuts mixing, etc.
|
||||
:param input_transforms: A list of transforms to be applied on each sampled batch,
|
||||
after the cuts are converted to audio/features.
|
||||
Examples: normalization, SpecAugment, etc.
|
||||
:param input_strategy: Converts cuts into a collated batch of audio/features.
|
||||
By default, reads pre-computed features from disk.
|
||||
:param text_sampling_func: Sampling a text as transcription from a list of texts.
|
||||
"""
|
||||
super().__init__()
|
||||
# Initialize the fields
|
||||
self.return_cuts = return_cuts
|
||||
self.cut_transforms = ifnone(cut_transforms, [])
|
||||
self.input_transforms = ifnone(input_transforms, [])
|
||||
self.input_strategy = input_strategy
|
||||
|
||||
# a text sampling function
|
||||
self.text_sampling_func = text_sampling_func
|
||||
|
||||
def __getitem__(
|
||||
self, cuts: CutSet
|
||||
) -> Dict[str, Union[torch.Tensor, List[str]]]:
|
||||
"""
|
||||
Return a new batch, with the batch size automatically determined using the constraints
|
||||
of max_frames and max_cuts.
|
||||
"""
|
||||
validate_for_asr(cuts)
|
||||
|
||||
# Sort the cuts by duration so that the first one determines the batch time dimensions.
|
||||
cuts = cuts.sort_by_duration(ascending=False)
|
||||
|
||||
# Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts
|
||||
# the supervision boundaries.
|
||||
for tnfm in self.cut_transforms:
|
||||
cuts = tnfm(cuts)
|
||||
|
||||
# Sort the cuts again after transforms
|
||||
cuts = cuts.sort_by_duration(ascending=False)
|
||||
|
||||
# Get a tensor with batched feature matrices, shape (B, T, F)
|
||||
# Collation performs auto-padding, if necessary.
|
||||
input_tpl = self.input_strategy(cuts)
|
||||
if len(input_tpl) == 3:
|
||||
# An input strategy with fault tolerant audio reading mode.
|
||||
# "cuts" may be a subset of the original "cuts" variable,
|
||||
# that only has cuts for which we succesfully read the audio.
|
||||
inputs, _, cuts = input_tpl
|
||||
else:
|
||||
inputs, _ = input_tpl
|
||||
|
||||
# Get a dict of tensors that encode the positional information about supervisions
|
||||
# in the batch of feature matrices. The tensors are named "sequence_idx",
|
||||
# "start_frame/sample" and "num_frames/samples".
|
||||
supervision_intervals = self.input_strategy.supervision_intervals(cuts)
|
||||
|
||||
# Apply all available transforms on the inputs, i.e. either audio or features.
|
||||
# This could be feature extraction, global MVN, SpecAugment, etc.
|
||||
segments = torch.stack(list(supervision_intervals.values()), dim=1)
|
||||
for tnfm in self.input_transforms:
|
||||
inputs = tnfm(inputs, supervision_segments=segments)
|
||||
|
||||
batch = {
|
||||
"inputs": inputs,
|
||||
"supervisions": default_collate(
|
||||
[
|
||||
self.text_sampling_func(
|
||||
texts=supervision.texts, pre_texts=supervision.pre_texts
|
||||
)
|
||||
if self.text_sampling_func is not None
|
||||
else {
|
||||
"text": train_text_normalization(supervision.texts[0]),
|
||||
"pre_text": train_text_normalization(
|
||||
supervision.pre_texts[0]
|
||||
),
|
||||
"style_text": train_text_normalization(
|
||||
supervision.pre_texts[0]
|
||||
),
|
||||
"transform_ids": 0,
|
||||
}
|
||||
for sequence_idx, cut in enumerate(cuts)
|
||||
for supervision in cut.supervisions
|
||||
]
|
||||
),
|
||||
}
|
||||
# Update the 'supervisions' field with sequence_idx and start/num frames/samples
|
||||
batch["supervisions"].update(supervision_intervals)
|
||||
if self.return_cuts:
|
||||
batch["supervisions"]["cut"] = [
|
||||
cut for cut in cuts for sup in cut.supervisions
|
||||
]
|
||||
|
||||
has_word_alignments = all(
|
||||
s.alignment is not None and "word" in s.alignment
|
||||
for c in cuts
|
||||
for s in c.supervisions
|
||||
)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def validate_for_asr(cuts: CutSet) -> None:
|
||||
validate(cuts)
|
||||
tol = 2e-3 # 1ms
|
||||
for cut in cuts:
|
||||
for supervision in cut.supervisions:
|
||||
assert supervision.start >= -tol, (
|
||||
f"Supervisions starting before the cut are not supported for ASR"
|
||||
f" (sup id: {supervision.id}, cut id: {cut.id})"
|
||||
)
|
||||
|
||||
# Supervision start time is relative to Cut ...
|
||||
# https://lhotse.readthedocs.io/en/v0.10_e/cuts.html
|
||||
#
|
||||
# 'supervision.end' is end of supervision inside the Cut
|
||||
assert supervision.end <= cut.duration + tol, (
|
||||
f"Supervisions ending after the cut "
|
||||
f"are not supported for ASR"
|
||||
f" (sup id: {supervision.id}, cut id: {cut.id})"
|
||||
)
|
||||
|
||||
|
||||
def get_substring(s: str, min_len: int = 40, max_len: int = 250) -> str:
|
||||
"""A helper function that generates a random substring from a given string
|
||||
|
||||
Args:
|
||||
s (str): Input string
|
||||
|
||||
Returns:
|
||||
str: Returned substring
|
||||
"""
|
||||
min_len = min(len(s), min_len)
|
||||
|
||||
start = random.randint(0, len(s) - min_len)
|
||||
end = min(start + max_len, random.randint(start + min_len, len(s)))
|
||||
|
||||
return s[start:end]
|
||||
|
||||
|
||||
def triplet_text_sampling(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
transforms: Optional[List[Callable[[str], str]]] = None,
|
||||
min_len_style: Optional[int] = 80,
|
||||
) -> Dict[str, str]:
|
||||
"""This function generates a tuple of
|
||||
(pre_text, style_text, ref_text). The style of style_text and ref_text
|
||||
should always match, whereas the style of pre_text is arbitrary.
|
||||
Suppose we have 3 different transforms A,B,C, and the groundtruth
|
||||
text and pre_text are referred to as text and pre_text.
|
||||
The following three tuples are all valid:
|
||||
|
||||
(A(pre_text), B(style_text), B(text))
|
||||
(A(pre_text), C(style_text), C(text))
|
||||
(A(pre_text), A(style_text), A(text))
|
||||
...
|
||||
|
||||
If transforms is not given, the following pre-defined transforms
|
||||
are available:
|
||||
0: original (normal case, with punc)
|
||||
1: recog (upper, no punc)
|
||||
2: upper_only_alpha (upper, no punc)
|
||||
3: lower_only_alpha (lower, no punc)
|
||||
4: upper_all (upper, with punc)
|
||||
5: lower_all (lower, with punc)
|
||||
|
||||
When the transform of text and pre_text match, we can use the whole
|
||||
pre_text as the prompt text.
|
||||
|
||||
Args:
|
||||
texts (List[str]):
|
||||
A list of ref_texts whose first item is the ground truth
|
||||
text from books.
|
||||
pre_texts (List[str]):
|
||||
A list of pre_texts, whose first item is the groundtruth
|
||||
pre_text from books.
|
||||
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
||||
|
||||
Returns:
|
||||
str: A dictionary
|
||||
"""
|
||||
# import pdb; pdb.set_trace()
|
||||
assert len(texts) == len(pre_texts)
|
||||
assert len(texts) == 2
|
||||
|
||||
# we assume the first item to be ground truth
|
||||
gt_text = texts[0]
|
||||
gt_pre_text = pre_texts[0]
|
||||
|
||||
if transforms is None:
|
||||
transforms = [
|
||||
lambda x: x, # return it self
|
||||
upper_only_alpha,
|
||||
lower_only_alpha,
|
||||
lower_all_char,
|
||||
]
|
||||
|
||||
sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
|
||||
|
||||
total_transforms = len(transforms) # do not use the recognized trans
|
||||
|
||||
# Select a transformation randomly
|
||||
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
|
||||
|
||||
# get the normalized text and pre_text
|
||||
text = transforms[i_text](gt_text)
|
||||
pre_text = transforms[i_pre_text](gt_pre_text)
|
||||
|
||||
if i_text == i_pre_text:
|
||||
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
|
||||
else:
|
||||
# get the pre_text of same style as text
|
||||
# For now, do not do transform to the style text
|
||||
style_text = gt_pre_text
|
||||
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
|
||||
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
|
||||
|
||||
return {
|
||||
"text": train_text_normalization(text),
|
||||
"pre_text": train_text_normalization(pre_text),
|
||||
"style_text": train_text_normalization(style_text),
|
||||
"transform_ids": i_text,
|
||||
}
|
||||
|
||||
|
||||
def naive_triplet_text_sampling(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
min_len_style: Optional[int] = 120,
|
||||
):
|
||||
|
||||
return {
|
||||
"text": train_text_normalization(texts[0]),
|
||||
"pre_text": train_text_normalization(pre_texts[0]),
|
||||
"style_text": train_text_normalization(pre_texts[0][:150]),
|
||||
# "style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related.",
|
||||
# "style_text": train_text_normalization(get_substring(pre_texts[0], min_len=min_len_style)),
|
||||
"transform_ids": 0,
|
||||
}
|
||||
|
||||
|
||||
def random_shuffle_subset(
|
||||
data: List[str],
|
||||
p: float = 0.2,
|
||||
p_mask: float = 0.05,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Randomly shuffle the subset by probability p, which means that p% of the samples
|
||||
in the original batch are shuffled, the others are kept in the original order.
|
||||
|
||||
With a probability of p_mask, replace the original string with an empty string.
|
||||
|
||||
"""
|
||||
|
||||
num_to_shuffle = int(len(data) * p)
|
||||
id_to_shuffle = np.random.choice(len(data), num_to_shuffle, replace=False)
|
||||
item_to_shuffle = [data[id] for id in id_to_shuffle]
|
||||
random.shuffle(item_to_shuffle)
|
||||
|
||||
# print(num_to_shuffle,id_to_shuffle, item_to_shuffle)
|
||||
for id, item in zip(id_to_shuffle, item_to_shuffle):
|
||||
data[id] = item
|
||||
|
||||
if p_mask > 0:
|
||||
for i in range(len(data)):
|
||||
if random.random() < p_mask:
|
||||
data[i] = ""
|
||||
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
texts = [
|
||||
"AA, BB, cC, dD!",
|
||||
"AA BB CC DD",
|
||||
]
|
||||
|
||||
pre_texts = [
|
||||
"EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg?",
|
||||
"EE FF GG EE FF GG EE FF GG EE FF GG EE FF GG",
|
||||
]
|
||||
# for i in range(10):
|
||||
# print(f"Run: {i}")
|
||||
# print(triplet_text_sampling(texts, pre_texts))
|
||||
|
||||
import time
|
||||
start = time.time()
|
||||
data = [str(i) for i in range(30)]
|
||||
random.shuffle(data)
|
||||
print(data)
|
||||
for i in range(1):
|
||||
shuffled = random_shuffle_subset(data=data, p=0.4, p_mask=0.1)
|
||||
print(shuffled)
|
||||
print((time.time() - start)/100)
|
858
egs/libriheavy/ASR/zipformer_prompt_asr/decode.py
Executable file
858
egs/libriheavy/ASR/zipformer_prompt_asr/decode.py
Executable file
@ -0,0 +1,858 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) modified beam search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Callable
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriHeavyAsrDataModule
|
||||
from beam_search import (
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from dataset import naive_triplet_text_sampling
|
||||
from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha, upper_all_char, lower_all_char, lower_only_alpha
|
||||
from train import (
|
||||
add_model_arguments,
|
||||
get_params,
|
||||
get_transducer_model,
|
||||
_encode_texts_as_bytes,
|
||||
)
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bpe_500",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
- modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion
|
||||
- modified_beam_search_LODR
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-pre-text",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Use pre-text is available during decoding",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-style-prompt",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Use style prompt when evaluation"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--post-normalization",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--compute-CER",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Reports CER. By default, only reports WER",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--style-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of style prompt, i.e style_text"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pre-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of content prompt, i.e pre_text"
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
||||
"""Apply transform to a list of text. By default, the text are in
|
||||
ground truth format, i.e mixed-punc.
|
||||
|
||||
Args:
|
||||
text (List[str]): Input text string
|
||||
transform (str): Transform to be applied
|
||||
|
||||
Returns:
|
||||
List[str]: _description_
|
||||
"""
|
||||
if transform == "mixed-punc":
|
||||
return text
|
||||
elif transform == "upper-no-punc":
|
||||
return [upper_only_alpha(s) for s in text]
|
||||
elif transform == "lower-no-punc":
|
||||
return [lower_only_alpha(s) for s in text]
|
||||
elif transform == "lower-punc":
|
||||
return [lower_all_char(s) for s in text]
|
||||
else:
|
||||
raise NotImplementedError(f"Unseen transform: {transform}")
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
LM:
|
||||
A neural net LM for shallow fusion. Only used when `--use-shallow-fusion`
|
||||
set to true.
|
||||
ngram_lm:
|
||||
A ngram lm. Used in LODR decoding.
|
||||
ngram_lm_scale:
|
||||
The scale of the ngram language model.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
batch_size = feature.size(0)
|
||||
|
||||
if "pre_text" in batch["supervisions"] and params.use_pre_text:
|
||||
pre_texts = batch["supervisions"]["pre_text"]
|
||||
else:
|
||||
pre_texts = ["" for _ in range(batch_size)]
|
||||
|
||||
if params.use_style_prompt:
|
||||
style_texts = batch["supervisions"]["style_text"]
|
||||
else:
|
||||
style_texts = ["" for _ in range(batch_size)] # use empty string
|
||||
|
||||
# Get the text embedding input
|
||||
if params.use_pre_text or params.use_style_prompt:
|
||||
|
||||
# apply style transform to the pre_text and style_text
|
||||
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
||||
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
pre_texts, pre_texts_lens, style_lens = _encode_texts_as_bytes(
|
||||
pre_texts,
|
||||
style_texts,
|
||||
device,
|
||||
max_len=1200
|
||||
) # note that the output pre_texts include style_text and actual pre_text
|
||||
|
||||
memory, memory_key_padding_mask = model.encode_text(
|
||||
text=pre_texts,
|
||||
text_lens=pre_texts_lens,
|
||||
style_lens=style_lens,
|
||||
)
|
||||
else:
|
||||
memory = None
|
||||
memory_key_padding_mask = None
|
||||
|
||||
# Get the transducer encoder output
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
encoder_out, encoder_out_lens = model.encode_audio(
|
||||
feature=feature,
|
||||
feature_lens=feature_lens,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
|
||||
hyps = []
|
||||
|
||||
if (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
LM:
|
||||
A neural network LM, used during shallow fusion
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"] # By default, this should be in mixed-punc format
|
||||
|
||||
# the style of ref_text should match style_text
|
||||
if params.use_style_prompt:
|
||||
texts = _apply_style_transform(texts, params.style_text_transform)
|
||||
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_text = ref_text_normalization(
|
||||
ref_text
|
||||
) # remove full-width symbols & some book marks
|
||||
if params.post_normalization:
|
||||
ref_words = ref_text.split()
|
||||
ref_words = [remove_non_alphabetic(w.upper()) for w in ref_words]
|
||||
ref_words = [w for w in ref_words if w != ""]
|
||||
hyp_words = [remove_non_alphabetic(w.upper()) for w in hyp_words]
|
||||
hyp_words = [w for w in hyp_words if w != ""]
|
||||
else:
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
test_set_cers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
if params.compute_CER:
|
||||
# Write CER statistics
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
|
||||
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
cer = write_error_stats(
|
||||
f,
|
||||
f"{test_set_name}-{key}",
|
||||
results,
|
||||
enable_log=True,
|
||||
compute_CER=params.compute_CER,
|
||||
)
|
||||
test_set_cers[key] = cer
|
||||
|
||||
logging.info("Wrote detailed CER stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
if params.compute_CER:
|
||||
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tcER", file=f)
|
||||
for key, val in test_set_cers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_cers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriHeavyAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.causal:
|
||||
assert (
|
||||
"," not in params.chunk_size
|
||||
), "chunk_size should be one value in decoding."
|
||||
assert (
|
||||
"," not in params.left_context_frames
|
||||
), "left_context_frames should be one value in decoding."
|
||||
params.suffix += f"-chunk-{params.chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context_frames}"
|
||||
|
||||
if "beam_search" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_pre_text:
|
||||
params.suffix += f"-pre-text-{params.pre_text_transform}"
|
||||
|
||||
if params.use_style_prompt:
|
||||
params.suffix += f"-style-prompt-{params.style_text_transform}"
|
||||
|
||||
if params.post_normalization:
|
||||
params.suffix += "-post-normalization"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
LM = None
|
||||
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
libriheavy = LibriHeavyAsrDataModule(args)
|
||||
|
||||
test_cuts = libriheavy.test_cuts()
|
||||
#test_cuts = test_cuts.subset(first=200)
|
||||
test_clean_cuts = libriheavy.test_clean_cuts()
|
||||
test_other_cuts = libriheavy.test_other_cuts()
|
||||
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
|
||||
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
|
||||
|
||||
test_dl = libriheavy.valid_dataloaders(test_cuts, text_sampling_func=naive_triplet_text_sampling)
|
||||
test_clean_dl = libriheavy.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = libriheavy.test_dataloaders(test_other_cuts)
|
||||
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
|
||||
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
|
||||
|
||||
#test_sets = ["test-clean", "test-other", "ls-test-clean", "ls-test-other"]
|
||||
#test_dl = [test_clean_dl, test_other_dl, ls_test_clean_dl, ls_test_other_dl]
|
||||
|
||||
# test_sets = ["ls-test-clean", "ls-test-other"]
|
||||
# test_dl = [ls_test_clean_dl, ls_test_other_dl]
|
||||
|
||||
test_sets = ["test",]
|
||||
test_dl = [test_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
123
egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py
Normal file
123
egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py
Normal file
@ -0,0 +1,123 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from scaling import Balancer
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""This class modifies the stateless decoder from the following paper:
|
||||
|
||||
RNN-transducer with stateless prediction network
|
||||
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
||||
|
||||
It removes the recurrent connection from the decoder, i.e., the prediction
|
||||
network. Different from the above paper, it adds an extra Conv1d
|
||||
right after the embedding layer.
|
||||
|
||||
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
decoder_dim: int,
|
||||
blank_id: int,
|
||||
context_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
vocab_size:
|
||||
Number of tokens of the modeling unit including blank.
|
||||
decoder_dim:
|
||||
Dimension of the input embedding, and of the decoder output.
|
||||
blank_id:
|
||||
The ID of the blank symbol.
|
||||
context_size:
|
||||
Number of previous words to use to predict the next word.
|
||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.embedding = nn.Embedding(
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=decoder_dim,
|
||||
padding_idx=blank_id,
|
||||
)
|
||||
# the balancers are to avoid any drift in the magnitude of the
|
||||
# embeddings, which would interact badly with parameter averaging.
|
||||
self.balancer = Balancer(decoder_dim, channel_dim=-1,
|
||||
min_positive=0.0, max_positive=1.0,
|
||||
min_abs=0.5, max_abs=1.0,
|
||||
prob=0.05)
|
||||
|
||||
self.blank_id = blank_id
|
||||
|
||||
assert context_size >= 1, context_size
|
||||
self.context_size = context_size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
if context_size > 1:
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels=decoder_dim,
|
||||
out_channels=decoder_dim,
|
||||
kernel_size=context_size,
|
||||
padding=0,
|
||||
groups=decoder_dim//4, # group size == 4
|
||||
bias=False,
|
||||
)
|
||||
self.balancer2 = Balancer(decoder_dim, channel_dim=-1,
|
||||
min_positive=0.0, max_positive=1.0,
|
||||
min_abs=0.5, max_abs=1.0,
|
||||
prob=0.05)
|
||||
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, U).
|
||||
need_pad:
|
||||
True to left pad the input. Should be True during training.
|
||||
False to not pad the input. Should be False during inference.
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, decoder_dim).
|
||||
"""
|
||||
y = y.to(torch.int64)
|
||||
# this stuff about clamp() is a temporary fix for a mismatch
|
||||
# at utterance start, we use negative ids in beam_search.py
|
||||
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
|
||||
|
||||
embedding_out = self.balancer(embedding_out)
|
||||
|
||||
if self.context_size > 1:
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
if need_pad is True:
|
||||
embedding_out = F.pad(
|
||||
embedding_out, pad=(self.context_size - 1, 0)
|
||||
)
|
||||
else:
|
||||
# During inference time, there is no need to do extra padding
|
||||
# as we only need one output
|
||||
assert embedding_out.size(-1) == self.context_size
|
||||
embedding_out = self.conv(embedding_out)
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
embedding_out = F.relu(embedding_out)
|
||||
embedding_out = self.balancer2(embedding_out)
|
||||
|
||||
return embedding_out
|
43
egs/libriheavy/ASR/zipformer_prompt_asr/encoder_interface.py
Normal file
43
egs/libriheavy/ASR/zipformer_prompt_asr/encoder_interface.py
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class EncoderInterface(nn.Module):
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A tensor of shape (batch_size, input_seq_len, num_features)
|
||||
containing the input features.
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames
|
||||
in `x` before padding.
|
||||
Returns:
|
||||
Return a tuple containing two tensors:
|
||||
- encoder_out, a tensor of (batch_size, out_seq_len, output_dim)
|
||||
containing unnormalized probabilities, i.e., the output of a
|
||||
linear layer.
|
||||
- encoder_out_lens, a tensor of shape (batch_size,) containing
|
||||
the number of frames in `encoder_out` before padding.
|
||||
"""
|
||||
raise NotImplementedError("Please implement it in a subclass")
|
69
egs/libriheavy/ASR/zipformer_prompt_asr/joiner.py
Normal file
69
egs/libriheavy/ASR/zipformer_prompt_asr/joiner.py
Normal file
@ -0,0 +1,69 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import (
|
||||
ScaledLinear
|
||||
)
|
||||
|
||||
|
||||
class Joiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
|
||||
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
|
||||
self.output_linear = nn.Linear(joiner_dim, vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
project_input: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, s_range, C).
|
||||
decoder_out:
|
||||
Output from the decoder. Its shape is (N, T, s_range, C).
|
||||
project_input:
|
||||
If true, apply input projections encoder_proj and decoder_proj.
|
||||
If this is false, it is the user's responsibility to do this
|
||||
manually.
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, s_range, C).
|
||||
"""
|
||||
assert encoder_out.ndim == decoder_out.ndim == 4
|
||||
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
|
||||
|
||||
if project_input:
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
|
||||
decoder_out
|
||||
)
|
||||
else:
|
||||
logit = encoder_out + decoder_out
|
||||
|
||||
logit = self.output_linear(torch.tanh(logit))
|
||||
|
||||
return logit
|
348
egs/libriheavy/ASR/zipformer_prompt_asr/model.py
Normal file
348
egs/libriheavy/ASR/zipformer_prompt_asr/model.py
Normal file
@ -0,0 +1,348 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import random
|
||||
import warnings
|
||||
from encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask
|
||||
from scaling import penalize_abs_values_gt, ScaledLinear
|
||||
from torch import Tensor
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class PromptedTransducer(nn.Module):
|
||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||
"Sequence Transduction with Recurrent Neural Networks"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_embed: nn.Module,
|
||||
encoder: EncoderInterface,
|
||||
text_embed: nn.Module,
|
||||
text_encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
joiner: nn.Module,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder_embed:
|
||||
It is a Convolutional 2D subsampling module. It converts
|
||||
an input of shape (N, T, idim) to an output of of shape
|
||||
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||
It should contain one attribute: `blank_id`.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
||||
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
assert hasattr(decoder, "blank_id")
|
||||
|
||||
self.encoder_embed = encoder_embed
|
||||
self.encoder = encoder
|
||||
self.text_embed = text_embed
|
||||
self.text_encoder = text_encoder
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
self.simple_am_proj = ScaledLinear(
|
||||
encoder_dim,
|
||||
vocab_size,
|
||||
initial_scale=0.25,
|
||||
)
|
||||
self.simple_lm_proj = ScaledLinear(
|
||||
decoder_dim,
|
||||
vocab_size,
|
||||
initial_scale=0.25,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lens: torch.Tensor,
|
||||
style_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
use_pre_text: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
text:
|
||||
A 2-D tensor of integer dtype containing prompt text, of shape (N, T).
|
||||
It is exptected to contain the style prompt (first) and then the content
|
||||
prompt.
|
||||
text_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of elements (bytes)
|
||||
in `text` before padding, which will include the lengths of the
|
||||
style plus the content prompt.
|
||||
style_lens:
|
||||
A 1-D tensor of shape (N,), containing the number of elements (bytes)
|
||||
within each row of `text` that correspond to the style prompt (these
|
||||
are expected to come first).
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
x, x_lens = self.encoder_embed(x, x_lens)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
if use_pre_text:
|
||||
memory, memory_key_padding_mask = self.encode_text(
|
||||
text,
|
||||
text_lens,
|
||||
style_lens
|
||||
)
|
||||
else:
|
||||
memory = None
|
||||
memory_key_padding_mask = None
|
||||
|
||||
encoder_out, x_lens = self.encoder(
|
||||
x,
|
||||
x_lens,
|
||||
src_key_padding_mask,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
# Note: y does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros(
|
||||
(encoder_out.size(0), 4),
|
||||
dtype=torch.int64,
|
||||
device=encoder_out.device,
|
||||
)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
# if self.training and random.random() < 0.25:
|
||||
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||
# if self.training and random.random() < 0.25:
|
||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad,
|
||||
py_grad=py_grad,
|
||||
boundary=boundary,
|
||||
s_range=prune_range,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=self.joiner.encoder_proj(encoder_out),
|
||||
lm=self.joiner.decoder_proj(decoder_out),
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
||||
|
||||
def _add_style_indicator(self, memory: Tensor, style_lens: Tensor):
|
||||
"""
|
||||
Adds to `memory` an indicator that is 1.0 for positions that correspond to
|
||||
the `style prompt` and 0 elsewhere. The scale can be fixed because the
|
||||
scale of the embedding vector can adjust to compensate.
|
||||
|
||||
Args:
|
||||
memory: (memory_len, batch_size, embed_dim)
|
||||
style_lens: (batch_size,), a vector of lengths of the style prompt.
|
||||
"""
|
||||
|
||||
(memory_len, batch_size, embed_dim) = memory.shape
|
||||
|
||||
indicator = (
|
||||
torch.arange(memory_len, device=memory.device).unsqueeze(-1)
|
||||
< style_lens
|
||||
)
|
||||
indicator = indicator.to(memory.dtype)
|
||||
|
||||
extra_term = torch.zeros_like(memory)
|
||||
extra_term[..., 0] += indicator
|
||||
|
||||
return memory + extra_term
|
||||
|
||||
def encode_text(
|
||||
self,
|
||||
text: Tensor,
|
||||
text_lens: Tensor,
|
||||
style_lens: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Get the embeddings of text
|
||||
|
||||
Args:
|
||||
text (Tensor): The input text data in utf-8 bytes, (N, T)
|
||||
text_lens (Tensor): The length of the input text (N, ), including style_prompt
|
||||
style_lens (Tensor): The length of the style prompt (N, )
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor]: Returns the text embeddings encoded by the
|
||||
text_encoder and the attention mask
|
||||
"""
|
||||
text = text.t() # now (T, N)
|
||||
text = self.text_embed(text) # now (T, N, C)
|
||||
text_key_padding_mask = make_pad_mask(text_lens)
|
||||
|
||||
text = self._add_style_indicator(text, style_lens)
|
||||
|
||||
memory, text_lens = self.text_encoder(
|
||||
text, text_lens, text_key_padding_mask
|
||||
)
|
||||
|
||||
memory_key_padding_mask = make_pad_mask(text_lens)
|
||||
|
||||
return memory, memory_key_padding_mask
|
||||
|
||||
def encode_audio(
|
||||
self,
|
||||
feature: Tensor,
|
||||
feature_lens: Tensor,
|
||||
memory: Optional[Tensor],
|
||||
memory_key_padding_mask: Optional[Tensor],
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Encode the input audio features
|
||||
|
||||
Args:
|
||||
feature (Tensor): Input audio (N,T,C)
|
||||
feature_lens (Tensor): Length of input audio (N,)
|
||||
memory (Tensor): Embeddings from the text encoder
|
||||
memory_key_padding_mask (Tensor): _description_
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor]: _description_
|
||||
"""
|
||||
x, x_lens = self.encoder_embed(feature, feature_lens)
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens = self.encoder(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
|
||||
Transducer = PromptedTransducer # for decoding
|
1175
egs/libriheavy/ASR/zipformer_prompt_asr/optim.py
Normal file
1175
egs/libriheavy/ASR/zipformer_prompt_asr/optim.py
Normal file
File diff suppressed because it is too large
Load Diff
1809
egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py
Normal file
1809
egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py
Normal file
File diff suppressed because it is too large
Load Diff
280
egs/libriheavy/ASR/zipformer_prompt_asr/subsampling.py
Normal file
280
egs/libriheavy/ASR/zipformer_prompt_asr/subsampling.py
Normal file
@ -0,0 +1,280 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from scaling import (
|
||||
Balancer,
|
||||
BiasNorm,
|
||||
Dropout3,
|
||||
FloatLike,
|
||||
Optional,
|
||||
ScaledConv2d,
|
||||
ScaleGrad,
|
||||
ScheduledFloat,
|
||||
SwooshL,
|
||||
SwooshR,
|
||||
Whiten,
|
||||
)
|
||||
|
||||
|
||||
class ConvNeXt(nn.Module):
|
||||
"""
|
||||
Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
hidden_ratio: int = 3,
|
||||
kernel_size: Tuple[int, int] = (7, 7),
|
||||
layerdrop_rate: FloatLike = None,
|
||||
):
|
||||
super().__init__()
|
||||
padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
|
||||
hidden_channels = channels * hidden_ratio
|
||||
if layerdrop_rate is None:
|
||||
layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015))
|
||||
self.layerdrop_rate = layerdrop_rate
|
||||
|
||||
self.depthwise_conv = nn.Conv2d(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
groups=channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding,
|
||||
)
|
||||
|
||||
self.pointwise_conv1 = nn.Conv2d(
|
||||
in_channels=channels, out_channels=hidden_channels, kernel_size=1
|
||||
)
|
||||
|
||||
self.hidden_balancer = Balancer(
|
||||
hidden_channels,
|
||||
channel_dim=1,
|
||||
min_positive=0.3,
|
||||
max_positive=1.0,
|
||||
min_abs=0.75,
|
||||
max_abs=5.0,
|
||||
)
|
||||
|
||||
self.activation = SwooshL()
|
||||
self.pointwise_conv2 = ScaledConv2d(
|
||||
in_channels=hidden_channels,
|
||||
out_channels=channels,
|
||||
kernel_size=1,
|
||||
initial_scale=0.01,
|
||||
)
|
||||
|
||||
self.out_balancer = Balancer(
|
||||
channels,
|
||||
channel_dim=1,
|
||||
min_positive=0.4,
|
||||
max_positive=0.6,
|
||||
min_abs=1.0,
|
||||
max_abs=6.0,
|
||||
)
|
||||
self.out_whiten = Whiten(
|
||||
num_groups=1,
|
||||
whitening_limit=5.0,
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01,
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
return self.forward_internal(x)
|
||||
layerdrop_rate = float(self.layerdrop_rate)
|
||||
|
||||
if layerdrop_rate != 0.0:
|
||||
batch_size = x.shape[0]
|
||||
mask = (
|
||||
torch.rand(
|
||||
(batch_size, 1, 1, 1), dtype=x.dtype, device=x.device
|
||||
)
|
||||
> layerdrop_rate
|
||||
)
|
||||
else:
|
||||
mask = None
|
||||
# turns out this caching idea does not work with --world-size > 1
|
||||
# return caching_eval(self.forward_internal, x, mask)
|
||||
return self.forward_internal(x, mask)
|
||||
|
||||
def forward_internal(
|
||||
self, x: Tensor, layer_skip_mask: Optional[Tensor] = None
|
||||
) -> Tensor:
|
||||
"""
|
||||
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
|
||||
|
||||
The returned value has the same shape as x.
|
||||
"""
|
||||
bypass = x
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.pointwise_conv1(x)
|
||||
x = self.hidden_balancer(x)
|
||||
x = self.activation(x)
|
||||
x = self.pointwise_conv2(x)
|
||||
|
||||
if layer_skip_mask is not None:
|
||||
x = x * layer_skip_mask
|
||||
|
||||
x = bypass + x
|
||||
x = self.out_balancer(x)
|
||||
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
|
||||
x = self.out_whiten(x)
|
||||
x = x.transpose(1, 3) # (N, C, H, W)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/2 length).
|
||||
|
||||
Convert an input of shape (N, T, idim) to an output
|
||||
with shape (N, T', odim), where
|
||||
T' = (T-3)//2 - 2 == (T-7)//2
|
||||
|
||||
It is based on
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
layer1_channels: int = 8,
|
||||
layer2_channels: int = 32,
|
||||
layer3_channels: int = 128,
|
||||
dropout: FloatLike = 0.1,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
in_channels:
|
||||
Number of channels in. The input shape is (N, T, in_channels).
|
||||
Caution: It requires: T >=7, in_channels >=7
|
||||
out_channels
|
||||
Output dim. The output shape is (N, (T-3)//2, out_channels)
|
||||
layer1_channels:
|
||||
Number of channels in layer1
|
||||
layer1_channels:
|
||||
Number of channels in layer2
|
||||
bottleneck:
|
||||
bottleneck dimension for 1d squeeze-excite
|
||||
"""
|
||||
assert in_channels >= 7
|
||||
super().__init__()
|
||||
|
||||
# The ScaleGrad module is there to prevent the gradients
|
||||
# w.r.t. the weight or bias of the first Conv2d module in self.conv from
|
||||
# exceeding the range of fp16 when using automatic mixed precision (amp)
|
||||
# training. (The second one is necessary to stop its bias from getting
|
||||
# a too-large gradient).
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=1,
|
||||
out_channels=layer1_channels,
|
||||
kernel_size=3,
|
||||
padding=(0, 1), # (time, freq)
|
||||
),
|
||||
ScaleGrad(0.2),
|
||||
Balancer(layer1_channels, channel_dim=1, max_abs=1.0),
|
||||
SwooshR(),
|
||||
nn.Conv2d(
|
||||
in_channels=layer1_channels,
|
||||
out_channels=layer2_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=0,
|
||||
),
|
||||
Balancer(layer2_channels, channel_dim=1, max_abs=4.0),
|
||||
SwooshR(),
|
||||
nn.Conv2d(
|
||||
in_channels=layer2_channels,
|
||||
out_channels=layer3_channels,
|
||||
kernel_size=3,
|
||||
stride=(1, 2), # (time, freq)
|
||||
),
|
||||
Balancer(layer3_channels, channel_dim=1, max_abs=4.0),
|
||||
SwooshR(),
|
||||
)
|
||||
|
||||
# just one convnext layer
|
||||
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
|
||||
|
||||
out_width = (((in_channels - 1) // 2) - 1) // 2
|
||||
|
||||
self.out = nn.Linear(out_width * layer3_channels, out_channels)
|
||||
# use a larger than normal grad_scale on this whitening module; there is
|
||||
# only one such module, so there is not a concern about adding together
|
||||
# many copies of this extra gradient term.
|
||||
self.out_whiten = Whiten(
|
||||
num_groups=1,
|
||||
whitening_limit=ScheduledFloat(
|
||||
(0.0, 4.0), (20000.0, 8.0), default=4.0
|
||||
),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.02,
|
||||
)
|
||||
|
||||
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
|
||||
# getting large, there is an unnecessary degree of freedom.
|
||||
self.out_norm = BiasNorm(out_channels)
|
||||
self.dropout = Dropout3(dropout, shared_dim=1)
|
||||
|
||||
def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is (N, T, idim).
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
|
||||
Returns:
|
||||
- a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
||||
- output lengths, of shape (batch_size,)
|
||||
"""
|
||||
# On entry, x is (N, T, idim)
|
||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
|
||||
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
|
||||
# gradients.
|
||||
x = self.conv(x)
|
||||
x = self.convnext(x)
|
||||
|
||||
# Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||
b, c, t, f = x.size()
|
||||
|
||||
x = x.transpose(1, 2).reshape(b, t, c * f)
|
||||
# now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels))
|
||||
|
||||
x = self.out(x)
|
||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||
x = self.out_whiten(x)
|
||||
x = self.out_norm(x)
|
||||
x = self.dropout(x)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
x_lens = (x_lens - 7) // 2
|
||||
assert x.size(1) == x_lens.max().item()
|
||||
|
||||
return x, x_lens
|
118
egs/libriheavy/ASR/zipformer_prompt_asr/test_model.py
Executable file
118
egs/libriheavy/ASR/zipformer_prompt_asr/test_model.py
Executable file
@ -0,0 +1,118 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./pruned_transducer_stateless4/test_model.py
|
||||
"""
|
||||
|
||||
from train_deeper_conv_relu import get_params, get_transducer_model, get_text_encoder
|
||||
from zipformer import Zipformer2
|
||||
from scaling import ScheduledFloat
|
||||
|
||||
|
||||
def test_model_1():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.num_encoder_layers = 24
|
||||
params.dim_feedforward = 1536 # 384 * 4
|
||||
params.encoder_dim = 384
|
||||
model = get_transducer_model(params)
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
|
||||
# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf
|
||||
def test_model_M():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.num_encoder_layers = "2,4,3,2,4"
|
||||
params.feedforward_dims = "1024,1024,2048,2048,1024"
|
||||
params.nhead = "8,8,8,8,8"
|
||||
params.encoder_dims = "384,384,384,384,384"
|
||||
params.attention_dims = "192,192,192,192,192"
|
||||
params.encoder_unmasked_dims = "256,256,256,256,256"
|
||||
params.zipformer_downsampling_factors = "1,2,4,8,2"
|
||||
params.cnn_module_kernels = "31,31,15,15"
|
||||
|
||||
params.text_encoder_dim = (192,192,256,384)
|
||||
params.decoder_dim = 512
|
||||
params.joiner_dim = 512
|
||||
model = Zipformer2(
|
||||
output_downsampling_factor=8,
|
||||
downsampling_factor=(1, 2, 4, 8),
|
||||
num_encoder_layers=(2, 4, 4, 4),
|
||||
encoder_dim=(192, 192, 256, 384),
|
||||
encoder_unmasked_dim=(192, 192, 256, 256),
|
||||
query_head_dim=(32, 32, 32, 32),
|
||||
pos_head_dim=(4, 4, 4, 4),
|
||||
value_head_dim=(12, 12, 12, 12),
|
||||
pos_dim=48,
|
||||
num_heads=(4, 4, 4, 8),
|
||||
feedforward_dim=(
|
||||
384,
|
||||
512,
|
||||
768,
|
||||
1024,
|
||||
), # could increase this if there is nough data
|
||||
cnn_module_kernel=(31, 31, 15, 15),
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||
warmup_batches=4000.0,
|
||||
causal=False,
|
||||
)
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
model = Zipformer2(
|
||||
output_downsampling_factor=8,
|
||||
downsampling_factor=(1, 2, 4, 8),
|
||||
num_encoder_layers=(2, 4, 6, 6),
|
||||
encoder_dim=(256,256,384,512),
|
||||
encoder_unmasked_dim=(196, 196, 256, 256),
|
||||
query_head_dim=(32, 32, 32, 32),
|
||||
pos_head_dim=(4, 4, 4, 4),
|
||||
value_head_dim=(12, 12, 12, 12),
|
||||
pos_dim=48,
|
||||
num_heads=(4, 4, 4, 8),
|
||||
feedforward_dim=(
|
||||
384,
|
||||
512,
|
||||
768,
|
||||
1024,
|
||||
), # could increase this if there is nough data
|
||||
cnn_module_kernel=(31, 31, 15, 15),
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||
warmup_batches=4000.0,
|
||||
causal=False,
|
||||
)
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
def main():
|
||||
# test_model_1()
|
||||
test_model_M()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,69 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import re
|
||||
|
||||
def train_text_normalization(s: str) -> str:
|
||||
s = s.replace("“", '"')
|
||||
s = s.replace("”", '"')
|
||||
s = s.replace("‘", "'")
|
||||
s = s.replace("’", "'")
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def ref_text_normalization(ref_text: str) -> str:
|
||||
# Rule 1: Remove the [FN#[]]
|
||||
p = r"[FN#[0-9]*]"
|
||||
pattern = re.compile(p)
|
||||
|
||||
# ref_text = ref_text.replace("”", "\"")
|
||||
# ref_text = ref_text.replace("’", "'")
|
||||
res = pattern.findall(ref_text)
|
||||
ref_text = re.sub(p, "", ref_text)
|
||||
|
||||
ref_text = train_text_normalization(ref_text)
|
||||
|
||||
return ref_text
|
||||
|
||||
|
||||
def remove_non_alphabetic(text: str) -> str:
|
||||
# Note, this also keeps space
|
||||
return re.sub("[^a-zA-Z\s]+", "", text)
|
||||
|
||||
|
||||
def recog_text_normalization(recog_text: str) -> str:
|
||||
pass
|
||||
|
||||
def upper_only_alpha(text: str) -> str:
|
||||
return remove_non_alphabetic(text.upper())
|
||||
|
||||
def lower_only_alpha(text: str) -> str:
|
||||
return remove_non_alphabetic(text.lower())
|
||||
|
||||
def lower_all_char(text: str) -> str:
|
||||
return text.lower()
|
||||
|
||||
def upper_all_char(text: str) -> str:
|
||||
return text.upper()
|
||||
|
||||
if __name__ == "__main__":
|
||||
ref_text = " Hello “! My name is ‘ haha"
|
||||
print(ref_text)
|
||||
res = train_text_normalization(ref_text)
|
||||
print(res)
|
1634
egs/libriheavy/ASR/zipformer_prompt_asr/train.py
Executable file
1634
egs/libriheavy/ASR/zipformer_prompt_asr/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1901
egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py
Normal file
1901
egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user