This is the initial commits for neural biasing implementation with early context injection and text perturbation; the codes runs well on the grid; however, it needs pretty much cleaning up and refactoring before maki a reasonable PR

This commit is contained in:
Ruizhe (Ray) Huang 2024-09-30 09:30:15 -04:00
parent 5c04c31292
commit 78b7ef3e3f
49 changed files with 24302 additions and 0 deletions

View File

@ -0,0 +1,206 @@
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import k2
import torch
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
# The force alignment problem can be formulated as finding
# a path in a rectangular lattice, where the path starts
# from the lower left corner and ends at the upper right
# corner. The horizontal axis of the lattice is `t` (representing
# acoustic frame indexes) and the vertical axis is `u` (representing
# BPE tokens of the transcript).
#
# The notations `t` and `u` are from the paper
# https://arxiv.org/pdf/1211.3711.pdf
#
# Beam search is used to find the path with the highest log probabilities.
#
# It assumes the maximum number of symbols that can be
# emitted per frame is 1.
def batch_force_alignment(
model: torch.nn.Module,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_list: List[List[int]],
beam_size: int = 4,
) -> List[int]:
"""Compute the force alignment of a batch of utterances given their transcripts
in BPE tokens and the corresponding acoustic output from the encoder.
Caution:
This function is modified from `modified_beam_search` in beam_search.py.
We assume that the maximum number of sybmols per frame is 1.
Args:
model:
The transducer model.
encoder_out:
A tensor of shape (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
ys_list:
A list of BPE token IDs list. We require that for each utterance i,
len(ys_list[i]) <= encoder_out_lens[i].
beam_size:
Size of the beam used in beam search.
Returns:
Return a list of frame indexes list for each utterance i,
where len(ans[i]) == len(ys_list[i]).
"""
assert encoder_out.ndim == 3, encoder_out.ndim
assert encoder_out.size(0) == len(ys_list), (encoder_out.size(0), len(ys_list))
assert encoder_out.size(0) > 0, encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = next(model.parameters()).device
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
sorted_indices = packed_encoder_out.sorted_indices.tolist()
encoder_out_lens = encoder_out_lens.tolist()
ys_lens = [len(ys) for ys in ys_list]
sorted_encoder_out_lens = [encoder_out_lens[i] for i in sorted_indices]
sorted_ys_lens = [ys_lens[i] for i in sorted_indices]
sorted_ys_list = [ys_list[i] for i in sorted_indices]
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
timestamp=[],
)
)
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
finalized_B = []
for t, batch_size in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
sorted_encoder_out_lens = sorted_encoder_out_lens[:batch_size]
sorted_ys_lens = sorted_ys_lens[:batch_size]
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out, decoder_out, project_input=False
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs.reshape(-1)
) # [batch][num_hyps*vocab_size]
for i in range(batch_size):
for h, hyp in enumerate(A[i]):
pos_u = len(hyp.timestamp)
idx_offset = h * vocab_size
if (sorted_encoder_out_lens[i] - 1 - t) >= (sorted_ys_lens[i] - pos_u):
# emit blank token
new_hyp = Hypothesis(
log_prob=ragged_log_probs[i][idx_offset + blank_id],
ys=hyp.ys[:],
timestamp=hyp.timestamp[:],
)
B[i].add(new_hyp)
if pos_u < sorted_ys_lens[i]:
# emit non-blank token
new_token = sorted_ys_list[i][pos_u]
new_hyp = Hypothesis(
log_prob=ragged_log_probs[i][idx_offset + new_token],
ys=hyp.ys + [new_token],
timestamp=hyp.timestamp + [t],
)
B[i].add(new_hyp)
if len(B[i]) > beam_size:
B[i] = B[i].topk(beam_size, length_norm=True)
B = B + finalized_B
sorted_hyps = [b.get_most_probable() for b in B]
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
hyps = [sorted_hyps[i] for i in unsorted_indices]
ans = []
for i, hyp in enumerate(hyps):
assert hyp.ys[context_size:] == ys_list[i], (hyp.ys[context_size:], ys_list[i])
ans.append(hyp.timestamp)
return ans

View File

@ -0,0 +1,460 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class LibriSpeechAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--full-libri",
type=str2bool,
default=True,
help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
)
@lru_cache()
def train_clean_100_cuts_sample(self) -> CutSet:
logging.info("About to get train-clean-100 cuts (sampled)")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-100-0.4.jsonl.gz"
)
@lru_cache()
def train_clean_360_cuts(self) -> CutSet:
logging.info("About to get train-clean-360 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
)
@lru_cache()
def train_other_500_cuts(self) -> CutSet:
logging.info("About to get train-other-500 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
)
@lru_cache()
def train_all_shuf_cuts(self) -> CutSet:
logging.info(
"About to get the shuffled train-clean-100, \
train-clean-360 and train-other-500 cuts"
)
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
)
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
)
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
)
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
)
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
)

View File

@ -0,0 +1,259 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class SBCAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/manifests"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
# @lru_cache()
# def test_healthcare_cuts(self) -> CutSet:
# logging.info("About to get test-healthcare cuts")
# return load_manifest_lazy(
# self.args.manifest_dir / "uniphore_cuts_healthcare.jsonl.gz"
# )
# @lru_cache()
# def test_banking_cuts(self) -> CutSet:
# logging.info("About to get test-banking cuts")
# return load_manifest_lazy(
# self.args.manifest_dir / "uniphore_cuts_banking.jsonl.gz"
# )
# @lru_cache()
# def test_insurance_cuts(self) -> CutSet:
# logging.info("About to get test-insurance cuts")
# return load_manifest_lazy(
# self.args.manifest_dir / "uniphore_cuts_insurance.jsonl.gz"
# )
@lru_cache()
def test_cuts(self, cuts_file) -> CutSet:
logging.info(f"About to get cuts from {cuts_file}")
return load_manifest_lazy(cuts_file)
# @lru_cache()
# def test_sbc_cuts(self, cuts_file, sampling_rate) -> CutSet:
# logging.info(f"About to get SBC cuts from {cuts_file}")
# cuts = CutSet.from_file(cuts_file)
# cuts = [c.to_mono(mono_downmix=True).resample(sampling_rate) for c in cuts]
# cuts = CutSet.from_cuts(cuts)
# return cuts

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,114 @@
import torch
import torch.nn as nn
import torch
import torch.nn as nn
class Ffn(nn.Module):
def __init__(self, input_dim, hidden_dim, out_dim, nlayers=1, drop_out=0.1, skip=False) -> None:
super().__init__()
layers = []
for ilayer in range(nlayers):
_in = hidden_dim if ilayer > 0 else input_dim
_out = hidden_dim if ilayer < nlayers - 1 else out_dim
layers.extend([
nn.Linear(_in, _out),
# nn.ReLU(),
# nn.Sigmoid(),
nn.Tanh(),
nn.Dropout(p=drop_out),
])
self.ffn = torch.nn.Sequential(
*layers,
)
self.skip = skip
def forward(self, x) -> torch.Tensor:
x_out = self.ffn(x)
if self.skip:
x_out = x_out + x
return x_out
class BiasingModule(torch.nn.Module):
def __init__(
self,
query_dim,
qkv_dim=64,
num_heads=4,
):
super(BiasingModule, self).__init__()
self.proj_in1 = nn.Linear(query_dim, qkv_dim)
self.proj_in2 = Ffn(
input_dim=qkv_dim,
hidden_dim=qkv_dim,
out_dim=qkv_dim,
skip=True,
drop_out=0.1,
nlayers=2,
)
self.multihead_attn = torch.nn.MultiheadAttention(
embed_dim=qkv_dim,
num_heads=num_heads,
# kdim=64,
# vdim=64,
batch_first=True,
)
self.proj_out1 = Ffn(
input_dim=qkv_dim,
hidden_dim=qkv_dim,
out_dim=qkv_dim,
skip=True,
drop_out=0.1,
nlayers=2,
)
self.proj_out2 = nn.Linear(qkv_dim, query_dim)
self.glu = nn.GLU()
def forward(
self,
queries,
contexts,
contexts_mask,
need_weights=False,
):
"""
Args:
query:
of shape batch_size * seq_length * query_dim
contexts:
of shape batch_size * max_contexts_size * query_dim
contexts_mask:
of shape batch_size * max_contexts_size
Returns:
attn_output:
of shape batch_size * seq_length * context_dim
"""
_queries = self.proj_in1(queries)
_queries = self.proj_in2(_queries)
# queries = queries / 0.01
attn_output, attn_output_weights = self.multihead_attn(
_queries, # query
contexts, # key
contexts, # value
key_padding_mask=contexts_mask,
need_weights=need_weights,
)
output = self.proj_out1(attn_output)
output = self.proj_out2(output)
# apply the gated linear unit
biasing_output = self.glu(output.repeat(1,1,2))
# print(f"query={query.shape}")
# print(f"value={contexts} value.shape={contexts.shape}")
# print(f"attn_output_weights={attn_output_weights} attn_output_weights.shape={attn_output_weights.shape}")
return biasing_output, attn_output_weights

View File

@ -0,0 +1,345 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The script gets forced-alignments based on the modified_beam_search decoding method.
Both token-level alignments and word-level alignments are saved to the new cuts manifests.
It loads a checkpoint and uses it to get the forced-alignments.
You can generate the checkpoint with the following command:
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
Usage of this script:
./pruned_transducer_stateless7/compute_ali.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--bpe-model data/lang_bpe_500/bpe.model \
--dataset test-clean \
--max-duration 300 \
--beam-size 4 \
--cuts-out-dir data/fbank_ali_beam_search
"""
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import sentencepiece as spm
import torch
import torch.nn as nn
from alignment import batch_force_alignment
from asr_datamodule import LibriSpeechAsrDataModule
from lhotse import CutSet
from lhotse.serialization import SequentialJsonlWriter
from lhotse.supervision import AlignmentItem
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import AttributeDict, convert_timestamp, parse_timestamp
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--dataset",
type=str,
required=True,
help="""The name of the dataset to compute alignments for.
Possible values are:
- test-clean
- test-other
- train-clean-100
- train-clean-360
- train-other-500
- dev-clean
- dev-other
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--cuts-out-dir",
type=str,
default="data/fbank_ali_beam_search",
help="The dir to save the new cuts manifests with alignments",
)
add_model_arguments(parser)
return parser
def align_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
) -> Tuple[List[List[str]], List[List[str]], List[List[float]], List[List[float]]]:
"""Get forced-alignments for one batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
Returns:
token_list:
A list of token list.
word_list:
A list of word list.
token_time_list:
A list of timestamps list for tokens.
word_time_list.
A list of timestamps list for words.
where len(token_list) == len(word_list) == len(token_time_list) == len(word_time_list),
len(token_list[i]) == len(token_time_list[i]),
and len(word_list[i]) == len(word_time_list[i])
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
texts = supervisions["text"]
ys_list: List[List[int]] = sp.encode(texts, out_type=int)
frame_indexes = batch_force_alignment(
model, encoder_out, encoder_out_lens, ys_list, params.beam_size
)
token_list = []
word_list = []
token_time_list = []
word_time_list = []
for i in range(encoder_out.size(0)):
tokens = sp.id_to_piece(ys_list[i])
words = texts[i].split()
token_time = convert_timestamp(
frame_indexes[i], params.subsampling_factor, params.frame_shift_ms
)
word_time = parse_timestamp(tokens, token_time)
assert len(word_time) == len(words), (len(word_time), len(words))
token_list.append(tokens)
word_list.append(words)
token_time_list.append(token_time)
word_time_list.append(word_time)
return token_list, word_list, token_time_list, word_time_list
def align_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
writer: SequentialJsonlWriter,
) -> None:
"""Get forced-alignments for the dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
writer:
Writer to save the cuts with alignments.
"""
log_interval = 20
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
for batch_idx, batch in enumerate(dl):
token_list, word_list, token_time_list, word_time_list = align_one_batch(
params=params, model=model, sp=sp, batch=batch
)
cut_list = batch["supervisions"]["cut"]
for cut, token, word, token_time, word_time in zip(
cut_list, token_list, word_list, token_time_list, word_time_list
):
assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}"
token_ali = [
AlignmentItem(
symbol=token[i],
start=round(token_time[i], ndigits=3),
duration=None,
)
for i in range(len(token))
]
word_ali = [
AlignmentItem(
symbol=word[i], start=round(word_time[i], ndigits=3), duration=None
)
for i in range(len(word))
]
cut.supervisions[0].alignment = {"word": word_ali, "token": token_ali}
writer.write(cut, flush=True)
num_cuts += len(cut_list)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
if params.dataset == "test-clean":
test_clean_cuts = librispeech.test_clean_cuts()
dl = librispeech.test_dataloaders(test_clean_cuts)
elif params.dataset == "test-other":
test_other_cuts = librispeech.test_other_cuts()
dl = librispeech.test_dataloaders(test_other_cuts)
elif params.dataset == "train-clean-100":
train_clean_100_cuts = librispeech.train_clean_100_cuts()
dl = librispeech.train_dataloaders(train_clean_100_cuts)
elif params.dataset == "train-clean-360":
train_clean_360_cuts = librispeech.train_clean_360_cuts()
dl = librispeech.train_dataloaders(train_clean_360_cuts)
elif params.dataset == "train-other-500":
train_other_500_cuts = librispeech.train_other_500_cuts()
dl = librispeech.train_dataloaders(train_other_500_cuts)
elif params.dataset == "dev-clean":
dev_clean_cuts = librispeech.dev_clean_cuts()
dl = librispeech.valid_dataloaders(dev_clean_cuts)
else:
assert params.dataset == "dev-other", f"{params.dataset}"
dev_other_cuts = librispeech.dev_other_cuts()
dl = librispeech.valid_dataloaders(dev_other_cuts)
cuts_out_dir = Path(params.cuts_out_dir)
cuts_out_dir.mkdir(parents=True, exist_ok=True)
cuts_out_path = cuts_out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz"
with CutSet.open_writer(cuts_out_path) as writer:
align_dataset(dl=dl, params=params, model=model, sp=sp, writer=writer)
logging.info(
f"For dataset {params.dataset}, the cut manifest with framewise token alignments "
f"and word alignments are saved to {cuts_out_path}"
)
logging.info("Done!")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,371 @@
import torch
import random
from pathlib import Path
import sentencepiece as spm
from typing import Union, List
import logging
import ast
import numpy as np
from itertools import chain
from word_encoder_bert import BertEncoder
from context_wfst import generate_context_graph_nfa
class SentenceTokenizer:
def encode(self, word_list: List, out_type: type = int) -> List:
"""
Encode a list of words into a list of tokens
Args:
word_list:
A list of words where each word is a string.
E.g., ["nihao", "hello", "你好"]
out_type:
This defines the output type. If it is an "int" type,
then each token is represented by its interger token id.
Returns:
A list of tokenized words, where each tokenization
is a list of tokens.
"""
pass
class ContextCollector(torch.utils.data.Dataset):
def __init__(
self,
path_is21_deep_bias: Path,
sp: Union[spm.SentencePieceProcessor, SentenceTokenizer],
bert_encoder: BertEncoder = None,
n_distractors: int = 100,
ratio_distractors: int = None,
is_predefined: bool = False,
keep_ratio: float = 1.0,
is_full_context: bool = False,
backoff_id: int = None,
):
self.sp = sp
self.bert_encoder = bert_encoder
self.path_is21_deep_bias = path_is21_deep_bias
self.n_distractors = n_distractors
self.ratio_distractors = ratio_distractors
self.is_predefined = is_predefined
self.keep_ratio = keep_ratio
self.is_full_context = is_full_context # use all words (rare or common) in the context
# self.embedding_dim = self.bert_encoder.bert_model.config.hidden_size
self.backoff_id = backoff_id
logging.info(f"""
n_distractors={n_distractors},
ratio_distractors={ratio_distractors},
is_predefined={is_predefined},
keep_ratio={keep_ratio},
is_full_context={is_full_context},
bert_encoder={bert_encoder.name if bert_encoder is not None else None},
""")
self.common_words = None
self.rare_words = None
self.all_words = None
with open(path_is21_deep_bias / "words/all_rare_words.txt", "r") as fin:
self.rare_words = [l.strip().upper() for l in fin if len(l) > 0]
with open(path_is21_deep_bias / "words/common_words_5k.txt", "r") as fin:
self.common_words = [l.strip().upper() for l in fin if len(l) > 0]
self.all_words = self.rare_words + self.common_words # sp needs a list of strings, can't be a set
self.common_words = set(self.common_words)
self.rare_words = set(self.rare_words)
logging.info(f"Number of common words: {len(self.common_words)}. Examples: {random.sample(self.common_words, 5)}")
logging.info(f"Number of rare words: {len(self.rare_words)}. Examples: {random.sample(self.rare_words, 5)}")
logging.info(f"Number of all words: {len(self.all_words)}. Examples: {random.sample(self.all_words, 5)}")
self.test_clean_biasing_list = None
self.test_other_biasing_list = None
if is_predefined:
def read_ref_biasing_list(filename):
biasing_list = dict()
all_cnt = 0
rare_cnt = 0
with open(filename, "r") as fin:
for line in fin:
line = line.strip().upper()
if len(line) == 0:
continue
line = line.split("\t")
uid, ref_text, ref_rare_words, context_rare_words = line
context_rare_words = ast.literal_eval(context_rare_words)
biasing_list[uid] = context_rare_words
ref_rare_words = ast.literal_eval(ref_rare_words)
ref_text = ref_text.split()
all_cnt += len(ref_text)
rare_cnt += len(ref_rare_words)
return biasing_list, rare_cnt / all_cnt
self.test_clean_biasing_list, ratio_clean = \
read_ref_biasing_list(self.path_is21_deep_bias / f"ref/test-clean.biasing_{n_distractors}.tsv")
self.test_other_biasing_list, ratio_other = \
read_ref_biasing_list(self.path_is21_deep_bias / f"ref/test-other.biasing_{n_distractors}.tsv")
logging.info(f"Number of utterances in test_clean_biasing_list: {len(self.test_clean_biasing_list)}, rare ratio={ratio_clean:.2f}")
logging.info(f"Number of utterances in test_other_biasing_list: {len(self.test_other_biasing_list)}, rare ratio={ratio_other:.2f}")
self.all_words2pieces = None
if self.sp is not None:
all_words2pieces = sp.encode(self.all_words, out_type=int) # a list of list of int
self.all_words2pieces = {w: pieces for w, pieces in zip(self.all_words, all_words2pieces)}
logging.info(f"len(self.all_words2pieces)={len(self.all_words2pieces)}")
self.all_words2embeddings = None
if self.bert_encoder is not None:
all_words = list(chain(self.common_words, self.rare_words))
all_embeddings = self.bert_encoder.encode_strings(all_words)
assert len(all_words) == len(all_embeddings)
self.all_words2embeddings = {w: ebd for w, ebd in zip(all_words, all_embeddings)}
logging.info(f"len(self.all_words2embeddings)={len(self.all_words2embeddings)}")
if is_predefined:
new_words_bias = set()
all_words_bias = set()
for uid, wlist in chain(self.test_clean_biasing_list.items(), self.test_other_biasing_list.items()):
for word in wlist:
if word not in self.common_words and word not in self.rare_words:
new_words_bias.add(word)
all_words_bias.add(word)
# if self.all_words2pieces is not None and word not in self.all_words2pieces:
# self.all_words2pieces[word] = self.sp.encode(word, out_type=int)
# if self.all_words2embeddings is not None and word not in self.all_words2embeddings:
# self.all_words2embeddings[word] = self.bert_encoder.encode_strings([word])[0]
logging.info(f"OOVs in the biasing list: {len(new_words_bias)}/{len(all_words_bias)}")
if len(new_words_bias) > 0:
self.add_new_words(list(new_words_bias), silent=True)
if is_predefined:
assert self.ratio_distractors is None
assert self.n_distractors in [100, 500, 1000, 2000]
self.temp_dict = None
def add_new_words(self, new_words_list, return_dict=False, silent=False):
if len(new_words_list) == 0:
if return_dict is True:
return dict()
else:
return
if self.all_words2pieces is not None:
words_pieces_list = self.sp.encode(new_words_list, out_type=int)
new_words2pieces = {w: pieces for w, pieces in zip(new_words_list, words_pieces_list)}
if return_dict:
return new_words2pieces
else:
self.all_words2pieces.update(new_words2pieces)
if self.all_words2embeddings is not None:
embeddings_list = self.bert_encoder.encode_strings(new_words_list, silent=silent)
new_words2embeddings = {w: ebd for w, ebd in zip(new_words_list, embeddings_list)}
if return_dict:
return new_words2embeddings
else:
self.all_words2embeddings.update(new_words2embeddings)
self.all_words.extend(new_words_list)
self.rare_words.update(new_words_list)
def discard_some_common_words(words, keep_ratio):
pass
def _get_random_word_lists(self, batch):
texts = batch["supervisions"]["text"]
new_words = []
rare_words_list = []
for text in texts:
rare_words = []
for word in text.split():
if self.is_full_context or word not in self.common_words:
rare_words.append(word)
if self.all_words2pieces is not None and word not in self.all_words2pieces:
new_words.append(word)
# self.all_words2pieces[word] = self.sp.encode(word, out_type=int)
if self.all_words2embeddings is not None and word not in self.all_words2embeddings:
new_words.append(word)
# logging.info(f"New word detected: {word}")
# self.all_words2embeddings[word] = self.bert_encoder.encode_strings([word])[0]
rare_words = list(set(rare_words)) # deduplication
if self.keep_ratio < 1.0 and len(rare_words) > 0:
# # method 1:
# keep_size = int(len(rare_words) * self.keep_ratio)
# if keep_size > 0:
# rare_words = random.sample(rare_words, keep_size)
# else:
# rare_words = []
# method 2:
x = np.random.rand(len(rare_words))
new_rare_words = []
for xi in range(len(rare_words)):
if x[xi] < self.keep_ratio:
new_rare_words.append(rare_words[xi])
rare_words = new_rare_words
rare_words_list.append(rare_words)
self.temp_dict = None
if len(new_words) > 0:
self.temp_dict = self.add_new_words(new_words, return_dict=True, silent=True)
if self.ratio_distractors is not None:
n_distractors_each = []
for rare_words in rare_words_list:
n_distractors_each.append(len(rare_words) * self.ratio_distractors)
n_distractors_each = np.asarray(n_distractors_each, dtype=int)
else:
if self.n_distractors == -1: # variable context list sizes
n_distractors_each = np.random.randint(low=10, high=500, size=len(texts))
# n_distractors_each = np.random.randint(low=80, high=300, size=len(texts))
else:
n_distractors_each = np.full(len(texts), self.n_distractors, int)
distractors_cnt = n_distractors_each.sum()
distractors = random.sample( # without replacement
self.rare_words,
distractors_cnt
) # TODO: actually the context should contain both rare and common words
# distractors = random.choices( # random choices with replacement
# self.rare_words,
# distractors_cnt,
# )
distractors_pos = 0
for i, rare_words in enumerate(rare_words_list):
rare_words.extend(distractors[distractors_pos: distractors_pos + n_distractors_each[i]])
distractors_pos += n_distractors_each[i]
# random.shuffle(rare_words)
# logging.info(rare_words)
assert distractors_pos == len(distractors)
return rare_words_list
def _get_predefined_word_lists(self, batch):
rare_words_list = []
for cut in batch['supervisions']['cut']:
uid = cut.supervisions[0].id
if uid in self.test_clean_biasing_list:
rare_words_list.append(self.test_clean_biasing_list[uid])
elif uid in self.test_other_biasing_list:
rare_words_list.append(self.test_other_biasing_list[uid])
else:
rare_words_list.append([])
logging.error(f"uid={uid} cannot find the predefined biasing list of size {self.n_distractors}")
# for wl in rare_words_list:
# for w in wl:
# if w not in self.all_words2pieces:
# self.all_words2pieces[w] = self.sp.encode(w, out_type=int)
return rare_words_list
def get_context_word_list(
self,
batch: dict,
):
"""
Generate/Get the context biasing list as a list of words for each utterance
Use keep_ratio to simulate the "imperfect" context which may not have 100% coverage of the ground truth words.
"""
if self.is_predefined:
rare_words_list = self._get_predefined_word_lists(batch)
else:
rare_words_list = self._get_random_word_lists(batch)
rare_words_list = [sorted(rwl) for rwl in rare_words_list]
if self.all_words2embeddings is None:
# Use SentencePiece to encode the words
rare_words_pieces_list = []
max_pieces_len = 0
for rare_words in rare_words_list:
rare_words_pieces = [self.all_words2pieces[w] if w in self.all_words2pieces else self.temp_dict[w] for w in rare_words]
if len(rare_words_pieces) > 0:
max_pieces_len = max(max_pieces_len, max(len(pieces) for pieces in rare_words_pieces))
rare_words_pieces_list.append(rare_words_pieces)
else:
# Use BERT embeddings here
rare_words_embeddings_list = []
for rare_words in rare_words_list:
# for w in rare_words:
# if w not in self.all_words2embeddings and (self.temp_dict is not None and w not in self.temp_dict):
# import pdb; pdb.set_trace()
# if w == "STUBE":
# import pdb; pdb.set_trace()
rare_words_embeddings = [self.all_words2embeddings[w] if w in self.all_words2embeddings else self.temp_dict[w] for w in rare_words]
rare_words_embeddings_list.append(rare_words_embeddings)
if self.all_words2embeddings is None:
# Use SentencePiece to encode the words
word_list = []
word_lengths = []
num_words_per_utt = []
pad_token = 0
for rare_words_pieces in rare_words_pieces_list:
num_words_per_utt.append(len(rare_words_pieces))
word_lengths.extend([len(pieces) for pieces in rare_words_pieces])
# # TODO: this is a bug here: this will effectively modify the entries in 'self.all_words2embeddings'!!!
# for pieces in rare_words_pieces:
# pieces += [pad_token] * (max_pieces_len - len(pieces))
# word_list.extend(rare_words_pieces)
# Correction:
rare_words_pieces_padded = list()
for pieces in rare_words_pieces:
rare_words_pieces_padded.append(pieces + [pad_token] * (max_pieces_len - len(pieces)))
word_list.extend(rare_words_pieces_padded)
word_list = torch.tensor(word_list, dtype=torch.int32)
# word_lengths = torch.tensor(word_lengths, dtype=torch.int32)
# num_words_per_utt = torch.tensor(num_words_per_utt, dtype=torch.int32)
else:
# Use BERT embeddings here
word_list = []
word_lengths = None
num_words_per_utt = []
for rare_words_embeddings in rare_words_embeddings_list:
num_words_per_utt.append(len(rare_words_embeddings))
word_list.extend(rare_words_embeddings)
word_list = torch.stack(word_list)
return word_list, word_lengths, num_words_per_utt
def get_context_word_wfst(
self,
batch: dict,
):
"""
Get the WFST representation of the context biasing list as a list of words for each utterance
"""
if self.is_predefined:
rare_words_list = self._get_predefined_word_lists(batch)
else:
rare_words_list = self._get_random_word_lists(batch)
# TODO:
# We can associate weighted or dynamic weights for each rare word or token
nbest_size = 1 # TODO: The maximum number of different tokenization for each lexicon entry.
# Use SentencePiece to encode the words
rare_words_pieces_list = []
num_words_per_utt = []
for rare_words in rare_words_list:
rare_words_pieces = [self.all_words2pieces[w] if w in self.all_words2pieces else self.temp_dict[w] for w in rare_words]
rare_words_pieces_list.append(rare_words_pieces)
num_words_per_utt.append(len(rare_words))
fsa_list, fsa_sizes = generate_context_graph_nfa(
words_pieces_list = rare_words_pieces_list,
backoff_id = self.backoff_id,
sp = self.sp,
)
return fsa_list, fsa_sizes, num_words_per_utt

View File

@ -0,0 +1,116 @@
import torch
import abc
class ContextEncoder(torch.nn.Module):
def __init__(self):
super(ContextEncoder, self).__init__()
@abc.abstractmethod
def forward(
self,
word_list,
word_lengths,
is_encoder_side=None,
):
pass
def embed_contexts(
self,
contexts,
is_encoder_side=None,
):
"""
Args:
contexts:
The contexts, see below for details
Returns:
final_h:
A tensor of shape (batch_size, max(num_words_per_utt) + 1, joiner_dim),
which is the embedding for each context word.
mask_h:
A tensor of shape (batch_size, max(num_words_per_utt) + 1),
which contains a True/False mask for final_h
"""
if contexts["mode"] == "get_context_word_list":
"""
word_list:
Option1: A list of words, where each word is a list of token ids.
The list of tokens for each word has been padded.
Option2: A list of words, where each word is an embedding.
word_lengths:
Option1: The number of tokens per word
Option2: None
num_words_per_utt:
The number of words in the context for each utterance
"""
word_list, word_lengths, num_words_per_utt = \
contexts["word_list"], contexts["word_lengths"], contexts["num_words_per_utt"]
assert word_lengths is None or word_list.size(0) == len(word_lengths)
batch_size = len(num_words_per_utt)
elif contexts["mode"] == "get_context_word_list_shared":
"""
word_list:
Option1: A list of words, where each word is a list of token ids.
The list of tokens for each word has been padded.
Option2: A list of words, where each word is an embedding.
word_lengths:
Option1: The number of tokens per word
Option2: None
positive_mask_list:
For each utterance, it contains a list of indices of the words should be masked
"""
word_list, word_lengths, positive_mask_list = \
contexts["word_list"], contexts["word_lengths"], contexts["positive_mask_list"]
batch_size = len(positive_mask_list)
assert word_lengths is None or word_list.size(0) == len(word_lengths)
else:
raise NotImplementedError
# print(f"word_list.shape={word_list.shape}")
final_h = self.forward(word_list, word_lengths, is_encoder_side=is_encoder_side)
if contexts["mode"] == "get_context_word_list":
final_h = torch.split(final_h, num_words_per_utt)
final_h = torch.nn.utils.rnn.pad_sequence(
final_h,
batch_first=True,
padding_value=0.0
)
# print(f"final_h.shape={final_h.shape}")
# add one no-bias token
no_bias_h = torch.zeros(final_h.shape[0], 1, final_h.shape[-1])
no_bias_h = no_bias_h.to(final_h.device)
final_h = torch.cat((no_bias_h, final_h), dim=1)
# print(final_h)
# https://stackoverflow.com/questions/53403306/how-to-batch-convert-sentence-lengths-to-masks-in-pytorch
mask_h = torch.arange(max(num_words_per_utt) + 1)
mask_h = mask_h.expand(len(num_words_per_utt), max(num_words_per_utt) + 1) > torch.Tensor(num_words_per_utt).unsqueeze(1)
mask_h = mask_h.to(final_h.device)
elif contexts["mode"] == "get_context_word_list_shared":
no_bias_h = torch.zeros(1, final_h.shape[-1])
no_bias_h = no_bias_h.to(final_h.device)
final_h = torch.cat((no_bias_h, final_h), dim=0)
final_h = final_h.expand(batch_size, -1, -1)
mask_h = torch.full(False, (batch_size, final_h.shape(1))) # TODO
for i, my_mask in enumerate(positive_mask_list):
if len(my_mask) > 0:
my_mask = torch.Tensor(my_mask, dtype=int)
my_mask += 1
mask_h[i][my_mask] = True
# TODO: validate this shape is correct:
# final_h: batch_size * max_num_words_per_utt + 1 * dim
# mask_h: batch_size * max_num_words_per_utt + 1
return final_h, mask_h
def clustering(self):
pass
def cache(self):
pass

View File

@ -0,0 +1,101 @@
import torch
from context_encoder import ContextEncoder
import copy
class ContextEncoderLSTM(ContextEncoder):
def __init__(
self,
vocab_size: int = None,
context_encoder_dim: int = None,
output_dim: int = None,
num_layers: int = None,
num_directions: int = None,
drop_out: float = 0.1,
bi_encoders: bool = False,
):
super(ContextEncoderLSTM, self).__init__()
self.num_layers = num_layers
self.num_directions = num_directions
self.context_encoder_dim = context_encoder_dim
torch.manual_seed(42)
self.embed = torch.nn.Embedding(
vocab_size,
context_encoder_dim
)
self.rnn = torch.nn.LSTM(
input_size=context_encoder_dim,
hidden_size=context_encoder_dim,
num_layers=self.num_layers,
batch_first=True,
bidirectional=(self.num_directions == 2),
dropout=0.0 if self.num_layers > 1 else 0
)
self.linear = torch.nn.Linear(
context_encoder_dim * self.num_directions,
output_dim
)
self.drop_out = torch.nn.Dropout(drop_out)
# TODO: Do we need some relu layer?
# https://galhever.medium.com/sentiment-analysis-with-pytorch-part-4-lstm-bilstm-model-84447f6c4525
# self.relu = nn.ReLU()
# self.dropout = nn.Dropout(dropout)
self.bi_encoders = bi_encoders
if bi_encoders:
# Create the decoder/predictor side of the context encoder
self.embed_dec = copy.deepcopy(self.embed)
self.rnn_dec = copy.deepcopy(self.rnn)
self.linear_dec = copy.deepcopy(self.linear)
self.drop_out_dec = copy.deepcopy(self.drop_out)
def forward(
self,
word_list,
word_lengths,
is_encoder_side=None,
):
if is_encoder_side is None or is_encoder_side is True:
embed = self.embed
rnn = self.rnn
linear = self.linear
drop_out = self.drop_out
else:
embed = self.embed_dec
rnn = self.rnn_dec
linear = self.linear_dec
drop_out = self.drop_out_dec
out = embed(word_list)
# https://stackoverflow.com/questions/51030782/why-do-we-pack-the-sequences-in-pytorch
out = torch.nn.utils.rnn.pack_padded_sequence(
out,
batch_first=True,
lengths=word_lengths,
enforce_sorted=False
)
output, (hn, cn) = rnn(out) # use default all zeros (h_0, c_0)
# # https://discuss.pytorch.org/t/bidirectional-3-layer-lstm-hidden-output/41336/4
# final_state = hn.view(
# self.num_layers,
# self.num_directions,
# word_list.shape[0],
# self.encoder_dim,
# )[-1] # Only the last layer
# h_1, h_2 = final_state[0], final_state[1]
# # X = h_1 + h_2 # Add both states (needs different input size for first linear layer)
# final_h = torch.cat((h_1, h_2), dim=1) # Concatenate both states
# final_h = self.linear(final_h)
# hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]
# outputs are always from the last layer.
# hidden[-2, :, : ] is the last of the forwards RNN
# hidden[-1, :, : ] is the last of the backwards RNN
h_1, h_2 = hn[-2, :, : ] , hn[-1, :, : ]
final_h = torch.cat((h_1, h_2), dim=1) # Concatenate both states
final_h = linear(final_h)
# final_h = drop_out(final_h)
return final_h

View File

@ -0,0 +1,52 @@
import torch
from context_encoder import ContextEncoder
import torch.nn.functional as F
class ContextEncoderPretrained(ContextEncoder):
def __init__(
self,
vocab_size: int = None,
context_encoder_dim: int = None,
output_dim: int = None,
num_layers: int = None,
num_directions: int = None,
drop_out: float = 0.3,
):
super(ContextEncoderPretrained, self).__init__()
self.drop_out = torch.nn.Dropout(drop_out)
self.linear1 = torch.nn.Linear(
context_encoder_dim, # 768
256,
)
self.linear3 = torch.nn.Linear(
256,
256,
)
self.linear4 = torch.nn.Linear(
256,
256,
)
self.linear2 = torch.nn.Linear(
256,
output_dim
)
self.sigmoid = torch.nn.Sigmoid()
self.bi_encoders = False
def forward(
self,
word_list,
word_lengths,
is_encoder_side=None,
):
out = word_list # Shape: N*L*D
# out = self.drop_out(out)
out = self.sigmoid(self.linear1(out)) # Note: ReLU may not be a good choice here
out = self.sigmoid(self.linear3(out))
out = self.sigmoid(self.linear4(out))
# out = self.drop_out(out)
out = self.linear2(out)
return out

View File

@ -0,0 +1,90 @@
import torch
from context_encoder import ContextEncoder
import copy
class ContextEncoderReused(ContextEncoder):
def __init__(
self,
decoder,
decoder_dim: int = None,
output_dim: int = None,
num_lstm_layers: int = None,
num_lstm_directions: int = None,
drop_out: float = 0.1,
):
super(ContextEncoderReused, self).__init__()
# self.num_lstm_layers = num_lstm_layers
# self.num_lstm_directions = num_lstm_directions
# self.decoder_dim = decoder_dim
hidden_size = output_dim * 2 # decoder_dim
self.rnn = torch.nn.LSTM(
input_size=decoder_dim,
hidden_size=hidden_size,
num_layers=num_lstm_layers,
batch_first=True,
bidirectional=(num_lstm_directions == 2),
dropout=0.1 if num_lstm_layers > 1 else 0
)
self.linear = torch.nn.Linear(
hidden_size * num_lstm_directions,
output_dim
)
self.drop_out = torch.nn.Dropout(drop_out)
self.decoder = decoder
self.bi_encoders = False
# TODO: Do we need some relu layer?
# https://galhever.medium.com/sentiment-analysis-with-pytorch-part-4-lstm-bilstm-model-84447f6c4525
# self.relu = nn.ReLU()
# self.dropout = nn.Dropout(dropout)
def forward(
self,
word_list,
word_lengths,
is_encoder_side=None,
):
sos_id = self.decoder.blank_id
sos_list = torch.full((word_list.shape[0], 1), sos_id).to(word_list.device)
sos_word_list = torch.cat((sos_list, word_list), 1)
word_lengths = [x + 1 for x in word_lengths]
# sos_word_list: (N, U)
# decoder_out: (N, U, decoder_dim)
out = self.decoder(sos_word_list)
# https://stackoverflow.com/questions/51030782/why-do-we-pack-the-sequences-in-pytorch
# https://gist.github.com/HarshTrivedi/f4e7293e941b17d19058f6fb90ab0fec
out = torch.nn.utils.rnn.pack_padded_sequence(
out,
batch_first=True,
lengths=word_lengths,
enforce_sorted=False
)
output, (hn, cn) = self.rnn(out) # use default all zeros (h_0, c_0)
# # https://discuss.pytorch.org/t/bidirectional-3-layer-lstm-hidden-output/41336/4
# final_state = hn.view(
# self.num_layers,
# self.num_directions,
# word_list.shape[0],
# self.encoder_dim,
# )[-1] # Only the last layer
# h_1, h_2 = final_state[0], final_state[1]
# # X = h_1 + h_2 # Add both states (needs different input size for first linear layer)
# final_h = torch.cat((h_1, h_2), dim=1) # Concatenate both states
# final_h = self.linear(final_h)
# hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]
# outputs are always from the last layer.
# hidden[-2, :, : ] is the last of the forwards RNN
# hidden[-1, :, : ] is the last of the backwards RNN
h_1, h_2 = hn[-2, :, : ] , hn[-1, :, : ]
final_h = torch.cat((h_1, h_2), dim=1) # Concatenate both states
final_h = self.linear(final_h)
# final_h = drop_out(final_h)
return final_h

View File

@ -0,0 +1,210 @@
import torch
import random
from pathlib import Path
import sentencepiece as spm
from typing import List
import logging
import ast
import numpy as np
class ContextGenerator(torch.utils.data.Dataset):
def __init__(
self,
path_is21_deep_bias: Path,
sp: spm.SentencePieceProcessor,
n_distractors: int = 100,
is_predefined: bool = False,
keep_ratio: float = 1.0,
is_full_context: bool = False,
):
self.sp = sp
self.path_is21_deep_bias = path_is21_deep_bias
self.n_distractors = n_distractors
self.is_predefined = is_predefined
self.keep_ratio = keep_ratio
self.is_full_context = is_full_context # use all words (rare or common) in the context
logging.info(f"""
n_distractors={n_distractors},
is_predefined={is_predefined},
keep_ratio={keep_ratio},
is_full_context={is_full_context},
""")
self.all_rare_words2pieces = None
self.common_words = None
if not is_predefined:
with open(path_is21_deep_bias / "words/all_rare_words.txt", "r") as fin:
all_rare_words = [l.strip().upper() for l in fin if len(l) > 0] # a list of strings
all_rare_words_pieces = sp.encode(all_rare_words, out_type=int) # a list of list of int
self.all_rare_words2pieces = {w: pieces for w, pieces in zip(all_rare_words, all_rare_words_pieces)}
with open(path_is21_deep_bias / "words/common_words_5k.txt", "r") as fin:
self.common_words = set([l.strip().upper() for l in fin if len(l) > 0]) # a list of strings
logging.info(f"Number of common words: {len(self.common_words)}")
logging.info(f"Number of rare words: {len(self.all_rare_words2pieces)}")
self.test_clean_biasing_list = None
self.test_other_biasing_list = None
if is_predefined:
def read_ref_biasing_list(filename):
biasing_list = dict()
all_cnt = 0
rare_cnt = 0
with open(filename, "r") as fin:
for line in fin:
line = line.strip().upper()
if len(line) == 0:
continue
line = line.split("\t")
uid, ref_text, ref_rare_words, context_rare_words = line
context_rare_words = ast.literal_eval(context_rare_words)
biasing_list[uid] = [w for w in context_rare_words]
ref_rare_words = ast.literal_eval(ref_rare_words)
ref_text = ref_text.split()
all_cnt += len(ref_text)
rare_cnt += len(ref_rare_words)
return biasing_list, rare_cnt / all_cnt
self.test_clean_biasing_list, ratio_clean = \
read_ref_biasing_list(self.path_is21_deep_bias / f"ref/test-clean.biasing_{n_distractors}.tsv")
self.test_other_biasing_list, ratio_other = \
read_ref_biasing_list(self.path_is21_deep_bias / f"ref/test-other.biasing_{n_distractors}.tsv")
logging.info(f"Number of utterances in test_clean_biasing_list: {len(self.test_clean_biasing_list)}, rare ratio={ratio_clean:.2f}")
logging.info(f"Number of utterances in test_other_biasing_list: {len(self.test_other_biasing_list)}, rare ratio={ratio_other:.2f}")
# from itertools import chain
# for uid, context_rare_words in chain(self.test_clean_biasing_list.items(), self.test_other_biasing_list.items()):
# for w in context_rare_words:
# if self.all_rare_words2pieces:
# pass
# else:
# logging.warning(f"new word: {w}")
def get_context_word_list(
self,
batch: dict,
):
# import pdb; pdb.set_trace()
if self.is_predefined:
return self.get_context_word_list_predefined(batch=batch)
else:
return self.get_context_word_list_random(batch=batch)
def discard_some_common_words(words, keep_ratio):
pass
def get_context_word_list_random(
self,
batch: dict,
):
"""
Generate context biasing list as a list of words for each utterance
Use keep_ratio to simulate the "imperfect" context which may not have 100% coverage of the ground truth words.
"""
texts = batch["supervisions"]["text"]
rare_words_list = []
for text in texts:
rare_words = []
for word in text.split():
if self.is_full_context or word not in self.common_words:
rare_words.append(word)
if word not in self.all_rare_words2pieces:
self.all_rare_words2pieces[word] = self.sp.encode(word, out_type=int)
rare_words = list(set(rare_words)) # deduplication
if self.keep_ratio < 1.0 and len(rare_words) > 0:
rare_words = random.sample(rare_words, int(len(rare_words) * self.keep_ratio))
rare_words_list.append(rare_words)
n_distractors = self.n_distractors
if n_distractors == -1: # variable context list sizes
n_distractors_each = np.random.randint(low=80, high=1000, size=len(texts))
distractors_cnt = n_distractors_each.sum()
else:
n_distractors_each = np.zeros(len(texts), int)
n_distractors_each[:] = self.n_distractors
distractors_cnt = n_distractors_each.sum()
distractors = random.sample(
self.all_rare_words2pieces.keys(),
distractors_cnt
) # TODO: actually the context should contain both rare and common words
distractors_pos = 0
rare_words_pieces_list = []
max_pieces_len = 0
for i, rare_words in enumerate(rare_words_list):
rare_words.extend(distractors[distractors_pos: distractors_pos + n_distractors_each[i]])
distractors_pos += n_distractors_each[i]
# random.shuffle(rare_words)
# logging.info(rare_words)
rare_words_pieces = [self.all_rare_words2pieces[w] for w in rare_words]
if len(rare_words_pieces) > 0:
max_pieces_len = max(max_pieces_len, max(len(pieces) for pieces in rare_words_pieces))
rare_words_pieces_list.append(rare_words_pieces)
assert distractors_pos == len(distractors)
word_list = []
word_lengths = []
num_words_per_utt = []
pad_token = 0
for rare_words_pieces in rare_words_pieces_list:
num_words_per_utt.append(len(rare_words_pieces))
word_lengths.extend([len(pieces) for pieces in rare_words_pieces])
for pieces in rare_words_pieces:
pieces += [pad_token] * (max_pieces_len - len(pieces))
word_list.extend(rare_words_pieces)
word_list = torch.tensor(word_list, dtype=torch.int32)
# word_lengths = torch.tensor(word_lengths, dtype=torch.int32)
# num_words_per_utt = torch.tensor(num_words_per_utt, dtype=torch.int32)
return word_list, word_lengths, num_words_per_utt
def get_context_word_list_predefined(
self,
batch: dict,
):
rare_words_list = []
for cut in batch['supervisions']['cut']:
uid = cut.supervisions[0].id
if uid in self.test_clean_biasing_list:
rare_words_list.append(self.test_clean_biasing_list[uid])
elif uid in self.test_other_biasing_list:
rare_words_list.append(self.test_other_biasing_list[uid])
else:
logging.error(f"uid={uid} cannot find the predefined biasing list of size {self.n_distractors}")
rare_words_pieces_list = []
max_pieces_len = 0
for rare_words in rare_words_list:
# logging.info(rare_words)
rare_words_pieces = self.sp.encode(rare_words, out_type=int)
max_pieces_len = max(max_pieces_len, max(len(pieces) for pieces in rare_words_pieces))
rare_words_pieces_list.append(rare_words_pieces)
word_list = []
word_lengths = []
num_words_per_utt = []
pad_token = 0
for rare_words_pieces in rare_words_pieces_list:
num_words_per_utt.append(len(rare_words_pieces))
word_lengths.extend([len(pieces) for pieces in rare_words_pieces])
for pieces in rare_words_pieces:
pieces += [pad_token] * (max_pieces_len - len(pieces))
word_list.extend(rare_words_pieces)
word_list = torch.tensor(word_list, dtype=torch.int32)
# word_lengths = torch.tensor(word_lengths, dtype=torch.int32)
# num_words_per_utt = torch.tensor(num_words_per_utt, dtype=torch.int32)
return word_list, word_lengths, num_words_per_utt

View File

@ -0,0 +1,371 @@
import logging
import math
import re
from collections import defaultdict
from typing import Any, Dict, List, Tuple
import k2
import sentencepiece as spm
import torch
from kaldifst.utils import k2_to_openfst
def generate_context_graph_simple(
words_pieces_list: list,
backoff_id: int,
sp: spm.SentencePieceProcessor,
bonus_per_token: float = 0.1,
):
"""Generate the context graph (in kaldifst format) given
the lexicon of the biasing list.
This context graph is a WFST as in
`https://arxiv.org/abs/1808.02480`
or
`https://wenet.org.cn/wenet/context.html`.
It is simple, as it does not have the capability to detect
word boundaries. So, if a biasing word (e.g., 'us', the country)
happens to be the prefix of another word (e.g., 'useful'),
the word 'useful' will be mistakenly boosted. This is not desired.
However, this context graph is easy to understand.
Args:
words_pieces_list:
A list (batch) of lists. Each sub-list contains the context for
the utterance. The sub-list again is a list of lists. Each sub-sub-list
is the token sequence of a word.
backoff_id:
The id of the backoff token. It serves for failure arcs.
bonus_per_token:
The bonus for each token during decoding, which will hopefully
boost the token up to survive beam search.
Returns:
Return an `openfst` object representing the context graph.
"""
# note: `k2_to_openfst` will multiply it with -1. So it will become +1 in the end.
flip = -1
fsa_list = []
fsa_sizes = []
for words_pieces in words_pieces_list:
start_state = 0
next_state = 1 # the next un-allocated state, will be incremented as we go.
arcs = []
arcs.append([start_state, start_state, backoff_id, 0, 0.0])
# for token_id in range(sp.vocab_size()):
# arcs.append([start_state, start_state, token_id, 0, 0.0])
for tokens in words_pieces:
assert len(tokens) > 0
cur_state = start_state
for i in range(len(tokens) - 1):
arcs.append(
[
cur_state,
next_state,
tokens[i],
0,
flip * bonus_per_token
]
)
arcs.append(
[
next_state,
start_state,
backoff_id,
0,
flip * -bonus_per_token * (i + 1),
]
)
cur_state = next_state
next_state += 1
# now for the last token of this word
i = len(tokens) - 1
arcs.append([cur_state, start_state, tokens[i], 0, flip * bonus_per_token])
final_state = next_state
arcs.append([start_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
fsa = k2.arc_sort(fsa)
fsa_sizes.append((fsa.shape[0], fsa.num_arcs)) # (n_states, n_arcs)
fsa = k2_to_openfst(fsa, olabels="aux_labels")
fsa_list.append(fsa)
return fsa_list, fsa_sizes
def generate_context_graph_nfa(
words_pieces_list: list,
backoff_id: int,
sp: spm.SentencePieceProcessor,
bonus_per_token: float = 0.1,
):
"""Generate the context graph (in kaldifst format) given
the lexicon of the biasing list.
This context graph is a WFST capable of detecting word boundaries.
It is epsilon-free, non-deterministic.
Args:
words_pieces_list:
A list (batch) of lists. Each sub-list contains the context for
the utterance. The sub-list again is a list of lists. Each sub-sub-list
is the token sequence of a word.
backoff_id:
The id of the backoff token. It serves for failure arcs.
bonus_per_token:
The bonus for each token during decoding, which will hopefully
boost the token up to survive beam search.
Returns:
Return an `openfst` object representing the context graph.
"""
# note: `k2_to_openfst` will multiply it with -1. So it will become +1 in the end.
flip = -1
fsa_list = []
fsa_sizes = []
for words_pieces in words_pieces_list:
start_state = 0
# if the path go through this state, then a word boundary is detected
boundary_state = 1
next_state = 2 # the next un-allocated state, will be incremented as we go.
arcs = []
# arcs.append([start_state, start_state, backoff_id, 0, 0.0])
for token_id in range(sp.vocab_size()):
arcs.append([start_state, start_state, token_id, 0, 0.0])
# arcs.append([boundary_state, start_state, token_id, 0, 0.0]) # Note: adding this line here degrades performance. Why? Because it will break the ability to detect word boundary.
for tokens in words_pieces:
assert len(tokens) > 0
cur_state = start_state
# static/constant bonus per token
# my_bonus_per_token = flip * bonus_per_token * biasing_list[word]
# my_bonus_per_token = flip * 1.0 / len(tokens) * biasing_list[word]
my_bonus_per_token = flip * bonus_per_token # TODO: support weighted biasing list
for i in range(len(tokens) - 1):
arcs.append(
[cur_state, next_state, tokens[i], 0, my_bonus_per_token]
)
if i == 0:
arcs.append(
[boundary_state, next_state, tokens[i], 0, my_bonus_per_token]
)
cur_state = next_state
next_state += 1
# now for the last token of this word
i = len(tokens) - 1
arcs.append(
[cur_state, boundary_state, tokens[i], 0, my_bonus_per_token]
)
for token_id in range(sp.vocab_size()):
token = sp.id_to_piece(token_id)
if token.startswith(""):
arcs.append([boundary_state, start_state, token_id, 0, 0.0])
final_state = next_state
arcs.append([start_state, final_state, -1, -1, 0])
arcs.append([boundary_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
fsa = k2.arc_sort(fsa)
# fsa = k2.determinize(fsa) # No weight pushing is needed.
fsa_sizes.append((fsa.shape[0], fsa.num_arcs)) # (n_states, n_arcs)
fsa = k2_to_openfst(fsa, olabels="aux_labels")
fsa_list.append(fsa)
return fsa_list, fsa_sizes
class TrieNode:
def __init__(self, token):
self.token = token
self.is_end = False
self.state_id = None
self.weight = -1e9
self.children = {}
class Trie(object):
# https://albertauyeung.github.io/2020/06/15/python-trie.html/
def __init__(self):
self.root = TrieNode("")
def insert(self, word_tokens, per_token_weight):
"""Insert a word into the trie"""
node = self.root
# Loop through each token in the word
# Check if there is no child containing the token, create a new child for the current node
for token in word_tokens:
if token in node.children:
node = node.children[token]
node.weight = max(per_token_weight, node.weight)
else:
# If a character is not found,
# create a new node in the trie
new_node = TrieNode(token)
node.children[token] = new_node
node = new_node
node.weight = per_token_weight
# Mark the end of a word
node.is_end = True
def generate_context_graph_nfa_trie(
words_pieces_list: list,
backoff_id: int,
sp: spm.SentencePieceProcessor,
bonus_per_token: float = 0.1,
):
"""Generate the context graph (in kaldifst format) given
the lexicon of the biasing list.
This context graph is a WFST capable of detecting word boundaries.
It is epsilon-free, non-deterministic.
It is also optimized such that the words are organized in a trie. (from Kang Wei)
However, this trie may not be good for having a seperate score for each word? Or not?
Args:
words_pieces_list:
A list (batch) of lists. Each sub-list contains the context for
the utterance. The sub-list again is a list of lists. Each sub-sub-list
is the token sequence of a word.
backoff_id:
The id of the backoff token. It serves for failure arcs.
bonus_per_token:
The bonus for each token during decoding, which will hopefully
boost the token up to survive beam search.
Returns:
Return an `openfst` object representing the context graph.
"""
# note: `k2_to_openfst` will multiply it with -1. So it will become +1 in the end.
flip = -1
fsa_list = [] # These are the fsas in a batch
fsa_sizes = []
for words_pieces in words_pieces_list:
start_state = 0
# if the path go through this state, then a word boundary is detected
boundary_state = 1
next_state = 2 # the next un-allocated state, will be incremented as we go.
arcs = []
# arcs.append([start_state, start_state, backoff_id, 0, 0.0])
for token_id in range(sp.vocab_size()):
arcs.append([start_state, start_state, token_id, 0, 0.0])
# First, compile the word list into a trie (prefix tree)
trie = Trie()
for tokens in words_pieces:
# static/constant bonus per token
# my_bonus_per_token = flip * bonus_per_token * biasing_list[word]
# my_bonus_per_token = flip * 1.0 / len(tokens) * biasing_list[word]
my_bonus_per_token = flip * bonus_per_token # TODO: support weighted biasing list
trie.insert(tokens, my_bonus_per_token)
trie.root.state_id = start_state
node_list = [trie.root]
while len(node_list) > 0:
cur_node = node_list.pop(0)
if cur_node.is_end:
continue
for token, child_node in cur_node.children.items():
if child_node.state_id is None and not child_node.is_end:
child_node.state_id = next_state
next_state += 1
if child_node.is_end:
# If the word has finished, go to the boundary state
arcs.append(
[
cur_node.state_id,
boundary_state,
token,
0,
child_node.weight,
]
)
else:
# Otherwise, add an arc from the current node to its child node
arcs.append(
[
cur_node.state_id,
child_node.state_id,
token,
0,
child_node.weight,
]
)
# If this is the first token in a word,
# also add an arc from the boundary node to the child node
if cur_node == trie.root:
arcs.append(
[
boundary_state,
child_node.state_id,
token,
0,
child_node.weight,
]
)
node_list.append(child_node)
for token_id in range(sp.vocab_size()):
token = sp.id_to_piece(token_id)
if token.startswith(""):
arcs.append([boundary_state, start_state, token_id, 0, 0.0])
final_state = next_state
arcs.append([start_state, final_state, -1, -1, 0])
arcs.append([boundary_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
fsa = k2.arc_sort(fsa)
# fsa = k2.determinize(fsa) # No weight pushing is needed.
fsa_sizes.append((fsa.shape[0], fsa.num_arcs)) # (n_states, n_arcs)
fsa = k2_to_openfst(fsa, olabels="aux_labels")
fsa_list.append(fsa)
return fsa_list, fsa_sizes

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,861 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless7/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./pruned_transducer_stateless7/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless7/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search (one best)
./pruned_transducer_stateless7/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest)
./pruned_transducer_stateless7/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless7/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless7/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
# from asr_datamodule import LibriSpeechAsrDataModule
from gigaspeech import GigaSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from gigaspeech_scoring import asr_text_post_processing
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
add_model_arguments(parser)
return parser
def post_processing(
results: List[Tuple[str, List[str], List[str]]],
) -> List[Tuple[str, List[str], List[str]]]:
new_results = []
for key, ref, hyp in results:
new_ref = asr_text_post_processing(" ".join(ref)).split()
new_hyp = asr_text_post_processing(" ".join(hyp)).split()
new_results.append((key, new_ref, new_hyp))
return new_results
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming:
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = post_processing(results)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
"""
This scripts test a libri model with libri BPE
on Gigaspeech.
"""
parser = get_parser()
GigaSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / (params.decoding_method + "_gigaspeech")
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
gigaspeech = GigaSpeechAsrDataModule(args)
dev_cuts = gigaspeech.dev_cuts()
test_cuts = gigaspeech.test_cuts()
dev_dl = gigaspeech.test_dataloaders(dev_cuts)
test_dl = gigaspeech.test_dataloaders(test_cuts)
test_sets = ["dev", "test"]
test_dls = [dev_dl, test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,993 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless7/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./pruned_transducer_stateless7/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless7/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search (one best)
./pruned_transducer_stateless7/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest)
./pruned_transducer_stateless7/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless7/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless7/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(8) modified beam search with RNNLM shallow fusion
./pruned_transducer_stateless7/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--decoding-method modified_beam_search_lm_shallow_fusion \
--beam-size 4 \
--lm-type rnn \
--lm-scale 0.3 \
--lm-exp-dir /path/to/LM \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
(9) modified beam search with LM shallow fusion + LODR
./pruned_transducer_stateless7/decode.py \
--epoch 28 \
--avg 15 \
--max-duration 600 \
--exp-dir ./pruned_transducer_stateless7/exp \
--decoding-method modified_beam_search_LODR \
--beam-size 4 \
--lm-type rnn \
--lm-scale 0.4 \
--lm-exp-dir /path/to/LM \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
--tokens-ngram 2 \
--ngram-lm-scale -0.16 \
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR,
modified_beam_search_ngram_rescoring,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall import LmScorer, NgramLm
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
- modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion
- modified_beam_search_LODR
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--use-shallow-fusion",
type=str2bool,
default=False,
help="""Use neural network LM for shallow fusion.
If you want to use LODR, you will also need to set this to true
""",
)
parser.add_argument(
"--lm-type",
type=str,
default="rnn",
help="Type of NN lm",
choices=["rnn", "transformer"],
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.3,
help="""The scale of the neural network LM
Used only when `--use-shallow-fusion` is set to True.
""",
)
parser.add_argument(
"--tokens-ngram",
type=int,
default=3,
help="""Token Ngram used for rescoring.
Used only when the decoding method is
modified_beam_search_ngram_rescoring, or LODR
""",
)
parser.add_argument(
"--backoff-id",
type=int,
default=500,
help="""ID of the backoff symbol.
Used only when the decoding method is
modified_beam_search_ngram_rescoring""",
)
parser.add_argument(
'--part',
type=str,
default=None,
help='e.g., 1/3'
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
LM: Optional[LmScorer] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
LM:
A neural net LM for shallow fusion. Only used when `--use-shallow-fusion`
set to true.
ngram_lm:
A ngram lm. Used in LODR decoding.
ngram_lm_scale:
The scale of the ngram language model.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
hyp_tokens = modified_beam_search_lm_shallow_fusion(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_LODR":
hyp_tokens = modified_beam_search_LODR(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale,
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
LM: Optional[LmScorer] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
LM:
A neural network LM, used during shallow fusion
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
LM=LM,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
if args.part is not None:
args.part = args.part.split("/")
args.part = (int(args.part[0]), int(args.part[1]))
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
"modified_beam_search_lm_shallow_fusion",
"modified_beam_search_LODR",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if "ngram" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
if params.use_shallow_fusion:
if params.lm_type == "rnn":
params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}"
elif params.lm_type == "transformer":
params.suffix += f"-transformer-lm-scale-{params.lm_scale}"
if "LODR" in params.decoding_method:
params.suffix += (
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
)
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
# only load N-gram LM when needed
if "ngram" in params.decoding_method or "LODR" in params.decoding_method:
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
logging.info(f"lm filename: {lm_filename}")
ngram_lm = NgramLm(
str(params.lang_dir / lm_filename),
backoff_id=params.backoff_id,
is_binary=False,
)
logging.info(f"num states: {ngram_lm.lm.num_states}")
ngram_lm_scale = params.ngram_lm_scale
else:
ngram_lm = None
ngram_lm_scale = None
# only load the neural network LM if doing shallow fusion
if params.use_shallow_fusion:
LM = LmScorer(
lm_type=params.lm_type,
params=params,
device=device,
lm_scale=params.lm_scale,
)
LM.to(device)
LM.eval()
else:
LM = None
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
# test_clean_cuts = librispeech.test_clean_cuts()
# test_other_cuts = librispeech.test_other_cuts()
# test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
# test_other_dl = librispeech.test_dataloaders(test_other_cuts)
# test_sets = ["test-clean", "test-other"]
# test_dl = [test_clean_dl, test_other_dl]
train_cuts = librispeech.train_all_shuf_cuts()
if args.part is not None:
from lhotse import CutSet
part = args.part[0]; n_parts = args.part[1]
# train_cuts = CutSet.from_cuts([c for c in train_cuts if "_sp" not in c.id and int(c.id.split("-")[-1]) % n_parts == part])
train_cuts = CutSet.from_cuts([c for c in train_cuts if int(c.id.split("-")[-2]) % n_parts == part % n_parts])
train_cuts = train_cuts.sort_by_duration(ascending=False)
# train_cuts.describe()
logging.info(f"part-{part}/{n_parts} len(train_cuts) = {len(train_cuts)}")
else:
part = 0
n_parts = 0
# part = "val"
train_dl = librispeech.train_dataloaders(train_cuts)
test_sets = [f"train-{part}.{n_parts}"]
test_dl = [train_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
LM=LM,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,105 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""
def __init__(
self,
vocab_size: int,
decoder_dim: int,
blank_id: int,
context_size: int,
):
"""
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
decoder_dim:
Dimension of the input embedding, and of the decoder output.
blank_id:
The ID of the blank symbol.
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
"""
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=decoder_dim,
)
self.blank_id = blank_id
assert context_size >= 1, context_size
self.context_size = context_size
self.vocab_size = vocab_size
if context_size > 1:
self.conv = nn.Conv1d(
in_channels=decoder_dim,
out_channels=decoder_dim,
kernel_size=context_size,
padding=0,
groups=decoder_dim // 4, # group size == 4
bias=False,
)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
need_pad:
True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference.
Returns:
Return a tensor of shape (N, U, decoder_dim).
"""
y = y.to(torch.int64)
# this stuff about clamp() is a temporary fix for a mismatch
# at utterance start, we use negative ids in beam_search.py
if torch.jit.is_tracing():
# This is for exporting to PNNX via ONNX
embedding_out = self.embedding(y)
else:
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = F.relu(embedding_out)
return embedding_out

View File

@ -0,0 +1,43 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
import torch.nn as nn
class EncoderInterface(nn.Module):
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
A tensor of shape (batch_size, input_seq_len, num_features)
containing the input features.
x_lens:
A tensor of shape (batch_size,) containing the number of frames
in `x` before padding.
Returns:
Return a tuple containing two tensors:
- encoder_out, a tensor of (batch_size, out_seq_len, output_dim)
containing unnormalized probabilities, i.e., the output of a
linear layer.
- encoder_out_lens, a tensor of shape (batch_size,) containing
the number of frames in `encoder_out` before padding.
"""
raise NotImplementedError("Please implement it in a subclass")

View File

@ -0,0 +1,560 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
"""
This script exports a transducer model from PyTorch to ONNX.
We use the pre-trained model from
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-epoch-30-avg-9.pt"
cd exp
ln -s pretrained-epoch-30-avg-9.pt epoch-99.pt
popd
2. Export the model to ONNX
./pruned_transducer_stateless7/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--feedforward-dims "1024,1024,2048,2048,1024"
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
See ./onnx_pretrained.py and ./onnx_check.py for how to
use the exported ONNX models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, Tuple
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from decoder import Decoder
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from zipformer import Zipformer
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless5/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
class OnnxEncoder(nn.Module):
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear):
"""
Args:
encoder:
A Zipformer encoder.
encoder_proj:
The projection layer for encoder from the joiner.
"""
super().__init__()
self.encoder = encoder
self.encoder_proj = encoder_proj
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Please see the help information of Zipformer.forward
Args:
x:
A 3-D tensor of shape (N, T, C)
x_lens:
A 1-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a tuple containing:
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
- encoder_out_lens, A 1-D tensor of shape (N,)
"""
encoder_out, encoder_out_lens = self.encoder(x, x_lens)
encoder_out = self.encoder_proj(encoder_out)
# Now encoder_out is of shape (N, T, joiner_dim)
return encoder_out, encoder_out_lens
class OnnxDecoder(nn.Module):
"""A wrapper for Decoder and the decoder_proj from the joiner"""
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
super().__init__()
self.decoder = decoder
self.decoder_proj = decoder_proj
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, context_size).
Returns
Return a 2-D tensor of shape (N, joiner_dim)
"""
need_pad = False
decoder_output = self.decoder(y, need_pad=need_pad)
decoder_output = decoder_output.squeeze(1)
output = self.decoder_proj(decoder_output)
return output
class OnnxJoiner(nn.Module):
"""A wrapper for the joiner"""
def __init__(self, output_linear: nn.Linear):
super().__init__()
self.output_linear = output_linear
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit
def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has two inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- x_lens, a tensor of shape (N,); dtype is torch.int64
and it has two outputs:
- encoder_out, a tensor of shape (N, T', joiner_dim)
- encoder_out_lens, a tensor of shape (N,)
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
torch.onnx.export(
encoder_model,
(x, x_lens),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens"],
output_names=["encoder_out", "encoder_out_lens"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"encoder_out": {0: "N", 1: "T"},
"encoder_out_lens": {0: "N"},
},
)
def export_decoder_model_onnx(
decoder_model: OnnxDecoder,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
torch.onnx.export(
decoder_model,
y,
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
meta_data = {
"context_size": str(context_size),
"vocab_size": str(vocab_size),
}
add_meta_data(filename=decoder_filename, meta_data=meta_data)
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- encoder_out: a tensor of shape (N, joiner_dim)
- decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
"""
joiner_dim = joiner_model.output_linear.weight.shape[1]
logging.info(f"joiner dim: {joiner_dim}")
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
meta_data = {
"joiner_dim": str(joiner_dim),
}
add_meta_data(filename=joiner_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
setup_logger(f"{params.exp_dir}/log-export/log-export-onnx")
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True)
encoder = OnnxEncoder(
encoder=model.encoder,
encoder_proj=model.joiner.encoder_proj,
)
decoder = OnnxDecoder(
decoder=model.decoder,
decoder_proj=model.joiner.decoder_proj,
)
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
if params.iter > 0:
suffix = f"iter-{params.iter}"
else:
suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}"
opset_version = 13
logging.info("Exporting encoder")
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
export_encoder_model_onnx(
encoder,
encoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported encoder to {encoder_filename}")
logging.info("Exporting decoder")
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
export_decoder_model_onnx(
decoder,
decoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported decoder to {decoder_filename}")
logging.info("Exporting joiner")
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
export_joiner_model_onnx(
joiner,
joiner_filename,
opset_version=opset_version,
)
logging.info(f"Exported joiner to {joiner_filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
main()

View File

@ -0,0 +1,320 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
(1) Export to torchscript model using torch.jit.script()
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("cpu_jit.pt")`.
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
Check
https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
(2) Export `model.state_dict()`
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
To use the generated file with `pruned_transducer_stateless7/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./pruned_transducer_stateless7/decode.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate a file named cpu_jit.pt
Check ./jit_pretrained.py for how to use it.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
if params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torchscript. Export model.state_dict()")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,282 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) use the averaged model with checkpoint exp_dir/epoch-xxx.pt
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
--epoch 28 \
--avg 15 \
--use-averaged-model True \
--exp-dir ./pruned_transducer_stateless7/exp
It will generate a file `epoch-28-avg-15-use-averaged-model.pt` in the given `exp_dir`.
You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt")`.
(2) use the averaged model with checkpoint exp_dir/checkpoint-iter.pt
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
--iter 22000 \
--avg 5 \
--use-averaged-model True \
--exp-dir ./pruned_transducer_stateless7/exp
It will generate a file `iter-22000-avg-5-use-averaged-model.pt` in the given `exp_dir`.
You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt")`.
(3) use the original model with checkpoint exp_dir/epoch-xxx.pt
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
--epoch 28 \
--avg 15 \
--use-averaged-model False \
--exp-dir ./pruned_transducer_stateless7/exp
It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
You can later load it by `torch.load("epoch-28-avg-15.pt")`.
(4) use the original model with checkpoint exp_dir/checkpoint-iter.pt
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
--iter 22000 \
--avg 5 \
--use-averaged-model False \
--exp-dir ./pruned_transducer_stateless7/exp
It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`.
You can later load it by `torch.load("iter-22000-avg-5.pt")`.
"""
import argparse
from pathlib import Path
from typing import Dict, List
import sentencepiece as spm
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model."
"If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
print("Script started")
device = torch.device("cpu")
print(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
print("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
print(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt"
torch.save({"model": model.state_dict()}, filename)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save({"model": model.state_dict()}, filename)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
print(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save({"model": model.state_dict()}, filename)
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
print(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
filename = (
params.exp_dir
/ f"iter-{params.iter}-avg-{params.avg}-use-averaged-model.pt"
)
torch.save({"model": model.state_dict()}, filename)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
print(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
filename = (
params.exp_dir
/ f"epoch-{params.epoch}-avg-{params.avg}-use-averaged-model.pt"
)
torch.save({"model": model.state_dict()}, filename)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
print("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,406 @@
# Copyright 2021 Piotr Żelasko
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class GigaSpeechAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it "
"with training dataset. ",
)
# GigaSpeech specific arguments
group.add_argument(
"--subset",
type=str,
default="XL",
help="Select the GigaSpeech subset (XS|S|M|L|XL)",
)
group.add_argument(
"--small-dev",
type=str2bool,
default=False,
help="Should we use only 1000 utterances for dev (speeds up training)",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=True,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info(f"About to get train_{self.args.subset} cuts")
path = self.args.manifest_dir / f"cuts_{self.args.subset}.jsonl.gz"
cuts_train = CutSet.from_jsonl_lazy(path)
return cuts_train
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
if self.args.small_dev:
return cuts_valid.subset(first=1000)
else:
return cuts_valid
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz")

View File

@ -0,0 +1,115 @@
#!/usr/bin/env python3
# Copyright 2021 Jiayu Du
# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
conversational_filler = [
"UH",
"UHH",
"UM",
"EH",
"MM",
"HM",
"AH",
"HUH",
"HA",
"ER",
"OOF",
"HEE",
"ACH",
"EEE",
"EW",
]
unk_tags = ["<UNK>", "<unk>"]
gigaspeech_punctuations = [
"<COMMA>",
"<PERIOD>",
"<QUESTIONMARK>",
"<EXCLAMATIONPOINT>",
]
gigaspeech_garbage_utterance_tags = ["<SIL>", "<NOISE>", "<MUSIC>", "<OTHER>"]
non_scoring_words = (
conversational_filler
+ unk_tags
+ gigaspeech_punctuations
+ gigaspeech_garbage_utterance_tags
)
def asr_text_post_processing(text: str) -> str:
# 1. convert to uppercase
text = text.upper()
# 2. remove hyphen
# "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART"
text = text.replace("-", " ")
# 3. remove non-scoring words from evaluation
remaining_words = []
for word in text.split():
if word in non_scoring_words:
continue
remaining_words.append(word)
return " ".join(remaining_words)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="This script evaluates GigaSpeech ASR result via"
"SCTK's tool sclite"
)
parser.add_argument(
"ref",
type=str,
help="sclite's standard transcription(trn) reference file",
)
parser.add_argument(
"hyp",
type=str,
help="sclite's standard transcription(trn) hypothesis file",
)
parser.add_argument(
"work_dir",
type=str,
help="working dir",
)
args = parser.parse_args()
if not os.path.isdir(args.work_dir):
os.mkdir(args.work_dir)
REF = os.path.join(args.work_dir, "REF")
HYP = os.path.join(args.work_dir, "HYP")
RESULT = os.path.join(args.work_dir, "RESULT")
for io in [(args.ref, REF), (args.hyp, HYP)]:
with open(io[0], "r", encoding="utf8") as fi:
with open(io[1], "w+", encoding="utf8") as fo:
for line in fi:
line = line.strip()
if line:
cols = line.split()
text = asr_text_post_processing(" ".join(cols[0:-1]))
uttid_field = cols[-1]
print(f"{text} {uttid_field}", file=fo)
# GigaSpeech's uttid comforms to swb
os.system(f"sclite -r {REF} trn -h {HYP} trn -i swb | tee {RESULT}")

View File

@ -0,0 +1,272 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads torchscript models, exported by `torch.jit.script()`
and uses them to decode waves.
You can use the following command to get the exported models:
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10 \
--jit 1
Usage of this script:
./pruned_transducer_stateless7/jit_pretrained.py \
--nn-model-filename ./pruned_transducer_stateless7/exp/cpu_jit.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nn-model-filename",
type=str,
required=True,
help="Path to the torchscript model cpu_jit.pt",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float = 16000
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(
model: torch.jit.ScriptModule,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
A 3-D tensor of shape (N, T, C)
encoder_out_lens:
A 1-D tensor of shape (N,).
Returns:
Return the decoded results for each utterance.
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
device = encoder_out.device
blank_id = 0 # hard-code to 0
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (N, context_size)
decoder_out = model.decoder(
decoder_input,
need_pad=torch.tensor([False]),
).squeeze(1)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = packed_encoder_out.data[start:end]
current_encoder_out = current_encoder_out
# current_encoder_out's shape: (batch_size, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = model.joiner(
current_encoder_out,
decoder_out,
)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=torch.tensor([False]),
)
decoder_out = decoder_out.squeeze(1)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = torch.jit.load(args.nn_model_filename)
model.eval()
model.to(device)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(
x=features,
x_lens=feature_lengths,
)
hyps = greedy_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = sp.decode(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,64 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
class Joiner(nn.Module):
def __init__(
self,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
):
super().__init__()
self.encoder_proj = nn.Linear(encoder_dim, joiner_dim)
self.decoder_proj = nn.Linear(decoder_dim, joiner_dim)
self.output_linear = nn.Linear(joiner_dim, vocab_size)
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
project_input: bool = True,
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, s_range, C).
decoder_out:
Output from the decoder. Its shape is (N, T, s_range, C).
project_input:
If true, apply input projections encoder_proj and decoder_proj.
If this is false, it is the user's responsibility to do this
manually.
Returns:
Return a tensor of shape (N, T, s_range, C).
"""
assert encoder_out.ndim == decoder_out.ndim
assert encoder_out.ndim in (2, 4)
if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
else:
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit

View File

@ -0,0 +1,242 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import penalize_abs_values_gt
from icefall.utils import add_sos
from typing import Union, List
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
context_encoder: nn.Module,
encoder_biasing_adapter: nn.Module,
decoder_biasing_adapter: nn.Module,
):
"""
Args:
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id")
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
self.context_encoder = context_encoder
self.encoder_biasing_adapter = encoder_biasing_adapter
self.decoder_biasing_adapter = decoder_biasing_adapter
self.simple_am_proj = nn.Linear(
encoder_dim,
vocab_size,
)
self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
# For temporary convenience
self.scratch_space = None
self.no_encoder_biasing = None
self.no_decoder_biasing = None
self.no_wfst_lm_biasing = None
self.params = None
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
contexts: dict,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
word_list:
A list of words, where each word is a list of token ids.
The list of tokens for each word has been padded.
word_lengths:
The number of tokens per word
num_words_per_utt:
The number of words in the context for each utterance
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Returns:
Return the transducer loss.
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
# breakpoint()
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
contexts_h, contexts_mask = self.context_encoder.embed_contexts(
contexts
)
# assert x.size(0) == contexts_h.size(0) == contexts_mask.size(0)
# assert contexts_h.ndim == 3
# assert contexts_h.ndim == 2
if self.params.irrelevance_learning:
need_weights = True
else:
need_weights = False
encoder_biasing_out, attn_enc = self.encoder_biasing_adapter.forward(encoder_out, contexts_h, contexts_mask, need_weights=need_weights)
encoder_out = encoder_out + encoder_biasing_out
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
if self.context_encoder.bi_encoders:
contexts_dec_h, contexts_dec_mask = self.context_encoder.embed_contexts(
contexts,
is_encoder_side=False,
)
else:
contexts_dec_h, contexts_dec_mask = contexts_h, contexts_mask
decoder_biasing_out, attn_dec = self.decoder_biasing_adapter.forward(decoder_out, contexts_dec_h, contexts_dec_mask, need_weights=need_weights)
decoder_out = decoder_out + decoder_biasing_out
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return (simple_loss, pruned_loss)

View File

@ -0,0 +1,146 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage: ./pruned_transducer_stateless7/my_profile.py
"""
import argparse
import logging
import sentencepiece as spm
import torch
from typing import Tuple
from torch import Tensor, nn
from icefall.profiler import get_model_profile
from scaling import BasicNorm, DoubleSwish
from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
add_model_arguments(parser)
return parser
def _basic_norm_flops_compute(module, input, output):
assert len(input) == 1, len(input)
# estimate as layer_norm, see icefall/profiler.py
flops = input[0].numel() * 5
module.__flops__ += int(flops)
def _doubleswish_module_flops_compute(module, input, output):
# For DoubleSwish
assert len(input) == 1, len(input)
# estimate as swish/silu, see icefall/profiler.py
flops = input[0].numel()
module.__flops__ += int(flops)
MODULE_HOOK_MAPPING = {
BasicNorm: _basic_norm_flops_compute,
DoubleSwish: _doubleswish_module_flops_compute,
}
class Model(nn.Module):
"""A Wrapper for encoder and encoder_proj"""
def __init__(
self,
encoder: nn.Module,
encoder_proj: nn.Module,
) -> None:
super().__init__()
self.encoder = encoder
self.encoder_proj = encoder_proj
def forward(self, feature: Tensor, feature_lens: Tensor) -> Tuple[Tensor, Tensor]:
encoder_out, encoder_out_lens = self.encoder(feature, feature_lens)
logits = self.encoder_proj(encoder_out)
return logits, encoder_out_lens
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
# We only profile the encoder part
model = Model(
encoder=get_encoder_model(params),
encoder_proj=get_joiner_model(params).encoder_proj,
)
model.eval()
model.to(device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# for 30-second input
B, T, D = 1, 3000, 80
feature = torch.ones(B, T, D, dtype=torch.float32).to(device)
feature_lens = torch.full((B,), T, dtype=torch.int64).to(device)
flops, params = get_model_profile(
model=model,
args=(feature, feature_lens),
module_hoop_mapping=MODULE_HOOK_MAPPING,
)
logging.info(f"For the encoder part, params: {params}, flops: {flops}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,634 @@
import torch
import torch.nn as nn
import abc
import copy
from collections import OrderedDict
class Ffn(nn.Module):
def __init__(self, input_dim, hidden_dim, out_dim, nlayers=1, drop_out=0.1, skip=False) -> None:
super().__init__()
layers = []
for ilayer in range(nlayers):
_in = hidden_dim if ilayer > 0 else input_dim
_out = hidden_dim if ilayer < nlayers - 1 else out_dim
layers.extend([
nn.Linear(_in, _out),
# nn.ReLU(),
# nn.Sigmoid(),
nn.Tanh(),
nn.Dropout(p=drop_out),
])
self.ffn = torch.nn.Sequential(
*layers,
)
self.skip = skip
def forward(self, x) -> torch.Tensor:
x_out = self.ffn(x)
if self.skip:
x_out = x_out + x
return x_out
class ContextEncoder(torch.nn.Module):
def __init__(self):
super(ContextEncoder, self).__init__()
self.stats_num_distractors_per_utt = 0
self.stats_num_utt = 0
@abc.abstractmethod
def forward(
self,
word_list,
word_lengths,
is_encoder_side=None,
):
pass
def embed_contexts(
self,
contexts,
is_encoder_side=None,
):
"""
Args:
contexts:
The contexts, see below for details
Returns:
final_h:
A tensor of shape (batch_size, max(num_words_per_utt) + 1, joiner_dim),
which is the embedding for each context word.
mask_h:
A tensor of shape (batch_size, max(num_words_per_utt) + 1),
which contains a True/False mask for final_h
"""
if contexts["mode"] == "get_context_word_list":
"""
word_list:
Option1: A list of words, where each word is a list of token ids.
The list of tokens for each word has been padded.
Option2: A list of words, where each word is an embedding.
word_lengths:
Option1: The number of tokens per word
Option2: None
num_words_per_utt:
The number of words in the context for each utterance
"""
word_list, word_lengths, num_words_per_utt = \
contexts["word_list"], contexts["word_lengths"], contexts["num_words_per_utt"]
assert word_lengths is None or word_list.size(0) == len(word_lengths)
batch_size = len(num_words_per_utt)
elif contexts["mode"] == "get_context_word_list_shared":
"""
word_list:
Option1: A list of words, where each word is a list of token ids.
The list of tokens for each word has been padded.
Option2: A list of words, where each word is an embedding.
word_lengths:
Option1: The number of tokens per word
Option2: None
positive_mask_list:
For each utterance, it contains a list of indices of the words should be masked
"""
# word_list, word_lengths, positive_mask_list = \
# contexts["word_list"], contexts["word_lengths"], contexts["positive_mask_list"]
# batch_size = len(positive_mask_list)
word_list, word_lengths, num_words_per_utt = \
contexts["word_list"], contexts["word_lengths"], contexts["num_words_per_utt"]
batch_size = len(num_words_per_utt)
assert word_lengths is None or word_list.size(0) == len(word_lengths)
else:
raise NotImplementedError
# print(f"word_list.shape={word_list.shape}")
final_h = self.forward(word_list, word_lengths, is_encoder_side=is_encoder_side)
if contexts["mode"] == "get_context_word_list":
final_h = torch.split(final_h, num_words_per_utt)
final_h = torch.nn.utils.rnn.pad_sequence(
final_h,
batch_first=True,
padding_value=0.0
)
# print(f"final_h.shape={final_h.shape}")
# add one no-bias token
no_bias_h = torch.zeros(final_h.shape[0], 1, final_h.shape[-1])
no_bias_h = no_bias_h.to(final_h.device)
final_h = torch.cat((no_bias_h, final_h), dim=1)
# print(final_h)
# https://stackoverflow.com/questions/53403306/how-to-batch-convert-sentence-lengths-to-masks-in-pytorch
mask_h = torch.arange(max(num_words_per_utt) + 1)
mask_h = mask_h.expand(len(num_words_per_utt), max(num_words_per_utt) + 1) > torch.Tensor(num_words_per_utt).unsqueeze(1)
mask_h = mask_h.to(final_h.device)
num_utt = len(num_words_per_utt)
self.stats_num_distractors_per_utt = len(word_list) / (num_utt + self.stats_num_utt) + self.stats_num_utt / (num_utt + self.stats_num_utt) * self.stats_num_distractors_per_utt
self.stats_num_utt += num_utt
elif contexts["mode"] == "get_context_word_list_shared":
no_bias_h = torch.zeros(1, final_h.shape[-1])
no_bias_h = no_bias_h.to(final_h.device)
final_h = torch.cat((no_bias_h, final_h), dim=0)
final_h = final_h.expand(batch_size, -1, -1)
# mask_h = torch.full(False, (batch_size, final_h.shape(1))) # TODO
# for i, my_mask in enumerate(positive_mask_list):
# if len(my_mask) > 0:
# my_mask = torch.Tensor(my_mask, dtype=int)
# my_mask += 1
# mask_h[i][my_mask] = True
mask_h = None
num_utt = len(num_words_per_utt)
self.stats_num_distractors_per_utt = num_utt / (num_utt + self.stats_num_utt) * len(word_list) + self.stats_num_utt / (num_utt + self.stats_num_utt) * self.stats_num_distractors_per_utt
self.stats_num_utt += num_utt
# TODO: validate this shape is correct:
# final_h: batch_size * max_num_words_per_utt + 1 * dim
# mask_h: batch_size * max_num_words_per_utt + 1
return final_h, mask_h
def clustering(self):
pass
def cache(self):
pass
class ContextEncoderLSTM(ContextEncoder):
def __init__(
self,
vocab_size: int = None,
context_encoder_dim: int = None,
embedding_layer: nn.Module = None,
output_dim: int = None,
num_layers: int = None,
num_directions: int = None,
drop_out: float = 0.1,
bi_encoders: bool = False,
):
super(ContextEncoderLSTM, self).__init__()
self.num_layers = num_layers
self.num_directions = num_directions
self.context_encoder_dim = context_encoder_dim
if embedding_layer is not None:
# self.embed = embedding_layer
self.embed = torch.nn.Embedding(
vocab_size,
context_encoder_dim
)
self.embed.weight.data = embedding_layer.weight.data
else:
self.embed = torch.nn.Embedding(
vocab_size,
context_encoder_dim
)
self.rnn = torch.nn.LSTM(
input_size=self.embed.weight.shape[1],
hidden_size=context_encoder_dim,
num_layers=self.num_layers,
batch_first=True,
bidirectional=(self.num_directions == 2),
dropout=0.1 if self.num_layers > 1 else 0
)
self.linear = torch.nn.Linear(
context_encoder_dim * self.num_directions,
output_dim
)
self.drop_out = torch.nn.Dropout(drop_out)
# TODO: Do we need some relu layer?
# https://galhever.medium.com/sentiment-analysis-with-pytorch-part-4-lstm-bilstm-model-84447f6c4525
# self.relu = nn.ReLU()
# self.dropout = nn.Dropout(dropout)
self.bi_encoders = bi_encoders
if bi_encoders:
# Create the decoder/predictor side of the context encoder
self.embed_dec = copy.deepcopy(self.embed)
self.rnn_dec = copy.deepcopy(self.rnn)
self.linear_dec = copy.deepcopy(self.linear)
self.drop_out_dec = copy.deepcopy(self.drop_out)
def forward(
self,
word_list,
word_lengths,
is_encoder_side=None,
):
if is_encoder_side is None or is_encoder_side is True:
embed = self.embed
rnn = self.rnn
linear = self.linear
drop_out = self.drop_out
else:
embed = self.embed_dec
rnn = self.rnn_dec
linear = self.linear_dec
drop_out = self.drop_out_dec
out = embed(word_list)
# https://stackoverflow.com/questions/51030782/why-do-we-pack-the-sequences-in-pytorch
out = torch.nn.utils.rnn.pack_padded_sequence(
out,
batch_first=True,
lengths=word_lengths,
enforce_sorted=False
)
output, (hn, cn) = rnn(out) # use default all zeros (h_0, c_0)
# # https://discuss.pytorch.org/t/bidirectional-3-layer-lstm-hidden-output/41336/4
# final_state = hn.view(
# self.num_layers,
# self.num_directions,
# word_list.shape[0],
# self.encoder_dim,
# )[-1] # Only the last layer
# h_1, h_2 = final_state[0], final_state[1]
# # X = h_1 + h_2 # Add both states (needs different input size for first linear layer)
# final_h = torch.cat((h_1, h_2), dim=1) # Concatenate both states
# final_h = self.linear(final_h)
# hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]
# outputs are always from the last layer.
# hidden[-2, :, : ] is the last of the forwards RNN
# hidden[-1, :, : ] is the last of the backwards RNN
h_1, h_2 = hn[-2, :, : ] , hn[-1, :, : ]
final_h = torch.cat((h_1, h_2), dim=1) # Concatenate both states
final_h = linear(final_h)
# final_h = drop_out(final_h)
return final_h
class SimpleGLU(nn.Module):
def __init__(self):
super(SimpleGLU, self).__init__()
# Initialize the learnable parameter 'a'
self.a = nn.Parameter(torch.tensor(0.0))
def forward(self, x):
# Perform the operation a * x
return self.a * x
class BiasingModule(torch.nn.Module):
def __init__(
self,
query_dim,
qkv_dim=64,
num_heads=4,
):
super(BiasingModule, self).__init__()
self.proj_in1 = nn.Linear(query_dim, qkv_dim)
self.proj_in2 = Ffn(
input_dim=qkv_dim,
hidden_dim=qkv_dim,
out_dim=qkv_dim,
skip=True,
drop_out=0.1,
nlayers=2,
)
self.multihead_attn = torch.nn.MultiheadAttention(
embed_dim=qkv_dim,
num_heads=num_heads,
# kdim=64,
# vdim=64,
batch_first=True,
)
self.proj_out1 = Ffn(
input_dim=qkv_dim,
hidden_dim=qkv_dim,
out_dim=qkv_dim,
skip=True,
drop_out=0.1,
nlayers=2,
)
self.proj_out2 = nn.Linear(qkv_dim, query_dim)
self.glu = nn.GLU()
# self.glu = SimpleGLU()
self.contexts = None
self.contexts_mask = None
# def __init__(
# self,
# query_dim,
# qkv_dim=64,
# num_heads=4,
# ):
# super(BiasingModule, self).__init__()
# self.proj_in1 = Ffn(
# input_dim=query_dim,
# hidden_dim=query_dim,
# out_dim=query_dim,
# skip=True,
# drop_out=0.1,
# nlayers=2,
# )
# self.proj_in2 = nn.Linear(query_dim, qkv_dim)
# self.multihead_attn = torch.nn.MultiheadAttention(
# embed_dim=qkv_dim,
# num_heads=num_heads,
# # kdim=64,
# # vdim=64,
# batch_first=True,
# )
# self.proj_out1 = nn.Linear(qkv_dim, query_dim)
# self.proj_out2 = Ffn(
# input_dim=query_dim,
# hidden_dim=query_dim,
# out_dim=query_dim,
# skip=True,
# drop_out=0.1,
# nlayers=2,
# )
# self.glu = nn.GLU()
# self.contexts = None
# self.contexts_mask = None
# self.attn_output_weights = None
def forward(
self,
queries,
contexts=None,
contexts_mask=None,
need_weights=False,
):
"""
Args:
query:
of shape batch_size * seq_length * query_dim
contexts:
of shape batch_size * max_contexts_size * query_dim
contexts_mask:
of shape batch_size * max_contexts_size
Returns:
attn_output:
of shape batch_size * seq_length * context_dim
"""
if contexts is None:
contexts = self.contexts
if contexts_mask is None:
contexts_mask = self.contexts_mask
_queries = self.proj_in1(queries)
_queries = self.proj_in2(_queries)
# _queries = _queries / 0.01
attn_output, attn_output_weights = self.multihead_attn(
_queries, # query
contexts, # key
contexts, # value
key_padding_mask=contexts_mask,
need_weights=need_weights,
)
biasing_output = self.proj_out1(attn_output)
biasing_output = self.proj_out2(biasing_output)
# apply the gated linear unit
biasing_output = self.glu(biasing_output.repeat(1,1,2))
# biasing_output = self.glu(biasing_output)
# inject contexts here
output = queries + biasing_output
# print(f"query={query.shape}")
# print(f"value={contexts} value.shape={contexts.shape}")
# print(f"attn_output_weights={attn_output_weights} attn_output_weights.shape={attn_output_weights.shape}")
return output, attn_output_weights
def tuple_to_list(t):
if isinstance(t, tuple):
return list(map(tuple_to_list, t))
return t
def list_to_tuple(l):
if isinstance(l, list):
return tuple(map(list_to_tuple, l))
return l
class ContextualSequential(nn.Sequential):
def __init__(self, *args):
super(ContextualSequential, self).__init__(*args)
self.contexts_h = None
self.contexts_mask = None
def set_contexts(self, contexts_h, contexts_masks):
self.contexts_h, self.contexts_mask = contexts_h, contexts_masks
def forward(self, *args, **kwargs):
# print(f"input: {type(args[0])=}, {args[0].shape=}")
is_hf = False
for module in self._modules.values():
module_name = type(module).__name__
# if "AudioEncoder" in module_name:
# args = (module(*args, **kwargs),)
# elif "TextDecoder" in module_name:
# args = (module(*args, **kwargs),)
# elif "WhisperDecoderLayer" in module_name:
# args = (module(*args, **kwargs),)
# elif "WhisperEncoderLayer" in module_name:
# args = (module(*args, **kwargs),)
if "WhisperDecoderLayer" in module_name or "WhisperEncoderLayer" in module_name:
is_hf = True
if "BiasingModule" in module_name:
x = args[0]
while isinstance(x, list) or isinstance(x, tuple):
x = x[0]
if self.contexts_h is not None:
# The contexts are injected here
x, attn_output_weights = module(x, contexts=self.contexts_h, contexts_mask=self.contexts_mask, need_weights=True)
else:
# final_h: batch_size * max_num_words_per_utt + 1 * dim
# mask_h: batch_size * max_num_words_per_utt + 1
batch_size = x.size(0)
contexts_h = torch.zeros(batch_size, 1, module.multihead_attn.embed_dim)
contexts_h = contexts_h.to(x.device)
contexts_mask = torch.zeros(batch_size, 1, dtype=torch.bool)
contexts_mask = contexts_mask.to(x.device)
x, attn_output_weights = module(x, contexts=contexts_h, contexts_mask=contexts_mask, need_weights=True)
args = (x,)
else:
x = module(*args, **kwargs)
while isinstance(x, list) or isinstance(x, tuple):
x = x[0]
args = (x,)
# print(f"output: {type(args[0])=}, {args[0].shape=}")
if is_hf:
return args
else:
return args[0]
def set_contexts_for_model(model, contexts):
# check each module in the model, if it is a class of "ContextualSequential",
# then set the contexts for the module
# if hasattr(model, "context_encoder") and model.context_encoder is not None:
contexts_h, contexts_mask = model.context_encoder.embed_contexts(
contexts
)
for module in model.modules():
if isinstance(module, ContextualSequential):
module.set_contexts(contexts_h, contexts_mask)
def get_contextual_model(model, encoder_biasing_layers="31,", decoder_biasing_layers="31,", context_dim=128) -> nn.Module:
# context_dim = 128 # 1.5%
# context_dim = 256 # 5.22% => seems better?
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# [(n, param.requires_grad) for n, param in context_encoder.named_parameters()]
# [(n, p.numel()) for n, p in model.named_parameters() if p.requires_grad]
print("Before neural biasing:")
print(f"{total_params=}")
print(f"{trainable_params=} ({trainable_params/total_params*100:.2f}%)")
encoder_biasing_layers = [int(l) for l in encoder_biasing_layers.strip().split(",") if len(l) > 0]
decoder_biasing_layers = [int(l) for l in decoder_biasing_layers.strip().split(",") if len(l) > 0]
if len(encoder_biasing_layers) > 0 or len(decoder_biasing_layers) > 0:
if hasattr(model, "model"): # Hugegingface models
embedding_layer = model.model.decoder.embed_tokens
else:
embedding_layer = model.decoder.token_embedding
context_encoder = ContextEncoderLSTM(
# vocab_size=embedding_layer.weight.shape[0],
embedding_layer=embedding_layer,
# context_encoder_dim=int(params.encoder_dims.split(",")[-1]),
context_encoder_dim=context_dim,
output_dim=context_dim,
num_layers=2,
num_directions=2,
drop_out=0.1,
)
model.context_encoder = context_encoder
# encoder_biasing_adapter = BiasingModule(
# query_dim=int(params.encoder_dims.split(",")[-1]),
# qkv_dim=context_dim,
# num_heads=4,
# )
if hasattr(model, "model") and hasattr(model.model.encoder, "layers"): # Huggingface models
for i, layer in enumerate(model.model.encoder.layers):
if i in encoder_biasing_layers:
layer_output_dim = layer.final_layer_norm.normalized_shape[0]
model.model.encoder.layers[i] = ContextualSequential(OrderedDict([
("layer", layer),
("biasing_adapter", BiasingModule(
query_dim=layer_output_dim,
qkv_dim=context_dim,
num_heads=4,
))
]))
elif hasattr(model.encoder, "blocks"): # OpenAI models
for i, layer in enumerate(model.encoder.blocks):
if i in encoder_biasing_layers:
layer_output_dim = layer.mlp_ln.normalized_shape[0]
model.encoder.blocks[i] = ContextualSequential(OrderedDict([
("layer", layer),
("biasing_adapter", BiasingModule(
query_dim=layer_output_dim,
qkv_dim=context_dim,
num_heads=4,
))
]))
else:
raise NotImplementedError
if hasattr(model, "model") and hasattr(model.model.decoder, "layers"):
for i, layer in enumerate(model.model.decoder.layers):
if i in decoder_biasing_layers:
layer_output_dim = layer.final_layer_norm.normalized_shape[0]
model.model.decoder.layers[i] = ContextualSequential(OrderedDict([
("layer", layer),
("biasing_adapter", BiasingModule(
query_dim=layer_output_dim,
qkv_dim=context_dim,
num_heads=4,
))
]))
elif hasattr(model.decoder, "blocks"):
for i, layer in enumerate(model.decoder.blocks):
if i in decoder_biasing_layers:
layer_output_dim = layer.mlp_ln.normalized_shape[0]
model.decoder.blocks[i] = ContextualSequential(OrderedDict([
("layer", layer),
("biasing_adapter", BiasingModule(
query_dim=layer_output_dim,
qkv_dim=context_dim,
num_heads=4,
))
]))
else:
raise NotImplementedError
# Freeze the model params
# exception_types = (BiasingModule, ContextEncoderLSTM)
for name, param in model.named_parameters():
# Check if the parameter belongs to a layer of the specified types
if "biasing_adapter" in name:
param.requires_grad = True
elif "context_encoder" in name and "context_encoder.embed" not in name: # We will not fine-tune the embedding layer, which comes from the original model
param.requires_grad = True
elif "context_encoder" in name and "context_encoder.embed" in name: # Debug
param.requires_grad = False
else:
param.requires_grad = False
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# [(n, param.requires_grad) for n, param in context_encoder.named_parameters()]
# [(n, p.numel()) for n, p in model.named_parameters() if p.requires_grad]
print("Neural biasing:")
print(f"{total_params=}")
print(f"{trainable_params=} ({trainable_params/total_params*100:.2f}%)")
return model
# # Test:
# import whisper, torch; device = torch.device("cuda", 0)
# model = whisper.load_model("large-v2", is_ctx=True, device=device)
# from neural_biasing import get_contextual_model
# model1 = get_contextual_model(model)

View File

@ -0,0 +1,239 @@
#!/usr/bin/env python3
#
# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script checks that exported onnx models produce the same output
with the given torchscript model for the same input.
We use the pre-trained model from
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-iter-1224000-avg-14.pt"
cd exp
ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt
popd
2. Export the model via torchscript (torch.jit.script())
./pruned_transducer_stateless3/export.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 9999 \
--avg 1 \
--exp-dir $repo/exp/ \
--jit 1
It will generate the following file in $repo/exp:
- cpu_jit.pt
3. Export the model to ONNX
./pruned_transducer_stateless3/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 9999 \
--avg 1 \
--exp-dir $repo/exp/
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-9999-avg-1.onnx
- decoder-epoch-9999-avg-1.onnx
- joiner-epoch-9999-avg-1.onnx
4. Run this file
./pruned_transducer_stateless3/onnx_check.py \
--jit-filename $repo/exp/cpu_jit.pt \
--onnx-encoder-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \
--onnx-decoder-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \
--onnx-joiner-filename $repo/exp/joiner-epoch-9999-avg-1.onnx
"""
import argparse
import logging
from icefall import is_module_available
from onnx_pretrained import OnnxModel
import torch
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--jit-filename",
required=True,
type=str,
help="Path to the torchscript model",
)
parser.add_argument(
"--onnx-encoder-filename",
required=True,
type=str,
help="Path to the onnx encoder model",
)
parser.add_argument(
"--onnx-decoder-filename",
required=True,
type=str,
help="Path to the onnx decoder model",
)
parser.add_argument(
"--onnx-joiner-filename",
required=True,
type=str,
help="Path to the onnx joiner model",
)
return parser
def test_encoder(
torch_model: torch.jit.ScriptModule,
onnx_model: OnnxModel,
):
C = 80
for i in range(3):
N = torch.randint(low=1, high=20, size=(1,)).item()
T = torch.randint(low=30, high=50, size=(1,)).item()
logging.info(f"test_encoder: iter {i}, N={N}, T={T}")
x = torch.rand(N, T, C)
x_lens = torch.randint(low=30, high=T + 1, size=(N,))
x_lens[0] = T
torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens)
torch_encoder_out = torch_model.joiner.encoder_proj(torch_encoder_out)
onnx_encoder_out, onnx_encoder_out_lens = onnx_model.run_encoder(x, x_lens)
assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-05), (
(torch_encoder_out - onnx_encoder_out).abs().max()
)
def test_decoder(
torch_model: torch.jit.ScriptModule,
onnx_model: OnnxModel,
):
context_size = onnx_model.context_size
vocab_size = onnx_model.vocab_size
for i in range(10):
N = torch.randint(1, 100, size=(1,)).item()
logging.info(f"test_decoder: iter {i}, N={N}")
x = torch.randint(
low=1,
high=vocab_size,
size=(N, context_size),
dtype=torch.int64,
)
torch_decoder_out = torch_model.decoder(x, need_pad=torch.tensor([False]))
torch_decoder_out = torch_model.joiner.decoder_proj(torch_decoder_out)
torch_decoder_out = torch_decoder_out.squeeze(1)
onnx_decoder_out = onnx_model.run_decoder(x)
assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), (
(torch_decoder_out - onnx_decoder_out).abs().max()
)
def test_joiner(
torch_model: torch.jit.ScriptModule,
onnx_model: OnnxModel,
):
encoder_dim = torch_model.joiner.encoder_proj.weight.shape[1]
decoder_dim = torch_model.joiner.decoder_proj.weight.shape[1]
for i in range(10):
N = torch.randint(1, 100, size=(1,)).item()
logging.info(f"test_joiner: iter {i}, N={N}")
encoder_out = torch.rand(N, encoder_dim)
decoder_out = torch.rand(N, decoder_dim)
projected_encoder_out = torch_model.joiner.encoder_proj(encoder_out)
projected_decoder_out = torch_model.joiner.decoder_proj(decoder_out)
torch_joiner_out = torch_model.joiner(encoder_out, decoder_out)
onnx_joiner_out = onnx_model.run_joiner(
projected_encoder_out, projected_decoder_out
)
assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), (
(torch_joiner_out - onnx_joiner_out).abs().max()
)
@torch.no_grad()
def main():
args = get_parser().parse_args()
logging.info(vars(args))
torch_model = torch.jit.load(args.jit_filename)
onnx_model = OnnxModel(
encoder_model_filename=args.onnx_encoder_filename,
decoder_model_filename=args.onnx_decoder_filename,
joiner_model_filename=args.onnx_joiner_filename,
)
logging.info("Test encoder")
test_encoder(torch_model, onnx_model)
logging.info("Test decoder")
test_decoder(torch_model, onnx_model)
logging.info("Test joiner")
test_joiner(torch_model, onnx_model)
logging.info("Finished checking ONNX models")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
# See https://github.com/pytorch/pytorch/issues/38342
# and https://github.com/pytorch/pytorch/issues/33354
#
# If we don't do this, the delay increases whenever there is
# a new request that changes the actual batch size.
# If you use `py-spy dump --pid <server-pid> --native`, you will
# see a lot of time is spent in re-compiling the torch script model.
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._set_graph_executor_optimize(False)
if __name__ == "__main__":
torch.manual_seed(20220727)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,319 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads ONNX exported models and uses them to decode the test sets.
We use the pre-trained model from
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-epoch-30-avg-9.pt"
cd exp
ln -s pretrained-epoch-30-avg-9.pt epoch-9999.pt
popd
2. Export the model to ONNX
./pruned_transducer_stateless7/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 9999 \
--avg 1 \
--exp-dir $repo/exp/
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-9999-avg-1.onnx
- decoder-epoch-9999-avg-1.onnx
- joiner-epoch-9999-avg-1.onnx
2. Run this file
./pruned_transducer_stateless7/onnx_decode.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \
"""
import argparse
import logging
import time
from pathlib import Path
from typing import List, Tuple
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from onnx_pretrained import greedy_search, OnnxModel
from icefall.utils import setup_logger, store_transcripts, write_error_stats
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
return parser
def decode_one_batch(
model: OnnxModel, sp: spm.SentencePieceProcessor, batch: dict
) -> List[List[str]]:
"""Decode one batch and return the result.
Currently it only greedy_search is supported.
Args:
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
Returns:
Return the decoded results for each utterance.
"""
feature = batch["inputs"]
assert feature.ndim == 3
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(dtype=torch.int64)
encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens)
hyps = greedy_search(
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
)
hyps = [sp.decode(h).split() for h in hyps]
return hyps
def decode_dataset(
dl: torch.utils.data.DataLoader,
model: nn.Module,
sp: spm.SentencePieceProcessor,
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
model:
The neural model.
sp:
The BPE model.
Returns:
- A list of tuples. Each tuple contains three elements:
- cut_id,
- reference transcript,
- predicted result.
- The total duration (in seconds) of the dataset.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
log_interval = 10
total_duration = 0
results = []
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
hyps = decode_one_batch(model=model, sp=sp, batch=batch)
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results.extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results, total_duration
def save_results(
res_dir: Path,
test_set_name: str,
results: List[Tuple[str, List[str], List[str]]],
):
recog_path = res_dir / f"recogs-{test_set_name}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = res_dir / f"errs-{test_set_name}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
errs_info = res_dir / f"wer-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("WER", file=f)
print(wer, file=f)
s = "\nFor {}, WER is {}:\n".format(test_set_name, wer)
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
assert (
args.decoding_method == "greedy_search"
), "Only supports greedy_search currently."
res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}"
setup_logger(f"{res_dir}/log-decode")
logging.info("Decoding started")
device = torch.device("cpu")
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
blank_id = sp.piece_to_id("<blk>")
assert blank_id == 0, blank_id
logging.info(vars(args))
logging.info("About to create model")
model = OnnxModel(
encoder_model_filename=args.encoder_model_filename,
decoder_model_filename=args.decoder_model_filename,
joiner_model_filename=args.joiner_model_filename,
)
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
start_time = time.time()
results, total_duration = decode_dataset(dl=test_dl, model=model, sp=sp)
end_time = time.time()
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / total_duration
logging.info(f"Elapsed time: {elapsed_seconds:.3f} s")
logging.info(f"Wave duration: {total_duration:.3f} s")
logging.info(
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
)
save_results(res_dir=res_dir, test_set_name=test_set, results=results)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,417 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads ONNX models and uses them to decode waves.
You can use the following command to get the exported models:
We use the pre-trained model from
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-iter-1224000-avg-14.pt"
cd exp
ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt
popd
2. Export the model to ONNX
./pruned_transducer_stateless3/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 9999 \
--avg 1 \
--exp-dir $repo/exp/
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-9999-avg-1.onnx
- decoder-epoch-9999-avg-1.onnx
- joiner-epoch-9999-avg-1.onnx
3. Run this file
./pruned_transducer_stateless3/onnx_pretrained.py \
--encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
"""
import argparse
import logging
import math
from typing import List, Tuple
import k2
import kaldifeat
import onnxruntime as ort
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--tokens",
type=str,
help="""Path to tokens.txt.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
return parser
class OnnxModel:
def __init__(
self,
encoder_model_filename: str,
decoder_model_filename: str,
joiner_model_filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 4
self.session_opts = session_opts
self.init_encoder(encoder_model_filename)
self.init_decoder(decoder_model_filename)
self.init_joiner(joiner_model_filename)
def init_encoder(self, encoder_model_filename: str):
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
)
def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
)
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
self.context_size = int(decoder_meta["context_size"])
self.vocab_size = int(decoder_meta["vocab_size"])
logging.info(f"context_size: {self.context_size}")
logging.info(f"vocab_size: {self.vocab_size}")
def init_joiner(self, joiner_model_filename: str):
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
)
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
self.joiner_dim = int(joiner_meta["joiner_dim"])
logging.info(f"joiner_dim: {self.joiner_dim}")
def run_encoder(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 3-D tensor of shape (N, T, C)
x_lens:
A 2-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a tuple containing:
- encoder_out, its shape is (N, T', joiner_dim)
- encoder_out_lens, its shape is (N,)
"""
out = self.encoder.run(
[
self.encoder.get_outputs()[0].name,
self.encoder.get_outputs()[1].name,
],
{
self.encoder.get_inputs()[0].name: x.numpy(),
self.encoder.get_inputs()[1].name: x_lens.numpy(),
},
)
return torch.from_numpy(out[0]), torch.from_numpy(out[1])
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
"""
Args:
decoder_input:
A 2-D tensor of shape (N, context_size)
Returns:
Return a 2-D tensor of shape (N, joiner_dim)
"""
out = self.decoder.run(
[self.decoder.get_outputs()[0].name],
{self.decoder.get_inputs()[0].name: decoder_input.numpy()},
)[0]
return torch.from_numpy(out)
def run_joiner(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
out = self.joiner.run(
[self.joiner.get_outputs()[0].name],
{
self.joiner.get_inputs()[0].name: encoder_out.numpy(),
self.joiner.get_inputs()[1].name: decoder_out.numpy(),
},
)[0]
return torch.from_numpy(out)
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(
model: OnnxModel,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
A 3-D tensor of shape (N, T, joiner_dim)
encoder_out_lens:
A 1-D tensor of shape (N,).
Returns:
Return the decoded results for each utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = 0 # hard-code to 0
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
context_size = model.context_size
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor(
hyps,
dtype=torch.int64,
) # (N, context_size)
decoder_out = model.run_decoder(decoder_input)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = packed_encoder_out.data[start:end]
# current_encoder_out's shape: (batch_size, joiner_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = model.run_joiner(current_encoder_out, decoder_out)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
dtype=torch.int64,
)
decoder_out = model.run_decoder(decoder_input)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
model = OnnxModel(
encoder_model_filename=args.encoder_model_filename,
decoder_model_filename=args.decoder_model_filename,
joiner_model_filename=args.joiner_model_filename,
)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
expected_sample_rate=args.sample_rate,
)
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
encoder_out, encoder_out_lens = model.run_encoder(features, feature_lengths)
hyps = greedy_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
s = "\n"
symbol_table = k2.SymbolTable.from_file(args.tokens)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += symbol_table[i]
return text.replace("", " ").strip()
for filename, hyp in zip(args.sound_files, hyps):
words = token_ids_to_words(hyp)
s += f"{filename}:\n{words}\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,355 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads a checkpoint and uses it to decode waves.
You can generate the checkpoint with the following command:
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
Usage of this script:
(1) greedy search
./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) beam search
./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) modified beam search
./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`.
Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by
./pruned_transducer_stateless7/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,214 @@
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file replaces various modules in a model.
Specifically, ActivationBalancer is replaced with an identity operator;
Whiten is also replaced with an identity operator;
BasicNorm is replaced by a module with `exp` removed.
"""
import copy
from typing import List, Tuple
import torch
import torch.nn as nn
from scaling import ActivationBalancer, BasicNorm, Whiten
from zipformer import PoolingModule
class PoolingModuleNoProj(nn.Module):
def forward(
self,
x: torch.Tensor,
cached_len: torch.Tensor,
cached_avg: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A tensor of shape (T, N, C)
cached_len:
A tensor of shape (N,)
cached_avg:
A tensor of shape (N, C)
Returns:
Return a tuple containing:
- new_x
- new_cached_len
- new_cached_avg
"""
x = x.cumsum(dim=0) # (T, N, C)
x = x + (cached_avg * cached_len.unsqueeze(1)).unsqueeze(0)
# Cumulated numbers of frames from start
cum_mask = torch.arange(1, x.size(0) + 1, device=x.device)
cum_mask = cum_mask.unsqueeze(1) + cached_len.unsqueeze(0) # (T, N)
pooling_mask = (1.0 / cum_mask).unsqueeze(2)
# now pooling_mask: (T, N, 1)
x = x * pooling_mask # (T, N, C)
cached_len = cached_len + x.size(0)
cached_avg = x[-1]
return x, cached_len, cached_avg
class PoolingModuleWithProj(nn.Module):
def __init__(self, proj: torch.nn.Module):
super().__init__()
self.proj = proj
self.pooling = PoolingModuleNoProj()
def forward(
self,
x: torch.Tensor,
cached_len: torch.Tensor,
cached_avg: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A tensor of shape (T, N, C)
cached_len:
A tensor of shape (N,)
cached_avg:
A tensor of shape (N, C)
Returns:
Return a tuple containing:
- new_x
- new_cached_len
- new_cached_avg
"""
x, cached_len, cached_avg = self.pooling(x, cached_len, cached_avg)
return self.proj(x), cached_len, cached_avg
def streaming_forward(
self,
x: torch.Tensor,
cached_len: torch.Tensor,
cached_avg: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A tensor of shape (T, N, C)
cached_len:
A tensor of shape (N,)
cached_avg:
A tensor of shape (N, C)
Returns:
Return a tuple containing:
- new_x
- new_cached_len
- new_cached_avg
"""
x, cached_len, cached_avg = self.pooling(x, cached_len, cached_avg)
return self.proj(x), cached_len, cached_avg
class NonScaledNorm(nn.Module):
"""See BasicNorm for doc"""
def __init__(
self,
num_channels: int,
eps_exp: float,
channel_dim: int = -1, # CAUTION: see documentation.
):
super().__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
self.eps_exp = eps_exp
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not torch.jit.is_tracing():
assert x.shape[self.channel_dim] == self.num_channels
scales = (
torch.mean(x * x, dim=self.channel_dim, keepdim=True) + self.eps_exp
).pow(-0.5)
return x * scales
def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm:
assert isinstance(basic_norm, BasicNorm), type(basic_norm)
norm = NonScaledNorm(
num_channels=basic_norm.num_channels,
eps_exp=basic_norm.eps.data.exp().item(),
channel_dim=basic_norm.channel_dim,
)
return norm
def convert_pooling_module(pooling: PoolingModule) -> PoolingModuleWithProj:
assert isinstance(pooling, PoolingModule), type(pooling)
return PoolingModuleWithProj(proj=pooling.proj)
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
# get_submodule was added to nn.Module at v1.9.0
def get_submodule(model, target):
if target == "":
return model
atoms: List[str] = target.split(".")
mod: torch.nn.Module = model
for item in atoms:
if not hasattr(mod, item):
raise AttributeError(
mod._get_name() + " has no " "attribute `" + item + "`"
)
mod = getattr(mod, item)
if not isinstance(mod, torch.nn.Module):
raise AttributeError("`" + item + "` is not " "an nn.Module")
return mod
def convert_scaled_to_non_scaled(
model: nn.Module,
inplace: bool = False,
is_pnnx: bool = False,
):
"""
Args:
model:
The model to be converted.
inplace:
If True, the input model is modified inplace.
If False, the input model is copied and we modify the copied version.
is_pnnx:
True if we are going to export the model for PNNX.
Return:
Return a model without scaled layers.
"""
if not inplace:
model = copy.deepcopy(model)
d = {}
for name, m in model.named_modules():
if isinstance(m, BasicNorm):
d[name] = convert_basic_norm(m)
elif isinstance(m, (ActivationBalancer, Whiten)):
d[name] = nn.Identity()
elif isinstance(m, PoolingModule) and is_pnnx:
d[name] = convert_pooling_module(m)
for k, v in d.items():
if "." in k:
parent, child = k.rsplit(".", maxsplit=1)
setattr(get_submodule(model, parent), child, v)
else:
setattr(model, k, v)
return model

View File

@ -0,0 +1,337 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Taken from: https://github.com/facebookresearch/fbai-speech/blob/main/is21_deep_bias/score.py
from collections import deque
from enum import Enum
import argparse
import logging
import json
from pathlib import Path, PosixPath
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())
class Code(Enum):
match = 1
substitution = 2
insertion = 3
deletion = 4
class AlignmentResult(object):
def __init__(self, refs, hyps, codes, score):
self.refs = refs # deque<int>
self.hyps = hyps # deque<int>
self.codes = codes # deque<Code>
self.score = score # float
class WordError(object):
def __init__(self):
self.errors = {
Code.substitution: 0,
Code.insertion: 0,
Code.deletion: 0,
}
self.ref_words = 0
def get_wer(self):
assert self.ref_words != 0
errors = (
self.errors[Code.substitution]
+ self.errors[Code.insertion]
+ self.errors[Code.deletion]
)
return 100.0 * errors / self.ref_words
def get_result_string(self):
return (
f"error_rate={self.get_wer()}, "
f"ref_words={self.ref_words}, "
f"subs={self.errors[Code.substitution]}, "
f"ins={self.errors[Code.insertion]}, "
f"dels={self.errors[Code.deletion]}"
)
def coordinate_to_offset(row, col, ncols):
return int(row * ncols + col)
def offset_to_row(offset, ncols):
return int(offset / ncols)
def offset_to_col(offset, ncols):
return int(offset % ncols)
class EditDistance(object):
def __init__(self):
self.scores_ = None
self.backtraces_ = None
self.confusion_pairs_ = {}
self.inserted_words_ = {}
self.deleted_words_ = {}
def cost(self, ref, hyp, code):
if code == Code.match:
return 0
elif code == Code.insertion or code == Code.deletion:
return 3
else: # substitution
return 4
def get_result(self, refs, hyps):
res = AlignmentResult(refs=deque(), hyps=deque(), codes=deque(), score=None)
num_rows, num_cols = len(self.scores_), len(self.scores_[0])
res.score = self.scores_[num_rows - 1][num_cols - 1]
curr_offset = coordinate_to_offset(num_rows - 1, num_cols - 1, num_cols)
while curr_offset != 0:
curr_row = offset_to_row(curr_offset, num_cols)
curr_col = offset_to_col(curr_offset, num_cols)
prev_offset = self.backtraces_[curr_row][curr_col]
prev_row = offset_to_row(prev_offset, num_cols)
prev_col = offset_to_col(prev_offset, num_cols)
res.refs.appendleft(curr_row - 1)
res.hyps.appendleft(curr_col - 1)
if curr_row - 1 == prev_row and curr_col == prev_col:
ref_str = refs[res.refs[0]]
deleted_word = ref_str
if deleted_word not in self.deleted_words_:
self.deleted_words_[deleted_word] = 1
else:
self.deleted_words_[deleted_word] += 1
res.codes.appendleft(Code.deletion)
elif curr_row == prev_row and curr_col - 1 == prev_col:
hyp_str = hyps[res.hyps[0]]
inserted_word = hyp_str
if inserted_word not in self.inserted_words_:
self.inserted_words_[inserted_word] = 1
else:
self.inserted_words_[inserted_word] += 1
res.codes.appendleft(Code.insertion)
else:
# assert(curr_row - 1 == prev_row and curr_col - 1 == prev_col)
ref_str = refs[res.refs[0]]
hyp_str = hyps[res.hyps[0]]
if ref_str == hyp_str:
res.codes.appendleft(Code.match)
else:
res.codes.appendleft(Code.substitution)
confusion_pair = "%s -> %s" % (ref_str, hyp_str)
if confusion_pair not in self.confusion_pairs_:
self.confusion_pairs_[confusion_pair] = 1
else:
self.confusion_pairs_[confusion_pair] += 1
curr_offset = prev_offset
return res
def align(self, refs, hyps):
if len(refs) == 0 and len(hyps) == 0:
raise ValueError("Doesn't support empty ref AND hyp!")
# NOTE: we're not resetting the values in these matrices because every value
# will be overridden in the loop below. If this assumption doesn't hold,
# be sure to set all entries in self.scores_ and self.backtraces_ to 0.
self.scores_ = [[0.0] * (len(hyps) + 1) for _ in range(len(refs) + 1)]
self.backtraces_ = [[0] * (len(hyps) + 1) for _ in range(len(refs) + 1)]
num_rows, num_cols = len(self.scores_), len(self.scores_[0])
for i in range(num_rows):
for j in range(num_cols):
if i == 0 and j == 0:
self.scores_[i][j] = 0.0
self.backtraces_[i][j] = 0
continue
if i == 0:
self.scores_[i][j] = self.scores_[i][j - 1] + self.cost(
None, hyps[j - 1], Code.insertion
)
self.backtraces_[i][j] = coordinate_to_offset(i, j - 1, num_cols)
continue
if j == 0:
self.scores_[i][j] = self.scores_[i - 1][j] + self.cost(
refs[i - 1], None, Code.deletion
)
self.backtraces_[i][j] = coordinate_to_offset(i - 1, j, num_cols)
continue
# Below here both i and j are greater than 0
ref = refs[i - 1]
hyp = hyps[j - 1]
best_score = self.scores_[i - 1][j - 1] + (
self.cost(ref, hyp, Code.match)
if ref == hyp
else self.cost(ref, hyp, Code.substitution)
)
prev_row = i - 1
prev_col = j - 1
ins = self.scores_[i][j - 1] + self.cost(None, hyp, Code.insertion)
if ins < best_score:
best_score = ins
prev_row = i
prev_col = j - 1
delt = self.scores_[i - 1][j] + self.cost(ref, None, Code.deletion)
if delt < best_score:
best_score = delt
prev_row = i - 1
prev_col = j
self.scores_[i][j] = best_score
self.backtraces_[i][j] = coordinate_to_offset(
prev_row, prev_col, num_cols
)
return self.get_result(refs, hyps)
def main(args):
refs = {}
if type(args.refs) is str or type(args.refs) is PosixPath:
with open(args.refs, "r") as f:
for line in f:
ary = line.strip().split("\t")
uttid, ref, biasing_words = ary[0], ary[1], set(json.loads(ary[2]))
refs[uttid] = {"text": ref, "biasing_words": biasing_words}
logger.info("Loaded %d reference utts from %s", len(refs), args.refs)
elif type(args.refs) is dict:
refs = args.refs
logger.info("Loaded %d reference utts", len(refs))
else:
raise NotImplementedError
hyps = {}
if type(args.hyps) is str or type(args.hyps) is PosixPath:
with open(args.hyps, "r") as f:
for line in f:
ary = line.strip().split("\t")
# May have empty hypo
if len(ary) >= 2:
uttid, hyp = ary[0], ary[1]
else:
uttid, hyp = ary[0], ""
hyps[uttid] = hyp
logger.info("Loaded %d hypothesis utts from %s", len(hyps), args.hyps)
elif type(args.hyps) is dict:
hyps = args.hyps
logger.info("Loaded %d hypothesis utts", len(hyps))
else:
raise NotImplementedError
if not args.lenient:
for uttid in refs:
if uttid in hyps:
continue
raise ValueError(
f"{uttid} missing in hyps! Set `--lenient` flag to ignore this error."
)
# train_rare_count = dict()
# with open("", "r") as fin:
# for line in fin:
# w, c = line.strip().split()
# train_rare_count[w] = int(c)
test_rare_count = dict()
# Calculate WER, U-WER, and B-WER
wer = WordError()
u_wer = WordError()
b_wer = WordError()
for uttid in refs:
if uttid not in hyps:
continue
ref_tokens = refs[uttid]["text"].split()
biasing_words = refs[uttid]["biasing_words"]
hyp_tokens = hyps[uttid].split()
ed = EditDistance()
result = ed.align(ref_tokens, hyp_tokens)
for code, ref_idx, hyp_idx in zip(result.codes, result.refs, result.hyps):
if code == Code.match:
wer.ref_words += 1
if ref_tokens[ref_idx] in biasing_words:
b_wer.ref_words += 1
else:
u_wer.ref_words += 1
elif code == Code.substitution:
wer.ref_words += 1
wer.errors[Code.substitution] += 1
if ref_tokens[ref_idx] in biasing_words:
b_wer.ref_words += 1
b_wer.errors[Code.substitution] += 1
else:
u_wer.ref_words += 1
u_wer.errors[Code.substitution] += 1
elif code == Code.deletion:
wer.ref_words += 1
wer.errors[Code.deletion] += 1
if ref_tokens[ref_idx] in biasing_words:
b_wer.ref_words += 1
b_wer.errors[Code.deletion] += 1
else:
u_wer.ref_words += 1
u_wer.errors[Code.deletion] += 1
elif code == Code.insertion:
wer.errors[Code.insertion] += 1
if hyp_tokens[hyp_idx] in biasing_words:
b_wer.errors[Code.insertion] += 1
else:
u_wer.errors[Code.insertion] += 1
# Report results
print(f"WER: {wer.get_result_string()}")
print(f"U-WER: {u_wer.get_result_string()}")
print(f"B-WER: {b_wer.get_result_string()}")
print(f"{wer.get_wer():.2f}({u_wer.get_wer():.2f}/{b_wer.get_wer():.2f})")
if __name__ == "__main__":
desc = "Compute WER, U-WER, and B-WER. Results are output to stdout."
parser = argparse.ArgumentParser(description=desc)
parser.add_argument(
"--refs",
required=True,
help="Path to tab-separated reference file. First column is utterance ID. "
"Second column is reference text. Last column is list of biasing words.",
)
parser.add_argument(
"--hyps",
required=True,
help="Path to tab-separated hypothesis file. First column is utterance ID. "
"Second column is hypothesis text.",
)
parser.add_argument(
"--lenient",
action="store_true",
help="If set, hyps doesn't have to cover all of refs.",
)
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,197 @@
#!/usr/bin/env bash
#$ -wd /exp/rhuang/meta/icefall/egs/librispeech/ASR/
#$ -V
#$ -N train_context
#$ -j y -o /exp/rhuang/meta/icefall/egs/librispeech/ASR/log/log-$JOB_NAME-$JOB_ID.out
#$ -M ruizhe@jhu.edu
#$ -m e
#$ -l mem_free=32G,h_rt=600:00:00,gpu=4,hostname=!r7n07*
#$ -q gpu.q@@v100
# #$ -q gpu.q@@v100
# #$ -q gpu.q@@rtx
# #$ -l ram_free=300G,mem_free=300G,gpu=0,hostname=b*
# hostname=b19
# hostname=!c04*&!b*&!octopod*
# hostname
# nvidia-smi
# conda activate /home/hltcoe/rhuang/mambaforge/envs/aligner5
export PATH="/home/hltcoe/rhuang/mambaforge/envs/aligner5/bin/":$PATH
module load cuda11.7/toolkit
module load cudnn/8.5.0.96_cuda11.x
module load nccl/2.13.4-1_cuda11.7
module load gcc/7.2.0
module load intel/mkl/64/2019/5.281
which python
nvcc --version
nvidia-smi
date
# k2
K2_ROOT=/exp/rhuang/meta/k2/
export PYTHONPATH=$K2_ROOT/k2/python:$PYTHONPATH # for `import k2`
export PYTHONPATH=$K2_ROOT/temp.linux-x86_64-cpython-310/lib:$PYTHONPATH # for `import _k2`
export PYTHONPATH=/exp/rhuang/meta/icefall:$PYTHONPATH
# # torchaudio recipe
# cd /exp/rhuang/meta/audio
# cd examples/asr/librispeech_conformer_ctc
# To verify SGE_HGR_gpu and CUDA_VISIBLE_DEVICES match for GPU jobs.
env | grep SGE_HGR_gpu
env | grep CUDA_VISIBLE_DEVICES
echo "hostname: `hostname`"
echo "current path:" `pwd`
# export PYTHONPATH=/exp/rhuang/meta/audio/examples/asr/librispeech_conformer_ctc2:$PYTHONPATH
# exp_dir=/exp/rhuang/meta/icefall/egs/librispeech/ASR/pruned_transducer_stateless7_context_ali/exp/exp_libri # 11073148
# exp_dir=/exp/rhuang/meta/icefall/egs/librispeech/ASR/pruned_transducer_stateless7_context_ali/exp/exp_libri_100 # 11073150
# exp_dir=/exp/rhuang/meta/icefall/egs/librispeech/ASR/pruned_transducer_stateless7_context_ali/exp/exp_libri_100_ts # 11073234, 11073240, 11073243
# exp_dir=/exp/rhuang/meta/icefall/egs/librispeech/ASR/pruned_transducer_stateless7_context_ali/exp/exp_libri_ts # 11073238, 11073255, log-train-2024-01-13-20-11-00 => 11073331 => log-train-2024-01-14-02-16-10
# exp_dir=/exp/rhuang/meta/icefall/egs/librispeech/ASR/pruned_transducer_stateless7_context_ali/exp/exp_libri_ts2 # log-train-2024-01-14-07-22-54-0
# exp_dir=/exp/rhuang/meta/icefall/egs/librispeech/ASR/pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri_100
# exp_dir=/exp/rhuang/meta/icefall/egs/librispeech/ASR/pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri2 # baseline, no biasing
# exp_dir=/exp/rhuang/meta/icefall/egs/librispeech/ASR/pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri # 11169512
# exp_dir=/exp/rhuang/meta/icefall/egs/librispeech/ASR/pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri_proxy # 11169515, 11169916
# exp_dir=pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri_proxy_34
# exp_dir=pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri_proxy_3ctc # 11171405, log-train-2024-03-04-02-10-27-2,
# exp_dir=pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri_proxy_3ctc # 11171405, log-train-2024-03-04-21-22-31-0
# exp_dir=pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri_proxy_3ctc_attn #
# exp_dir=pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri_proxy_2early
# exp_dir=pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri_proxy_4early
# exp_dir=pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri_proxy_234early
# exp_dir=pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri_proxy_3early_no5
exp_dir=pruned_transducer_stateless7_contextual/exp/exp_libri_test
mkdir -p $exp_dir
echo
echo "exp_dir:" $exp_dir
echo
path_to_pretrained_asr_model=/exp/rhuang/librispeech/pretrained2/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/
# path_to_pretrained_asr_model=/scratch4/skhudan1/rhuang25/icefall/egs/librispeech/ASR/download/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
# From pretrained ASR model
if [ ! -f $exp_dir/epoch-1.pt ]; then
ln -s $path_to_pretrained_asr_model/exp/pretrained.pt $exp_dir/epoch-1.pt
fi
####################################
# train
####################################
# if false; then
# echo "True"
# else
# echo "False"
# fi
if true; then
# # stage 1:
max_duration=1600
# max_duration=400 # libri100
n_distractors=0
is_full_context=true
# stage 2:
# max_duration=1600
# # max_duration=400 # libri100
# n_distractors=100
# is_full_context=false
python pruned_transducer_stateless7_contextual/train.py \
--world-size 4 \
--use-fp16 true \
--max-duration $max_duration \
--exp-dir $exp_dir \
--bpe-model "data/lang_bpe_500/bpe.model" \
--prune-range 5 \
--full-libri true \
--context-dir "data/fbai-speech/is21_deep_bias/" \
--keep-ratio 1.0 \
--start-epoch 2 \
--num-epochs 30 \
--is-pretrained-context-encoder false \
--is-reused-context-encoder false \
--is-full-context $is_full_context \
--n-distractors $n_distractors --start-epoch 14 --num-epochs 40 --master-port 12357 --proxy-prob 0.2 --keep-ratio 0.8 --throwaway-prob 0.7 # --start-batch 24000 # --base-lr 0.08 --master-port 12355 --irrelevance-learning true
fi
--n-distractors $n_distractors --master-port 12357 --proxy-prob 0.4 --early-layers 2 --enable-nn true
--n-distractors $n_distractors --master-port 12357 --proxy-prob 0.4 --early-layers 4 --enable-nn true
--n-distractors $n_distractors --master-port 12357 --proxy-prob 0.4 --early-layers 2,3,4 --enable-nn true
####################################
# tensorboard
####################################
# tensorboard dev upload --logdir /exp/rhuang/meta/icefall/egs/librispeech/ASR/$exp_dir/tensorboard --description `pwd`
# wandb sync $exp_dir/tensorboard
# https://github.com/k2-fsa/icefall/issues/1298
# python -c "import wandb; wandb.init(project='icefall-asr-gigaspeech-zipformer-2023-10-20')"
# wandb sync zipformer/exp/tensorboard -p icefall-asr-gigaspeech-zipformer-2023-10-20
# https://stackoverflow.com/questions/37987839/how-can-i-run-tensorboard-on-a-remote-server
# ssh -L 16006:127.0.0.1:6006 rhuang@test1.hltcoe.jhu.edu
# tensorboard --logdir $exp_dir/tensorboard --port 6006
# http://localhost:16006
# no-biasing: /exp/rhuang/icefall_latest/egs/spgispeech/ASR/pruned_transducer_stateless7/exp_500_norm/tensorboard/
# bi_enc: /exp/rhuang/icefall_latest/egs/spgispeech/ASR/pruned_transducer_stateless7_context/exp/exp_libri_full_c-1_stage1/
# single_enc: /exp/rhuang/icefall_latest/egs/spgispeech/ASR/pruned_transducer_stateless7_context/exp/exp_libri_full_c-1_stage1_single_enc
# exp_dir=pruned_transducer_stateless7_context_ali/exp
# n_distractors=0
# max_duration=1200
# python /exp/rhuang/meta/icefall/egs/spgispeech/ASR/pruned_transducer_stateless7_context_ali/train.py --world-size 1 --use-fp16 true --max-duration $max_duration --exp-dir $exp_dir --bpe-model "data/lang_bpe_500/bpe.model" --prune-range 5 --use-fp16 true --context-dir "data/uniphore_contexts/" --keep-ratio 1.0 --start-epoch 2 --num-epochs 30 --is-bi-context-encoder false --is-pretrained-context-encoder false --is-full-context true --n-distractors $n_distractors
####### debug: RuntimeError: grad_scale is too small, exiting: 8.470329472543003e-22
# encoder_out=rs["encoder_out"]; contexts_h=rs["contexts_h"]; contexts_mask=rs["contexts_mask"]
# queries=encoder_out; contexts=contexts_h; contexts_mask=contexts_mask; need_weights=True
# md = rs["encoder_biasing_adapter"]
# with torch.cuda.amp.autocast(enabled=True):
# queries = md.proj_in(queries)
# print("queries:", torch.any(torch.isnan(queries) | torch.isinf(queries)))
# attn_output, attn_output_weights = md.multihead_attn(queries,contexts,contexts,key_padding_mask=contexts_mask,need_weights=need_weights,)
# print(torch.any(torch.isnan(attn_output) | torch.isinf(attn_output)))
# print(torch.any(torch.isnan(attn_output_weights) | torch.isinf(attn_output_weights)))
# output = md.proj_out(attn_output)
# print(torch.any(torch.isnan(output) | torch.isinf(output)))
# encoder_out=rs["encoder_out"]; contexts_h=rs["contexts_h"]; contexts_mask=rs["contexts_mask"]
# queries=encoder_out; contexts=contexts_h; contexts_mask=contexts_mask; need_weights=True
# md = rs["encoder_biasing_adapter"]
# with torch.cuda.amp.autocast(enabled=False):
# queries = md.proj_in(queries)
# print("queries:", torch.any(torch.isnan(queries) | torch.isinf(queries)))
# attn_output, attn_output_weights = md.multihead_attn(queries,contexts,contexts,key_padding_mask=contexts_mask,need_weights=need_weights,)
# print(torch.any(torch.isnan(attn_output) | torch.isinf(attn_output)))
# print(torch.any(torch.isnan(attn_output_weights) | torch.isinf(attn_output_weights)))
# output = md.proj_out(attn_output)
# print(torch.any(torch.isnan(output) | torch.isinf(output)))
# encoder_out=rs["encoder_out"]; contexts_h=rs["contexts_h"]; contexts_mask=rs["contexts_mask"]
# queries=encoder_out; contexts=contexts_h; contexts_mask=contexts_mask; need_weights=True
# md = rs["encoder_biasing_adapter"]
# queries = md.proj_in(queries)
# print("queries:", torch.any(torch.isnan(queries) | torch.isinf(queries)))
# attn_output, attn_output_weights = md.multihead_attn(queries,contexts,contexts,key_padding_mask=contexts_mask,need_weights=need_weights,)
# print(torch.any(torch.isnan(attn_output) | torch.isinf(attn_output)))
# print(torch.any(torch.isnan(attn_output_weights) | torch.isinf(attn_output_weights)))
# output = md.proj_out(attn_output)
# print(torch.any(torch.isnan(output) | torch.isinf(output)))
python pruned_transducer_stateless7_context_proxy_all_layers/train.py --world-size 4 --use-fp16 true --max-duration $max_duration --exp-dir $exp_dir --bpe-model "data/lang_bpe_500/bpe.model" --prune-range 5 --full-libri true --context-dir "data/fbai-speech/is21_deep_bias/" --keep-ratio 1.0 --start-epoch 2 --num-epochs 30 --is-pretrained-context-encoder false --is-reused-context-encoder false --is-full-context $is_full_context --n-distractors $n_distractors --start-epoch 14 --num-epochs 40 --master-port 12357 --proxy-prob 0.2 --keep-ratio 0.8 --throwaway-prob 0.7 --n-distractors 100 --start-epoch 29 --num-epochs 40 --world-size 4 --max-duration 1800

View File

@ -0,0 +1,130 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script compares the word-level alignments generated based on modified_beam_search decoding
(in ./pruned_transducer_stateless7/compute_ali.py) to the reference alignments generated
by torchaudio framework (in ./add_alignments.sh).
Usage:
./pruned_transducer_stateless7/compute_ali.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--bpe-model data/lang_bpe_500/bpe.model \
--dataset test-clean \
--max-duration 300 \
--beam-size 4 \
--cuts-out-dir data/fbank_ali_beam_search
And the you can run:
./pruned_transducer_stateless7/test_compute_ali.py \
--cuts-out-dir ./data/fbank_ali_test \
--cuts-ref-dir ./data/fbank_ali_torch \
--dataset train-clean-100
"""
import argparse
import logging
from pathlib import Path
import torch
from lhotse import load_manifest
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--cuts-out-dir",
type=Path,
default="./data/fbank_ali",
help="The dir that saves the generated cuts manifests with alignments",
)
parser.add_argument(
"--cuts-ref-dir",
type=Path,
default="./data/fbank_ali_torch",
help="The dir that saves the reference cuts manifests with alignments",
)
parser.add_argument(
"--dataset",
type=str,
required=True,
help="""The name of the dataset:
Possible values are:
- test-clean
- test-other
- train-clean-100
- train-clean-360
- train-other-500
- dev-clean
- dev-other
""",
)
return parser
def main():
args = get_parser().parse_args()
cuts_out_jsonl = args.cuts_out_dir / f"librispeech_cuts_{args.dataset}.jsonl.gz"
cuts_ref_jsonl = args.cuts_ref_dir / f"librispeech_cuts_{args.dataset}.jsonl.gz"
logging.info(f"Loading {cuts_out_jsonl} and {cuts_ref_jsonl}")
cuts_out = load_manifest(cuts_out_jsonl)
cuts_ref = load_manifest(cuts_ref_jsonl)
cuts_ref = cuts_ref.sort_like(cuts_out)
all_time_diffs = []
for cut_out, cut_ref in zip(cuts_out, cuts_ref):
time_out = [
ali.start
for ali in cut_out.supervisions[0].alignment["word"]
if ali.symbol != ""
]
time_ref = [
ali.start
for ali in cut_ref.supervisions[0].alignment["word"]
if ali.symbol != ""
]
assert len(time_out) == len(time_ref), (len(time_out), len(time_ref))
diff = [
round(abs(out - ref), ndigits=3) for out, ref in zip(time_out, time_ref)
]
all_time_diffs += diff
all_time_diffs = torch.tensor(all_time_diffs)
logging.info(
f"For the word-level alignments abs difference on dataset {args.dataset}, "
f"mean: {'%.2f' % all_time_diffs.mean()}s, std: {'%.2f' % all_time_diffs.std()}s"
)
logging.info("Done!")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,68 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless7/test_model.py
"""
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.num_encoder_layers = "2,4,3,2,4"
params.feedforward_dims = "1024,1024,2048,2048,1024"
params.nhead = "8,8,8,8,8"
params.encoder_dims = "384,384,384,384,384"
params.attention_dims = "192,192,192,192,192"
params.encoder_unmasked_dims = "256,256,256,256,256"
params.zipformer_downsampling_factors = "1,2,4,8,2"
params.cnn_module_kernels = "31,31,31,31,31"
params.decoder_dim = 512
params.joiner_dim = 512
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
# Test jit script
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
print("Using torch.jit.script")
model = torch.jit.script(model)
def main():
test_model()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,374 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file is to test that models can be exported to onnx.
"""
import os
from icefall import is_module_available
if not is_module_available("onnxruntime"):
raise ValueError("Please 'pip install onnxruntime' first.")
import onnxruntime as ort
import torch
from scaling_converter import convert_scaled_to_non_scaled
from zipformer import (
Conv2dSubsampling,
RelPositionalEncoding,
Zipformer,
ZipformerEncoder,
ZipformerEncoderLayer,
)
ort.set_default_logger_severity(3)
def test_conv2d_subsampling():
filename = "conv2d_subsampling.onnx"
opset_version = 13
N = 30
T = 50
num_features = 80
d_model = 512
x = torch.rand(N, T, num_features)
encoder_embed = Conv2dSubsampling(num_features, d_model)
encoder_embed.eval()
encoder_embed = convert_scaled_to_non_scaled(encoder_embed, inplace=True)
torch.onnx.export(
encoder_embed,
x,
filename,
verbose=False,
opset_version=opset_version,
input_names=["x"],
output_names=["y"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"y": {0: "N", 1: "T"},
},
)
options = ort.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 1
session = ort.InferenceSession(
filename,
sess_options=options,
)
input_nodes = session.get_inputs()
assert input_nodes[0].name == "x"
assert input_nodes[0].shape == ["N", "T", num_features]
inputs = {input_nodes[0].name: x.numpy()}
onnx_y = session.run(["y"], inputs)[0]
onnx_y = torch.from_numpy(onnx_y)
torch_y = encoder_embed(x)
assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max()
os.remove(filename)
def test_rel_pos():
filename = "rel_pos.onnx"
opset_version = 13
N = 30
T = 50
num_features = 80
d_model = 512
x = torch.rand(N, T, num_features)
encoder_pos = RelPositionalEncoding(d_model, dropout_rate=0.1)
encoder_pos.eval()
encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True)
x = x.permute(1, 0, 2)
torch.onnx.export(
encoder_pos,
x,
filename,
verbose=False,
opset_version=opset_version,
input_names=["x"],
output_names=["pos_emb"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"pos_emb": {0: "N", 1: "T"},
},
)
options = ort.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 1
session = ort.InferenceSession(
filename,
sess_options=options,
)
input_nodes = session.get_inputs()
assert input_nodes[0].name == "x"
assert input_nodes[0].shape == ["N", "T", num_features]
inputs = {input_nodes[0].name: x.numpy()}
onnx_pos_emb = session.run(["pos_emb"], inputs)
onnx_pos_emb = torch.from_numpy(onnx_pos_emb[0])
torch_pos_emb = encoder_pos(x)
assert torch.allclose(onnx_pos_emb, torch_pos_emb, atol=1e-05), (
(onnx_pos_emb - torch_pos_emb).abs().max()
)
print(onnx_pos_emb.abs().sum(), torch_pos_emb.abs().sum())
os.remove(filename)
def test_zipformer_encoder_layer():
filename = "zipformer_encoder_layer.onnx"
opset_version = 13
N = 30
T = 50
d_model = 384
attention_dim = 192
nhead = 8
feedforward_dim = 1024
dropout = 0.1
cnn_module_kernel = 31
pos_dim = 4
x = torch.rand(N, T, d_model)
encoder_pos = RelPositionalEncoding(d_model, dropout)
encoder_pos.eval()
encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True)
x = x.permute(1, 0, 2)
pos_emb = encoder_pos(x)
encoder_layer = ZipformerEncoderLayer(
d_model,
attention_dim,
nhead,
feedforward_dim,
dropout,
cnn_module_kernel,
pos_dim,
)
encoder_layer.eval()
encoder_layer = convert_scaled_to_non_scaled(encoder_layer, inplace=True)
torch.onnx.export(
encoder_layer,
(x, pos_emb),
filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "pos_emb"],
output_names=["y"],
dynamic_axes={
"x": {0: "T", 1: "N"},
"pos_emb": {0: "N", 1: "T"},
"y": {0: "T", 1: "N"},
},
)
options = ort.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 1
session = ort.InferenceSession(
filename,
sess_options=options,
)
input_nodes = session.get_inputs()
inputs = {
input_nodes[0].name: x.numpy(),
input_nodes[1].name: pos_emb.numpy(),
}
onnx_y = session.run(["y"], inputs)[0]
onnx_y = torch.from_numpy(onnx_y)
torch_y = encoder_layer(x, pos_emb)
assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max()
print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape)
os.remove(filename)
def test_zipformer_encoder():
filename = "zipformer_encoder.onnx"
opset_version = 13
N = 3
T = 15
d_model = 512
attention_dim = 192
nhead = 8
feedforward_dim = 1024
dropout = 0.1
cnn_module_kernel = 31
pos_dim = 4
num_encoder_layers = 12
warmup_batches = 4000.0
warmup_begin = warmup_batches / (num_encoder_layers + 1)
warmup_end = warmup_batches / (num_encoder_layers + 1)
x = torch.rand(N, T, d_model)
encoder_layer = ZipformerEncoderLayer(
d_model,
attention_dim,
nhead,
feedforward_dim,
dropout,
cnn_module_kernel,
pos_dim,
)
encoder = ZipformerEncoder(
encoder_layer, num_encoder_layers, dropout, warmup_begin, warmup_end
)
encoder.eval()
encoder = convert_scaled_to_non_scaled(encoder, inplace=True)
# jit_model = torch.jit.trace(encoder, (pos_emb))
torch_y = encoder(x)
torch.onnx.export(
encoder,
(x),
filename,
verbose=False,
opset_version=opset_version,
input_names=["x"],
output_names=["y"],
dynamic_axes={
"x": {0: "T", 1: "N"},
"y": {0: "T", 1: "N"},
},
)
options = ort.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 1
session = ort.InferenceSession(
filename,
sess_options=options,
)
input_nodes = session.get_inputs()
inputs = {
input_nodes[0].name: x.numpy(),
}
onnx_y = session.run(["y"], inputs)[0]
onnx_y = torch.from_numpy(onnx_y)
torch_y = encoder(x)
assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max()
print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape)
os.remove(filename)
def test_zipformer():
filename = "zipformer.onnx"
opset_version = 11
N = 3
T = 15
num_features = 80
x = torch.rand(N, T, num_features)
x_lens = torch.full((N,), fill_value=T, dtype=torch.int64)
zipformer = Zipformer(num_features=num_features)
zipformer.eval()
zipformer = convert_scaled_to_non_scaled(zipformer, inplace=True)
# jit_model = torch.jit.trace(zipformer, (x, x_lens))
torch.onnx.export(
zipformer,
(x, x_lens),
filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens"],
output_names=["y", "y_lens"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"y": {0: "N", 1: "T"},
"y_lens": {0: "N"},
},
)
options = ort.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 1
session = ort.InferenceSession(
filename,
sess_options=options,
)
input_nodes = session.get_inputs()
inputs = {
input_nodes[0].name: x.numpy(),
input_nodes[1].name: x_lens.numpy(),
}
onnx_y, onnx_y_lens = session.run(["y", "y_lens"], inputs)
onnx_y = torch.from_numpy(onnx_y)
onnx_y_lens = torch.from_numpy(onnx_y_lens)
torch_y, torch_y_lens = zipformer(x, x_lens)
assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max()
assert torch.allclose(onnx_y_lens, torch_y_lens, atol=1e-05), (
(onnx_y_lens - torch_y_lens).abs().max()
)
print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape)
print(onnx_y_lens, torch_y_lens)
os.remove(filename)
@torch.no_grad()
def main():
test_conv2d_subsampling()
test_rel_pos()
test_zipformer_encoder_layer()
test_zipformer_encoder()
test_zipformer()
if __name__ == "__main__":
torch.manual_seed(20221011)
main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,83 @@
import torch
from transformers import BertTokenizer, BertModel
import logging
class BertEncoder:
def __init__(self, device=None, **kwargs):
# https://huggingface.co/bert-base-uncased
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# TODO: fast_tokenizers: https://huggingface.co/docs/transformers/v4.27.2/en/fast_tokenizers
self.bert_model = BertModel.from_pretrained("bert-base-uncased")
self.bert_model.eval()
for param in self.bert_model.parameters():
param.requires_grad = False
self.name = self.bert_model.config._name_or_path
self.embedding_size = 768
num_param = sum([p.numel() for p in self.bert_model.parameters()])
logging.info(f"Number of parameters in '{self.name}': {num_param}")
if device is not None:
self.bert_model.to(device)
logging.info(f"Loaded '{self.name}' to {device}")
logging.info(f"cuda.memory_allocated: {torch.cuda.memory_allocated()/1024/1024:.2f} MB")
def _encode_strings(
self,
word_list,
):
# word_list is a list of uncased strings
encoded_input = self.tokenizer(word_list, return_tensors='pt', padding=True)
encoded_input = encoded_input.to(self.bert_model.device)
out = self.bert_model(**encoded_input).pooler_output
# TODO:
# 1. compare with some online API
# 2. smaller or no dropout
# 3. other activation function: different ranges, sigmoid
# 4. compare the range with lstm encoder or rnnt hidden representation
# 5. more layers, transformer layers: how to connect two spaces
# out is of the shape: len(word_list) * 768
return out
def encode_strings(
self,
word_list,
batch_size = 6000,
silent = False,
):
"""
Encode a list of uncased strings into a list of embeddings
Args:
word_list:
A list of words, where each word is a string
Returns:
embeddings:
A list of embeddings (on CPU), of the shape len(word_list) * 768
"""
i = 0
embeddings_list = list()
while i < len(word_list):
if not silent and int(i / 10000) % 5 == 0:
logging.info(f"Using '{self.name}' to encode the wordlist: {i}/{len(word_list)}")
wlist = word_list[i: i + batch_size]
embeddings = self._encode_strings(wlist)
embeddings = embeddings.detach().cpu() # To save GPU memory
embeddings = list(embeddings)
embeddings_list.extend(embeddings)
i += batch_size
if not silent:
logging.info(f"Done, len(embeddings_list)={len(embeddings_list)}")
return embeddings_list
def free_up(self):
self.bert_model = self.bert_model.to(torch.device("cpu"))
torch.cuda.empty_cache()
logging.info(f"cuda.memory_allocated: {torch.cuda.memory_allocated()/1024/1024:.2f} MB")

View File

@ -0,0 +1,81 @@
import torch
import logging
import io
import fasttext
class FastTextEncoder:
def __init__(self, embeddings_path=None, model_path=None, **kwargs):
logging.info(f"Loading word embeddings from: {embeddings_path}")
self.word_to_vector = self.load_vectors(embeddings_path)
logging.info(f"Number of word embeddings: {len(self.word_to_vector)}")
self.model_path = model_path
self.model = None
self.name = "FastText"
self.embedding_size = 300
def load_vectors(self, fname):
fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
# n, d = map(int, fin.readline().split())
data = {}
for line in fin:
tokens = line.rstrip().split(' ')
w = tokens[0].upper()
embedding = list(map(float, tokens[1:]))
data[w] = torch.tensor(embedding)
return data
def _encode_strings(
self,
word_list,
lower_case=True,
):
'''
Encode unseen strings
'''
if len(word_list) == 0:
return []
if self.model is None:
logging.info(f"Lazy loading FastText model from: {self.model_path}")
self.model = fasttext.FastText.load_model(self.model_path) # "/exp/rhuang/fastText/cc.en.300.bin"
out = [self.model[w.lower()] if lower_case else self.model[w] for w in word_list]
for w, embedding in zip(word_list, out):
self.word_to_vector[w] = torch.tensor(embedding)
return out
def encode_strings(
self,
word_list,
batch_size = 6000,
silent = False,
):
"""
Encode a list of uncased strings into a list of embeddings
Args:
word_list:
A list of words, where each word is a string
Returns:
embeddings:
A list of embeddings (on CPU), of the shape len(word_list) * 768
"""
embeddings_list = list()
for i, w in enumerate(word_list):
if not silent and i % 50000 == 0:
logging.info(f"Using FastText to encode the wordlist: {i}/{len(word_list)}")
if w in self.word_to_vector:
embeddings_list.append(self.word_to_vector[w])
else:
embedding = self._encode_strings([w], lower_case=True)[0]
embeddings_list.append(embedding)
if not silent:
logging.info(f"Done, len(embeddings_list)={len(embeddings_list)}")
return embeddings_list
def free_up(self):
pass

File diff suppressed because it is too large Load Diff