mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
This is the initial commits for neural biasing implementation with early context injection and text perturbation; the codes runs well on the grid; however, it needs pretty much cleaning up and refactoring before maki a reasonable PR
This commit is contained in:
parent
5c04c31292
commit
78b7ef3e3f
@ -0,0 +1,206 @@
|
||||
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
|
||||
# The force alignment problem can be formulated as finding
|
||||
# a path in a rectangular lattice, where the path starts
|
||||
# from the lower left corner and ends at the upper right
|
||||
# corner. The horizontal axis of the lattice is `t` (representing
|
||||
# acoustic frame indexes) and the vertical axis is `u` (representing
|
||||
# BPE tokens of the transcript).
|
||||
#
|
||||
# The notations `t` and `u` are from the paper
|
||||
# https://arxiv.org/pdf/1211.3711.pdf
|
||||
#
|
||||
# Beam search is used to find the path with the highest log probabilities.
|
||||
#
|
||||
# It assumes the maximum number of symbols that can be
|
||||
# emitted per frame is 1.
|
||||
|
||||
|
||||
def batch_force_alignment(
|
||||
model: torch.nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_list: List[List[int]],
|
||||
beam_size: int = 4,
|
||||
) -> List[int]:
|
||||
"""Compute the force alignment of a batch of utterances given their transcripts
|
||||
in BPE tokens and the corresponding acoustic output from the encoder.
|
||||
|
||||
Caution:
|
||||
This function is modified from `modified_beam_search` in beam_search.py.
|
||||
We assume that the maximum number of sybmols per frame is 1.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||
encoder_out before padding.
|
||||
ys_list:
|
||||
A list of BPE token IDs list. We require that for each utterance i,
|
||||
len(ys_list[i]) <= encoder_out_lens[i].
|
||||
beam_size:
|
||||
Size of the beam used in beam search.
|
||||
|
||||
Returns:
|
||||
Return a list of frame indexes list for each utterance i,
|
||||
where len(ans[i]) == len(ys_list[i]).
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.ndim
|
||||
assert encoder_out.size(0) == len(ys_list), (encoder_out.size(0), len(ys_list))
|
||||
assert encoder_out.size(0) > 0, encoder_out.size(0)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
device = next(model.parameters()).device
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
sorted_indices = packed_encoder_out.sorted_indices.tolist()
|
||||
encoder_out_lens = encoder_out_lens.tolist()
|
||||
ys_lens = [len(ys) for ys in ys_list]
|
||||
sorted_encoder_out_lens = [encoder_out_lens[i] for i in sorted_indices]
|
||||
sorted_ys_lens = [ys_lens[i] for i in sorted_indices]
|
||||
sorted_ys_list = [ys_list[i] for i in sorted_indices]
|
||||
|
||||
B = [HypothesisList() for _ in range(N)]
|
||||
for i in range(N):
|
||||
B[i].add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
timestamp=[],
|
||||
)
|
||||
)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for t, batch_size in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
finalized_B = B[batch_size:] + finalized_B
|
||||
B = B[:batch_size]
|
||||
sorted_encoder_out_lens = sorted_encoder_out_lens[:batch_size]
|
||||
sorted_ys_lens = sorted_ys_lens[:batch_size]
|
||||
|
||||
hyps_shape = get_hyps_shape(B).to(device)
|
||||
|
||||
A = [list(b) for b in B]
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
|
||||
ys_log_probs = torch.cat(
|
||||
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
|
||||
) # (num_hyps, 1)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (num_hyps, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
|
||||
|
||||
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||
# as index, so we use `to(torch.int64)` below.
|
||||
current_encoder_out = torch.index_select(
|
||||
current_encoder_out,
|
||||
dim=0,
|
||||
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out, decoder_out, project_input=False
|
||||
) # (num_hyps, 1, 1, vocab_size)
|
||||
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||
|
||||
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||
log_probs.add_(ys_log_probs)
|
||||
|
||||
vocab_size = log_probs.size(-1)
|
||||
|
||||
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||
)
|
||||
ragged_log_probs = k2.RaggedTensor(
|
||||
shape=log_probs_shape, value=log_probs.reshape(-1)
|
||||
) # [batch][num_hyps*vocab_size]
|
||||
|
||||
for i in range(batch_size):
|
||||
for h, hyp in enumerate(A[i]):
|
||||
pos_u = len(hyp.timestamp)
|
||||
idx_offset = h * vocab_size
|
||||
if (sorted_encoder_out_lens[i] - 1 - t) >= (sorted_ys_lens[i] - pos_u):
|
||||
# emit blank token
|
||||
new_hyp = Hypothesis(
|
||||
log_prob=ragged_log_probs[i][idx_offset + blank_id],
|
||||
ys=hyp.ys[:],
|
||||
timestamp=hyp.timestamp[:],
|
||||
)
|
||||
B[i].add(new_hyp)
|
||||
if pos_u < sorted_ys_lens[i]:
|
||||
# emit non-blank token
|
||||
new_token = sorted_ys_list[i][pos_u]
|
||||
new_hyp = Hypothesis(
|
||||
log_prob=ragged_log_probs[i][idx_offset + new_token],
|
||||
ys=hyp.ys + [new_token],
|
||||
timestamp=hyp.timestamp + [t],
|
||||
)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
if len(B[i]) > beam_size:
|
||||
B[i] = B[i].topk(beam_size, length_norm=True)
|
||||
|
||||
B = B + finalized_B
|
||||
sorted_hyps = [b.get_most_probable() for b in B]
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
hyps = [sorted_hyps[i] for i in unsorted_indices]
|
||||
ans = []
|
||||
for i, hyp in enumerate(hyps):
|
||||
assert hyp.ys[context_size:] == ys_list[i], (hyp.ys[context_size:], ys_list[i])
|
||||
ans.append(hyp.timestamp)
|
||||
|
||||
return ans
|
||||
@ -0,0 +1,460 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class LibriSpeechAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--full-libri",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--concatenate-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, utterances (cuts) will be concatenated "
|
||||
"to minimize the amount of padding.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--duration-factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Determines the maximum duration of a concatenated cut "
|
||||
"relative to the duration of the longest cut in a batch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gap",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The amount of padding (in seconds) inserted between "
|
||||
"concatenated cuts. This padding is filled with noise when "
|
||||
"noise augmentation is used.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||
)
|
||||
# Cut concatenation should be the first transform in the list,
|
||||
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||
# different utterances.
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2SpeechRecognitionDataset(
|
||||
input_strategy=eval(self.args.input_strategy)(),
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
# NOTE: the PerturbSpeed transform should be added only if we
|
||||
# remove it from data prep stage.
|
||||
# Add on-the-fly speed perturbation; since originally it would
|
||||
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||
# 3x more epochs.
|
||||
# Speed perturbation probably should come first before
|
||||
# concatenation, but in principle the transforms order doesn't have
|
||||
# to be strict (e.g. could be randomized)
|
||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_100_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-100 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_100_cuts_sample(self) -> CutSet:
|
||||
logging.info("About to get train-clean-100 cuts (sampled)")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-clean-100-0.4.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_360_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-360 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_other_500_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-other-500 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_all_shuf_cuts(self) -> CutSet:
|
||||
logging.info(
|
||||
"About to get the shuffled train-clean-100, \
|
||||
train-clean-360 and train-other-500 cuts"
|
||||
)
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
||||
)
|
||||
@ -0,0 +1,259 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class SBCAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/manifests"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--concatenate-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, utterances (cuts) will be concatenated "
|
||||
"to minimize the amount of padding.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--duration-factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Determines the maximum duration of a concatenated cut "
|
||||
"relative to the duration of the longest cut in a batch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gap",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The amount of padding (in seconds) inserted between "
|
||||
"concatenated cuts. This padding is filled with noise when "
|
||||
"noise augmentation is used.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
# @lru_cache()
|
||||
# def test_healthcare_cuts(self) -> CutSet:
|
||||
# logging.info("About to get test-healthcare cuts")
|
||||
# return load_manifest_lazy(
|
||||
# self.args.manifest_dir / "uniphore_cuts_healthcare.jsonl.gz"
|
||||
# )
|
||||
|
||||
# @lru_cache()
|
||||
# def test_banking_cuts(self) -> CutSet:
|
||||
# logging.info("About to get test-banking cuts")
|
||||
# return load_manifest_lazy(
|
||||
# self.args.manifest_dir / "uniphore_cuts_banking.jsonl.gz"
|
||||
# )
|
||||
|
||||
# @lru_cache()
|
||||
# def test_insurance_cuts(self) -> CutSet:
|
||||
# logging.info("About to get test-insurance cuts")
|
||||
# return load_manifest_lazy(
|
||||
# self.args.manifest_dir / "uniphore_cuts_insurance.jsonl.gz"
|
||||
# )
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self, cuts_file) -> CutSet:
|
||||
logging.info(f"About to get cuts from {cuts_file}")
|
||||
return load_manifest_lazy(cuts_file)
|
||||
|
||||
# @lru_cache()
|
||||
# def test_sbc_cuts(self, cuts_file, sampling_rate) -> CutSet:
|
||||
# logging.info(f"About to get SBC cuts from {cuts_file}")
|
||||
# cuts = CutSet.from_file(cuts_file)
|
||||
# cuts = [c.to_mono(mono_downmix=True).resample(sampling_rate) for c in cuts]
|
||||
# cuts = CutSet.from_cuts(cuts)
|
||||
# return cuts
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,114 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Ffn(nn.Module):
|
||||
def __init__(self, input_dim, hidden_dim, out_dim, nlayers=1, drop_out=0.1, skip=False) -> None:
|
||||
super().__init__()
|
||||
|
||||
layers = []
|
||||
for ilayer in range(nlayers):
|
||||
_in = hidden_dim if ilayer > 0 else input_dim
|
||||
_out = hidden_dim if ilayer < nlayers - 1 else out_dim
|
||||
layers.extend([
|
||||
nn.Linear(_in, _out),
|
||||
# nn.ReLU(),
|
||||
# nn.Sigmoid(),
|
||||
nn.Tanh(),
|
||||
nn.Dropout(p=drop_out),
|
||||
])
|
||||
self.ffn = torch.nn.Sequential(
|
||||
*layers,
|
||||
)
|
||||
|
||||
self.skip = skip
|
||||
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
x_out = self.ffn(x)
|
||||
|
||||
if self.skip:
|
||||
x_out = x_out + x
|
||||
|
||||
return x_out
|
||||
|
||||
|
||||
class BiasingModule(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
query_dim,
|
||||
qkv_dim=64,
|
||||
num_heads=4,
|
||||
):
|
||||
super(BiasingModule, self).__init__()
|
||||
self.proj_in1 = nn.Linear(query_dim, qkv_dim)
|
||||
self.proj_in2 = Ffn(
|
||||
input_dim=qkv_dim,
|
||||
hidden_dim=qkv_dim,
|
||||
out_dim=qkv_dim,
|
||||
skip=True,
|
||||
drop_out=0.1,
|
||||
nlayers=2,
|
||||
)
|
||||
self.multihead_attn = torch.nn.MultiheadAttention(
|
||||
embed_dim=qkv_dim,
|
||||
num_heads=num_heads,
|
||||
# kdim=64,
|
||||
# vdim=64,
|
||||
batch_first=True,
|
||||
)
|
||||
self.proj_out1 = Ffn(
|
||||
input_dim=qkv_dim,
|
||||
hidden_dim=qkv_dim,
|
||||
out_dim=qkv_dim,
|
||||
skip=True,
|
||||
drop_out=0.1,
|
||||
nlayers=2,
|
||||
)
|
||||
self.proj_out2 = nn.Linear(qkv_dim, query_dim)
|
||||
self.glu = nn.GLU()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
queries,
|
||||
contexts,
|
||||
contexts_mask,
|
||||
need_weights=False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
query:
|
||||
of shape batch_size * seq_length * query_dim
|
||||
contexts:
|
||||
of shape batch_size * max_contexts_size * query_dim
|
||||
contexts_mask:
|
||||
of shape batch_size * max_contexts_size
|
||||
Returns:
|
||||
attn_output:
|
||||
of shape batch_size * seq_length * context_dim
|
||||
"""
|
||||
_queries = self.proj_in1(queries)
|
||||
_queries = self.proj_in2(_queries)
|
||||
# queries = queries / 0.01
|
||||
|
||||
attn_output, attn_output_weights = self.multihead_attn(
|
||||
_queries, # query
|
||||
contexts, # key
|
||||
contexts, # value
|
||||
key_padding_mask=contexts_mask,
|
||||
need_weights=need_weights,
|
||||
)
|
||||
output = self.proj_out1(attn_output)
|
||||
output = self.proj_out2(output)
|
||||
|
||||
# apply the gated linear unit
|
||||
biasing_output = self.glu(output.repeat(1,1,2))
|
||||
|
||||
# print(f"query={query.shape}")
|
||||
# print(f"value={contexts} value.shape={contexts.shape}")
|
||||
# print(f"attn_output_weights={attn_output_weights} attn_output_weights.shape={attn_output_weights.shape}")
|
||||
return biasing_output, attn_output_weights
|
||||
|
||||
345
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/compute_ali.py
Executable file
345
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/compute_ali.py
Executable file
@ -0,0 +1,345 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao,
|
||||
# Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
The script gets forced-alignments based on the modified_beam_search decoding method.
|
||||
Both token-level alignments and word-level alignments are saved to the new cuts manifests.
|
||||
|
||||
It loads a checkpoint and uses it to get the forced-alignments.
|
||||
You can generate the checkpoint with the following command:
|
||||
|
||||
./pruned_transducer_stateless7/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./pruned_transducer_stateless7/compute_ali.py \
|
||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--dataset test-clean \
|
||||
--max-duration 300 \
|
||||
--beam-size 4 \
|
||||
--cuts-out-dir data/fbank_ali_beam_search
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from alignment import batch_force_alignment
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from lhotse import CutSet
|
||||
from lhotse.serialization import SequentialJsonlWriter
|
||||
from lhotse.supervision import AlignmentItem
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.utils import AttributeDict, convert_timestamp, parse_timestamp
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=True,
|
||||
help="""The name of the dataset to compute alignments for.
|
||||
Possible values are:
|
||||
- test-clean
|
||||
- test-other
|
||||
- train-clean-100
|
||||
- train-clean-360
|
||||
- train-other-500
|
||||
- dev-clean
|
||||
- dev-other
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cuts-out-dir",
|
||||
type=str,
|
||||
default="data/fbank_ali_beam_search",
|
||||
help="The dir to save the new cuts manifests with alignments",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def align_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
) -> Tuple[List[List[str]], List[List[str]], List[List[float]], List[List[float]]]:
|
||||
"""Get forced-alignments for one batch.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
|
||||
Returns:
|
||||
token_list:
|
||||
A list of token list.
|
||||
word_list:
|
||||
A list of word list.
|
||||
token_time_list:
|
||||
A list of timestamps list for tokens.
|
||||
word_time_list.
|
||||
A list of timestamps list for words.
|
||||
|
||||
where len(token_list) == len(word_list) == len(token_time_list) == len(word_time_list),
|
||||
len(token_list[i]) == len(token_time_list[i]),
|
||||
and len(word_list[i]) == len(word_time_list[i])
|
||||
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
|
||||
texts = supervisions["text"]
|
||||
ys_list: List[List[int]] = sp.encode(texts, out_type=int)
|
||||
|
||||
frame_indexes = batch_force_alignment(
|
||||
model, encoder_out, encoder_out_lens, ys_list, params.beam_size
|
||||
)
|
||||
|
||||
token_list = []
|
||||
word_list = []
|
||||
token_time_list = []
|
||||
word_time_list = []
|
||||
for i in range(encoder_out.size(0)):
|
||||
tokens = sp.id_to_piece(ys_list[i])
|
||||
words = texts[i].split()
|
||||
token_time = convert_timestamp(
|
||||
frame_indexes[i], params.subsampling_factor, params.frame_shift_ms
|
||||
)
|
||||
word_time = parse_timestamp(tokens, token_time)
|
||||
assert len(word_time) == len(words), (len(word_time), len(words))
|
||||
|
||||
token_list.append(tokens)
|
||||
word_list.append(words)
|
||||
token_time_list.append(token_time)
|
||||
word_time_list.append(word_time)
|
||||
|
||||
return token_list, word_list, token_time_list, word_time_list
|
||||
|
||||
|
||||
def align_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
writer: SequentialJsonlWriter,
|
||||
) -> None:
|
||||
"""Get forced-alignments for the dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
writer:
|
||||
Writer to save the cuts with alignments.
|
||||
"""
|
||||
log_interval = 20
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
token_list, word_list, token_time_list, word_time_list = align_one_batch(
|
||||
params=params, model=model, sp=sp, batch=batch
|
||||
)
|
||||
|
||||
cut_list = batch["supervisions"]["cut"]
|
||||
for cut, token, word, token_time, word_time in zip(
|
||||
cut_list, token_list, word_list, token_time_list, word_time_list
|
||||
):
|
||||
assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}"
|
||||
token_ali = [
|
||||
AlignmentItem(
|
||||
symbol=token[i],
|
||||
start=round(token_time[i], ndigits=3),
|
||||
duration=None,
|
||||
)
|
||||
for i in range(len(token))
|
||||
]
|
||||
word_ali = [
|
||||
AlignmentItem(
|
||||
symbol=word[i], start=round(word_time[i], ndigits=3), duration=None
|
||||
)
|
||||
for i in range(len(word))
|
||||
]
|
||||
cut.supervisions[0].alignment = {"word": word_ali, "token": token_ali}
|
||||
writer.write(cut, flush=True)
|
||||
|
||||
num_cuts += len(cut_list)
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
if params.dataset == "test-clean":
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
elif params.dataset == "test-other":
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
elif params.dataset == "train-clean-100":
|
||||
train_clean_100_cuts = librispeech.train_clean_100_cuts()
|
||||
dl = librispeech.train_dataloaders(train_clean_100_cuts)
|
||||
elif params.dataset == "train-clean-360":
|
||||
train_clean_360_cuts = librispeech.train_clean_360_cuts()
|
||||
dl = librispeech.train_dataloaders(train_clean_360_cuts)
|
||||
elif params.dataset == "train-other-500":
|
||||
train_other_500_cuts = librispeech.train_other_500_cuts()
|
||||
dl = librispeech.train_dataloaders(train_other_500_cuts)
|
||||
elif params.dataset == "dev-clean":
|
||||
dev_clean_cuts = librispeech.dev_clean_cuts()
|
||||
dl = librispeech.valid_dataloaders(dev_clean_cuts)
|
||||
else:
|
||||
assert params.dataset == "dev-other", f"{params.dataset}"
|
||||
dev_other_cuts = librispeech.dev_other_cuts()
|
||||
dl = librispeech.valid_dataloaders(dev_other_cuts)
|
||||
|
||||
cuts_out_dir = Path(params.cuts_out_dir)
|
||||
cuts_out_dir.mkdir(parents=True, exist_ok=True)
|
||||
cuts_out_path = cuts_out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz"
|
||||
|
||||
with CutSet.open_writer(cuts_out_path) as writer:
|
||||
align_dataset(dl=dl, params=params, model=model, sp=sp, writer=writer)
|
||||
|
||||
logging.info(
|
||||
f"For dataset {params.dataset}, the cut manifest with framewise token alignments "
|
||||
f"and word alignments are saved to {cuts_out_path}"
|
||||
)
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
@ -0,0 +1,371 @@
|
||||
import torch
|
||||
import random
|
||||
from pathlib import Path
|
||||
import sentencepiece as spm
|
||||
from typing import Union, List
|
||||
import logging
|
||||
import ast
|
||||
import numpy as np
|
||||
from itertools import chain
|
||||
from word_encoder_bert import BertEncoder
|
||||
from context_wfst import generate_context_graph_nfa
|
||||
|
||||
class SentenceTokenizer:
|
||||
def encode(self, word_list: List, out_type: type = int) -> List:
|
||||
"""
|
||||
Encode a list of words into a list of tokens
|
||||
|
||||
Args:
|
||||
word_list:
|
||||
A list of words where each word is a string.
|
||||
E.g., ["nihao", "hello", "你好"]
|
||||
out_type:
|
||||
This defines the output type. If it is an "int" type,
|
||||
then each token is represented by its interger token id.
|
||||
Returns:
|
||||
A list of tokenized words, where each tokenization
|
||||
is a list of tokens.
|
||||
"""
|
||||
pass
|
||||
|
||||
class ContextCollector(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
path_is21_deep_bias: Path,
|
||||
sp: Union[spm.SentencePieceProcessor, SentenceTokenizer],
|
||||
bert_encoder: BertEncoder = None,
|
||||
n_distractors: int = 100,
|
||||
ratio_distractors: int = None,
|
||||
is_predefined: bool = False,
|
||||
keep_ratio: float = 1.0,
|
||||
is_full_context: bool = False,
|
||||
backoff_id: int = None,
|
||||
):
|
||||
self.sp = sp
|
||||
self.bert_encoder = bert_encoder
|
||||
self.path_is21_deep_bias = path_is21_deep_bias
|
||||
self.n_distractors = n_distractors
|
||||
self.ratio_distractors = ratio_distractors
|
||||
self.is_predefined = is_predefined
|
||||
self.keep_ratio = keep_ratio
|
||||
self.is_full_context = is_full_context # use all words (rare or common) in the context
|
||||
# self.embedding_dim = self.bert_encoder.bert_model.config.hidden_size
|
||||
self.backoff_id = backoff_id
|
||||
|
||||
logging.info(f"""
|
||||
n_distractors={n_distractors},
|
||||
ratio_distractors={ratio_distractors},
|
||||
is_predefined={is_predefined},
|
||||
keep_ratio={keep_ratio},
|
||||
is_full_context={is_full_context},
|
||||
bert_encoder={bert_encoder.name if bert_encoder is not None else None},
|
||||
""")
|
||||
|
||||
self.common_words = None
|
||||
self.rare_words = None
|
||||
self.all_words = None
|
||||
with open(path_is21_deep_bias / "words/all_rare_words.txt", "r") as fin:
|
||||
self.rare_words = [l.strip().upper() for l in fin if len(l) > 0]
|
||||
|
||||
with open(path_is21_deep_bias / "words/common_words_5k.txt", "r") as fin:
|
||||
self.common_words = [l.strip().upper() for l in fin if len(l) > 0]
|
||||
|
||||
self.all_words = self.rare_words + self.common_words # sp needs a list of strings, can't be a set
|
||||
self.common_words = set(self.common_words)
|
||||
self.rare_words = set(self.rare_words)
|
||||
|
||||
logging.info(f"Number of common words: {len(self.common_words)}. Examples: {random.sample(self.common_words, 5)}")
|
||||
logging.info(f"Number of rare words: {len(self.rare_words)}. Examples: {random.sample(self.rare_words, 5)}")
|
||||
logging.info(f"Number of all words: {len(self.all_words)}. Examples: {random.sample(self.all_words, 5)}")
|
||||
|
||||
self.test_clean_biasing_list = None
|
||||
self.test_other_biasing_list = None
|
||||
if is_predefined:
|
||||
def read_ref_biasing_list(filename):
|
||||
biasing_list = dict()
|
||||
all_cnt = 0
|
||||
rare_cnt = 0
|
||||
with open(filename, "r") as fin:
|
||||
for line in fin:
|
||||
line = line.strip().upper()
|
||||
if len(line) == 0:
|
||||
continue
|
||||
line = line.split("\t")
|
||||
uid, ref_text, ref_rare_words, context_rare_words = line
|
||||
context_rare_words = ast.literal_eval(context_rare_words)
|
||||
biasing_list[uid] = context_rare_words
|
||||
|
||||
ref_rare_words = ast.literal_eval(ref_rare_words)
|
||||
ref_text = ref_text.split()
|
||||
all_cnt += len(ref_text)
|
||||
rare_cnt += len(ref_rare_words)
|
||||
return biasing_list, rare_cnt / all_cnt
|
||||
|
||||
self.test_clean_biasing_list, ratio_clean = \
|
||||
read_ref_biasing_list(self.path_is21_deep_bias / f"ref/test-clean.biasing_{n_distractors}.tsv")
|
||||
self.test_other_biasing_list, ratio_other = \
|
||||
read_ref_biasing_list(self.path_is21_deep_bias / f"ref/test-other.biasing_{n_distractors}.tsv")
|
||||
|
||||
logging.info(f"Number of utterances in test_clean_biasing_list: {len(self.test_clean_biasing_list)}, rare ratio={ratio_clean:.2f}")
|
||||
logging.info(f"Number of utterances in test_other_biasing_list: {len(self.test_other_biasing_list)}, rare ratio={ratio_other:.2f}")
|
||||
|
||||
self.all_words2pieces = None
|
||||
if self.sp is not None:
|
||||
all_words2pieces = sp.encode(self.all_words, out_type=int) # a list of list of int
|
||||
self.all_words2pieces = {w: pieces for w, pieces in zip(self.all_words, all_words2pieces)}
|
||||
logging.info(f"len(self.all_words2pieces)={len(self.all_words2pieces)}")
|
||||
|
||||
self.all_words2embeddings = None
|
||||
if self.bert_encoder is not None:
|
||||
all_words = list(chain(self.common_words, self.rare_words))
|
||||
all_embeddings = self.bert_encoder.encode_strings(all_words)
|
||||
assert len(all_words) == len(all_embeddings)
|
||||
self.all_words2embeddings = {w: ebd for w, ebd in zip(all_words, all_embeddings)}
|
||||
logging.info(f"len(self.all_words2embeddings)={len(self.all_words2embeddings)}")
|
||||
|
||||
if is_predefined:
|
||||
new_words_bias = set()
|
||||
all_words_bias = set()
|
||||
for uid, wlist in chain(self.test_clean_biasing_list.items(), self.test_other_biasing_list.items()):
|
||||
for word in wlist:
|
||||
if word not in self.common_words and word not in self.rare_words:
|
||||
new_words_bias.add(word)
|
||||
all_words_bias.add(word)
|
||||
# if self.all_words2pieces is not None and word not in self.all_words2pieces:
|
||||
# self.all_words2pieces[word] = self.sp.encode(word, out_type=int)
|
||||
# if self.all_words2embeddings is not None and word not in self.all_words2embeddings:
|
||||
# self.all_words2embeddings[word] = self.bert_encoder.encode_strings([word])[0]
|
||||
logging.info(f"OOVs in the biasing list: {len(new_words_bias)}/{len(all_words_bias)}")
|
||||
if len(new_words_bias) > 0:
|
||||
self.add_new_words(list(new_words_bias), silent=True)
|
||||
|
||||
if is_predefined:
|
||||
assert self.ratio_distractors is None
|
||||
assert self.n_distractors in [100, 500, 1000, 2000]
|
||||
|
||||
self.temp_dict = None
|
||||
|
||||
def add_new_words(self, new_words_list, return_dict=False, silent=False):
|
||||
if len(new_words_list) == 0:
|
||||
if return_dict is True:
|
||||
return dict()
|
||||
else:
|
||||
return
|
||||
|
||||
if self.all_words2pieces is not None:
|
||||
words_pieces_list = self.sp.encode(new_words_list, out_type=int)
|
||||
new_words2pieces = {w: pieces for w, pieces in zip(new_words_list, words_pieces_list)}
|
||||
if return_dict:
|
||||
return new_words2pieces
|
||||
else:
|
||||
self.all_words2pieces.update(new_words2pieces)
|
||||
|
||||
if self.all_words2embeddings is not None:
|
||||
embeddings_list = self.bert_encoder.encode_strings(new_words_list, silent=silent)
|
||||
new_words2embeddings = {w: ebd for w, ebd in zip(new_words_list, embeddings_list)}
|
||||
if return_dict:
|
||||
return new_words2embeddings
|
||||
else:
|
||||
self.all_words2embeddings.update(new_words2embeddings)
|
||||
|
||||
self.all_words.extend(new_words_list)
|
||||
self.rare_words.update(new_words_list)
|
||||
|
||||
def discard_some_common_words(words, keep_ratio):
|
||||
pass
|
||||
|
||||
def _get_random_word_lists(self, batch):
|
||||
texts = batch["supervisions"]["text"]
|
||||
|
||||
new_words = []
|
||||
rare_words_list = []
|
||||
for text in texts:
|
||||
rare_words = []
|
||||
for word in text.split():
|
||||
if self.is_full_context or word not in self.common_words:
|
||||
rare_words.append(word)
|
||||
|
||||
if self.all_words2pieces is not None and word not in self.all_words2pieces:
|
||||
new_words.append(word)
|
||||
# self.all_words2pieces[word] = self.sp.encode(word, out_type=int)
|
||||
if self.all_words2embeddings is not None and word not in self.all_words2embeddings:
|
||||
new_words.append(word)
|
||||
# logging.info(f"New word detected: {word}")
|
||||
# self.all_words2embeddings[word] = self.bert_encoder.encode_strings([word])[0]
|
||||
|
||||
rare_words = list(set(rare_words)) # deduplication
|
||||
|
||||
if self.keep_ratio < 1.0 and len(rare_words) > 0:
|
||||
# # method 1:
|
||||
# keep_size = int(len(rare_words) * self.keep_ratio)
|
||||
# if keep_size > 0:
|
||||
# rare_words = random.sample(rare_words, keep_size)
|
||||
# else:
|
||||
# rare_words = []
|
||||
|
||||
# method 2:
|
||||
x = np.random.rand(len(rare_words))
|
||||
new_rare_words = []
|
||||
for xi in range(len(rare_words)):
|
||||
if x[xi] < self.keep_ratio:
|
||||
new_rare_words.append(rare_words[xi])
|
||||
rare_words = new_rare_words
|
||||
|
||||
rare_words_list.append(rare_words)
|
||||
|
||||
self.temp_dict = None
|
||||
if len(new_words) > 0:
|
||||
self.temp_dict = self.add_new_words(new_words, return_dict=True, silent=True)
|
||||
|
||||
if self.ratio_distractors is not None:
|
||||
n_distractors_each = []
|
||||
for rare_words in rare_words_list:
|
||||
n_distractors_each.append(len(rare_words) * self.ratio_distractors)
|
||||
n_distractors_each = np.asarray(n_distractors_each, dtype=int)
|
||||
else:
|
||||
if self.n_distractors == -1: # variable context list sizes
|
||||
n_distractors_each = np.random.randint(low=10, high=500, size=len(texts))
|
||||
# n_distractors_each = np.random.randint(low=80, high=300, size=len(texts))
|
||||
else:
|
||||
n_distractors_each = np.full(len(texts), self.n_distractors, int)
|
||||
distractors_cnt = n_distractors_each.sum()
|
||||
|
||||
distractors = random.sample( # without replacement
|
||||
self.rare_words,
|
||||
distractors_cnt
|
||||
) # TODO: actually the context should contain both rare and common words
|
||||
# distractors = random.choices( # random choices with replacement
|
||||
# self.rare_words,
|
||||
# distractors_cnt,
|
||||
# )
|
||||
distractors_pos = 0
|
||||
for i, rare_words in enumerate(rare_words_list):
|
||||
rare_words.extend(distractors[distractors_pos: distractors_pos + n_distractors_each[i]])
|
||||
distractors_pos += n_distractors_each[i]
|
||||
# random.shuffle(rare_words)
|
||||
# logging.info(rare_words)
|
||||
assert distractors_pos == len(distractors)
|
||||
|
||||
return rare_words_list
|
||||
|
||||
def _get_predefined_word_lists(self, batch):
|
||||
rare_words_list = []
|
||||
for cut in batch['supervisions']['cut']:
|
||||
uid = cut.supervisions[0].id
|
||||
if uid in self.test_clean_biasing_list:
|
||||
rare_words_list.append(self.test_clean_biasing_list[uid])
|
||||
elif uid in self.test_other_biasing_list:
|
||||
rare_words_list.append(self.test_other_biasing_list[uid])
|
||||
else:
|
||||
rare_words_list.append([])
|
||||
logging.error(f"uid={uid} cannot find the predefined biasing list of size {self.n_distractors}")
|
||||
# for wl in rare_words_list:
|
||||
# for w in wl:
|
||||
# if w not in self.all_words2pieces:
|
||||
# self.all_words2pieces[w] = self.sp.encode(w, out_type=int)
|
||||
return rare_words_list
|
||||
|
||||
def get_context_word_list(
|
||||
self,
|
||||
batch: dict,
|
||||
):
|
||||
"""
|
||||
Generate/Get the context biasing list as a list of words for each utterance
|
||||
Use keep_ratio to simulate the "imperfect" context which may not have 100% coverage of the ground truth words.
|
||||
"""
|
||||
if self.is_predefined:
|
||||
rare_words_list = self._get_predefined_word_lists(batch)
|
||||
else:
|
||||
rare_words_list = self._get_random_word_lists(batch)
|
||||
|
||||
rare_words_list = [sorted(rwl) for rwl in rare_words_list]
|
||||
|
||||
if self.all_words2embeddings is None:
|
||||
# Use SentencePiece to encode the words
|
||||
rare_words_pieces_list = []
|
||||
max_pieces_len = 0
|
||||
for rare_words in rare_words_list:
|
||||
rare_words_pieces = [self.all_words2pieces[w] if w in self.all_words2pieces else self.temp_dict[w] for w in rare_words]
|
||||
if len(rare_words_pieces) > 0:
|
||||
max_pieces_len = max(max_pieces_len, max(len(pieces) for pieces in rare_words_pieces))
|
||||
rare_words_pieces_list.append(rare_words_pieces)
|
||||
else:
|
||||
# Use BERT embeddings here
|
||||
rare_words_embeddings_list = []
|
||||
for rare_words in rare_words_list:
|
||||
# for w in rare_words:
|
||||
# if w not in self.all_words2embeddings and (self.temp_dict is not None and w not in self.temp_dict):
|
||||
# import pdb; pdb.set_trace()
|
||||
# if w == "STUBE":
|
||||
# import pdb; pdb.set_trace()
|
||||
rare_words_embeddings = [self.all_words2embeddings[w] if w in self.all_words2embeddings else self.temp_dict[w] for w in rare_words]
|
||||
rare_words_embeddings_list.append(rare_words_embeddings)
|
||||
|
||||
if self.all_words2embeddings is None:
|
||||
# Use SentencePiece to encode the words
|
||||
word_list = []
|
||||
word_lengths = []
|
||||
num_words_per_utt = []
|
||||
pad_token = 0
|
||||
for rare_words_pieces in rare_words_pieces_list:
|
||||
num_words_per_utt.append(len(rare_words_pieces))
|
||||
word_lengths.extend([len(pieces) for pieces in rare_words_pieces])
|
||||
|
||||
# # TODO: this is a bug here: this will effectively modify the entries in 'self.all_words2embeddings'!!!
|
||||
# for pieces in rare_words_pieces:
|
||||
# pieces += [pad_token] * (max_pieces_len - len(pieces))
|
||||
# word_list.extend(rare_words_pieces)
|
||||
|
||||
# Correction:
|
||||
rare_words_pieces_padded = list()
|
||||
for pieces in rare_words_pieces:
|
||||
rare_words_pieces_padded.append(pieces + [pad_token] * (max_pieces_len - len(pieces)))
|
||||
word_list.extend(rare_words_pieces_padded)
|
||||
|
||||
word_list = torch.tensor(word_list, dtype=torch.int32)
|
||||
# word_lengths = torch.tensor(word_lengths, dtype=torch.int32)
|
||||
# num_words_per_utt = torch.tensor(num_words_per_utt, dtype=torch.int32)
|
||||
else:
|
||||
# Use BERT embeddings here
|
||||
word_list = []
|
||||
word_lengths = None
|
||||
num_words_per_utt = []
|
||||
for rare_words_embeddings in rare_words_embeddings_list:
|
||||
num_words_per_utt.append(len(rare_words_embeddings))
|
||||
word_list.extend(rare_words_embeddings)
|
||||
word_list = torch.stack(word_list)
|
||||
|
||||
return word_list, word_lengths, num_words_per_utt
|
||||
|
||||
def get_context_word_wfst(
|
||||
self,
|
||||
batch: dict,
|
||||
):
|
||||
"""
|
||||
Get the WFST representation of the context biasing list as a list of words for each utterance
|
||||
"""
|
||||
if self.is_predefined:
|
||||
rare_words_list = self._get_predefined_word_lists(batch)
|
||||
else:
|
||||
rare_words_list = self._get_random_word_lists(batch)
|
||||
|
||||
# TODO:
|
||||
# We can associate weighted or dynamic weights for each rare word or token
|
||||
|
||||
nbest_size = 1 # TODO: The maximum number of different tokenization for each lexicon entry.
|
||||
|
||||
# Use SentencePiece to encode the words
|
||||
rare_words_pieces_list = []
|
||||
num_words_per_utt = []
|
||||
for rare_words in rare_words_list:
|
||||
rare_words_pieces = [self.all_words2pieces[w] if w in self.all_words2pieces else self.temp_dict[w] for w in rare_words]
|
||||
rare_words_pieces_list.append(rare_words_pieces)
|
||||
num_words_per_utt.append(len(rare_words))
|
||||
|
||||
fsa_list, fsa_sizes = generate_context_graph_nfa(
|
||||
words_pieces_list = rare_words_pieces_list,
|
||||
backoff_id = self.backoff_id,
|
||||
sp = self.sp,
|
||||
)
|
||||
|
||||
return fsa_list, fsa_sizes, num_words_per_utt
|
||||
@ -0,0 +1,116 @@
|
||||
import torch
|
||||
import abc
|
||||
|
||||
class ContextEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(ContextEncoder, self).__init__()
|
||||
|
||||
@abc.abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
word_list,
|
||||
word_lengths,
|
||||
is_encoder_side=None,
|
||||
):
|
||||
pass
|
||||
|
||||
def embed_contexts(
|
||||
self,
|
||||
contexts,
|
||||
is_encoder_side=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
contexts:
|
||||
The contexts, see below for details
|
||||
Returns:
|
||||
final_h:
|
||||
A tensor of shape (batch_size, max(num_words_per_utt) + 1, joiner_dim),
|
||||
which is the embedding for each context word.
|
||||
mask_h:
|
||||
A tensor of shape (batch_size, max(num_words_per_utt) + 1),
|
||||
which contains a True/False mask for final_h
|
||||
"""
|
||||
if contexts["mode"] == "get_context_word_list":
|
||||
"""
|
||||
word_list:
|
||||
Option1: A list of words, where each word is a list of token ids.
|
||||
The list of tokens for each word has been padded.
|
||||
Option2: A list of words, where each word is an embedding.
|
||||
word_lengths:
|
||||
Option1: The number of tokens per word
|
||||
Option2: None
|
||||
num_words_per_utt:
|
||||
The number of words in the context for each utterance
|
||||
"""
|
||||
word_list, word_lengths, num_words_per_utt = \
|
||||
contexts["word_list"], contexts["word_lengths"], contexts["num_words_per_utt"]
|
||||
|
||||
assert word_lengths is None or word_list.size(0) == len(word_lengths)
|
||||
batch_size = len(num_words_per_utt)
|
||||
elif contexts["mode"] == "get_context_word_list_shared":
|
||||
"""
|
||||
word_list:
|
||||
Option1: A list of words, where each word is a list of token ids.
|
||||
The list of tokens for each word has been padded.
|
||||
Option2: A list of words, where each word is an embedding.
|
||||
word_lengths:
|
||||
Option1: The number of tokens per word
|
||||
Option2: None
|
||||
positive_mask_list:
|
||||
For each utterance, it contains a list of indices of the words should be masked
|
||||
"""
|
||||
word_list, word_lengths, positive_mask_list = \
|
||||
contexts["word_list"], contexts["word_lengths"], contexts["positive_mask_list"]
|
||||
batch_size = len(positive_mask_list)
|
||||
|
||||
assert word_lengths is None or word_list.size(0) == len(word_lengths)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# print(f"word_list.shape={word_list.shape}")
|
||||
final_h = self.forward(word_list, word_lengths, is_encoder_side=is_encoder_side)
|
||||
|
||||
if contexts["mode"] == "get_context_word_list":
|
||||
final_h = torch.split(final_h, num_words_per_utt)
|
||||
final_h = torch.nn.utils.rnn.pad_sequence(
|
||||
final_h,
|
||||
batch_first=True,
|
||||
padding_value=0.0
|
||||
)
|
||||
# print(f"final_h.shape={final_h.shape}")
|
||||
|
||||
# add one no-bias token
|
||||
no_bias_h = torch.zeros(final_h.shape[0], 1, final_h.shape[-1])
|
||||
no_bias_h = no_bias_h.to(final_h.device)
|
||||
final_h = torch.cat((no_bias_h, final_h), dim=1)
|
||||
# print(final_h)
|
||||
|
||||
# https://stackoverflow.com/questions/53403306/how-to-batch-convert-sentence-lengths-to-masks-in-pytorch
|
||||
mask_h = torch.arange(max(num_words_per_utt) + 1)
|
||||
mask_h = mask_h.expand(len(num_words_per_utt), max(num_words_per_utt) + 1) > torch.Tensor(num_words_per_utt).unsqueeze(1)
|
||||
mask_h = mask_h.to(final_h.device)
|
||||
elif contexts["mode"] == "get_context_word_list_shared":
|
||||
no_bias_h = torch.zeros(1, final_h.shape[-1])
|
||||
no_bias_h = no_bias_h.to(final_h.device)
|
||||
final_h = torch.cat((no_bias_h, final_h), dim=0)
|
||||
|
||||
final_h = final_h.expand(batch_size, -1, -1)
|
||||
|
||||
mask_h = torch.full(False, (batch_size, final_h.shape(1))) # TODO
|
||||
for i, my_mask in enumerate(positive_mask_list):
|
||||
if len(my_mask) > 0:
|
||||
my_mask = torch.Tensor(my_mask, dtype=int)
|
||||
my_mask += 1
|
||||
mask_h[i][my_mask] = True
|
||||
|
||||
# TODO: validate this shape is correct:
|
||||
# final_h: batch_size * max_num_words_per_utt + 1 * dim
|
||||
# mask_h: batch_size * max_num_words_per_utt + 1
|
||||
return final_h, mask_h
|
||||
|
||||
def clustering(self):
|
||||
pass
|
||||
|
||||
def cache(self):
|
||||
pass
|
||||
@ -0,0 +1,101 @@
|
||||
import torch
|
||||
from context_encoder import ContextEncoder
|
||||
import copy
|
||||
|
||||
class ContextEncoderLSTM(ContextEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int = None,
|
||||
context_encoder_dim: int = None,
|
||||
output_dim: int = None,
|
||||
num_layers: int = None,
|
||||
num_directions: int = None,
|
||||
drop_out: float = 0.1,
|
||||
bi_encoders: bool = False,
|
||||
):
|
||||
super(ContextEncoderLSTM, self).__init__()
|
||||
self.num_layers = num_layers
|
||||
self.num_directions = num_directions
|
||||
self.context_encoder_dim = context_encoder_dim
|
||||
|
||||
torch.manual_seed(42)
|
||||
self.embed = torch.nn.Embedding(
|
||||
vocab_size,
|
||||
context_encoder_dim
|
||||
)
|
||||
self.rnn = torch.nn.LSTM(
|
||||
input_size=context_encoder_dim,
|
||||
hidden_size=context_encoder_dim,
|
||||
num_layers=self.num_layers,
|
||||
batch_first=True,
|
||||
bidirectional=(self.num_directions == 2),
|
||||
dropout=0.0 if self.num_layers > 1 else 0
|
||||
)
|
||||
self.linear = torch.nn.Linear(
|
||||
context_encoder_dim * self.num_directions,
|
||||
output_dim
|
||||
)
|
||||
self.drop_out = torch.nn.Dropout(drop_out)
|
||||
|
||||
# TODO: Do we need some relu layer?
|
||||
# https://galhever.medium.com/sentiment-analysis-with-pytorch-part-4-lstm-bilstm-model-84447f6c4525
|
||||
# self.relu = nn.ReLU()
|
||||
# self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.bi_encoders = bi_encoders
|
||||
if bi_encoders:
|
||||
# Create the decoder/predictor side of the context encoder
|
||||
self.embed_dec = copy.deepcopy(self.embed)
|
||||
self.rnn_dec = copy.deepcopy(self.rnn)
|
||||
self.linear_dec = copy.deepcopy(self.linear)
|
||||
self.drop_out_dec = copy.deepcopy(self.drop_out)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
word_list,
|
||||
word_lengths,
|
||||
is_encoder_side=None,
|
||||
):
|
||||
if is_encoder_side is None or is_encoder_side is True:
|
||||
embed = self.embed
|
||||
rnn = self.rnn
|
||||
linear = self.linear
|
||||
drop_out = self.drop_out
|
||||
else:
|
||||
embed = self.embed_dec
|
||||
rnn = self.rnn_dec
|
||||
linear = self.linear_dec
|
||||
drop_out = self.drop_out_dec
|
||||
|
||||
out = embed(word_list)
|
||||
# https://stackoverflow.com/questions/51030782/why-do-we-pack-the-sequences-in-pytorch
|
||||
out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
out,
|
||||
batch_first=True,
|
||||
lengths=word_lengths,
|
||||
enforce_sorted=False
|
||||
)
|
||||
output, (hn, cn) = rnn(out) # use default all zeros (h_0, c_0)
|
||||
|
||||
# # https://discuss.pytorch.org/t/bidirectional-3-layer-lstm-hidden-output/41336/4
|
||||
# final_state = hn.view(
|
||||
# self.num_layers,
|
||||
# self.num_directions,
|
||||
# word_list.shape[0],
|
||||
# self.encoder_dim,
|
||||
# )[-1] # Only the last layer
|
||||
# h_1, h_2 = final_state[0], final_state[1]
|
||||
# # X = h_1 + h_2 # Add both states (needs different input size for first linear layer)
|
||||
# final_h = torch.cat((h_1, h_2), dim=1) # Concatenate both states
|
||||
# final_h = self.linear(final_h)
|
||||
|
||||
# hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]
|
||||
# outputs are always from the last layer.
|
||||
# hidden[-2, :, : ] is the last of the forwards RNN
|
||||
# hidden[-1, :, : ] is the last of the backwards RNN
|
||||
h_1, h_2 = hn[-2, :, : ] , hn[-1, :, : ]
|
||||
final_h = torch.cat((h_1, h_2), dim=1) # Concatenate both states
|
||||
final_h = linear(final_h)
|
||||
# final_h = drop_out(final_h)
|
||||
|
||||
return final_h
|
||||
@ -0,0 +1,52 @@
|
||||
import torch
|
||||
from context_encoder import ContextEncoder
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ContextEncoderPretrained(ContextEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int = None,
|
||||
context_encoder_dim: int = None,
|
||||
output_dim: int = None,
|
||||
num_layers: int = None,
|
||||
num_directions: int = None,
|
||||
drop_out: float = 0.3,
|
||||
):
|
||||
super(ContextEncoderPretrained, self).__init__()
|
||||
|
||||
self.drop_out = torch.nn.Dropout(drop_out)
|
||||
self.linear1 = torch.nn.Linear(
|
||||
context_encoder_dim, # 768
|
||||
256,
|
||||
)
|
||||
self.linear3 = torch.nn.Linear(
|
||||
256,
|
||||
256,
|
||||
)
|
||||
self.linear4 = torch.nn.Linear(
|
||||
256,
|
||||
256,
|
||||
)
|
||||
self.linear2 = torch.nn.Linear(
|
||||
256,
|
||||
output_dim
|
||||
)
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
|
||||
self.bi_encoders = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
word_list,
|
||||
word_lengths,
|
||||
is_encoder_side=None,
|
||||
):
|
||||
out = word_list # Shape: N*L*D
|
||||
# out = self.drop_out(out)
|
||||
out = self.sigmoid(self.linear1(out)) # Note: ReLU may not be a good choice here
|
||||
out = self.sigmoid(self.linear3(out))
|
||||
out = self.sigmoid(self.linear4(out))
|
||||
# out = self.drop_out(out)
|
||||
out = self.linear2(out)
|
||||
return out
|
||||
@ -0,0 +1,90 @@
|
||||
import torch
|
||||
from context_encoder import ContextEncoder
|
||||
import copy
|
||||
|
||||
class ContextEncoderReused(ContextEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
decoder,
|
||||
decoder_dim: int = None,
|
||||
output_dim: int = None,
|
||||
num_lstm_layers: int = None,
|
||||
num_lstm_directions: int = None,
|
||||
drop_out: float = 0.1,
|
||||
):
|
||||
super(ContextEncoderReused, self).__init__()
|
||||
# self.num_lstm_layers = num_lstm_layers
|
||||
# self.num_lstm_directions = num_lstm_directions
|
||||
# self.decoder_dim = decoder_dim
|
||||
|
||||
hidden_size = output_dim * 2 # decoder_dim
|
||||
self.rnn = torch.nn.LSTM(
|
||||
input_size=decoder_dim,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_lstm_layers,
|
||||
batch_first=True,
|
||||
bidirectional=(num_lstm_directions == 2),
|
||||
dropout=0.1 if num_lstm_layers > 1 else 0
|
||||
)
|
||||
self.linear = torch.nn.Linear(
|
||||
hidden_size * num_lstm_directions,
|
||||
output_dim
|
||||
)
|
||||
self.drop_out = torch.nn.Dropout(drop_out)
|
||||
|
||||
self.decoder = decoder
|
||||
|
||||
self.bi_encoders = False
|
||||
|
||||
# TODO: Do we need some relu layer?
|
||||
# https://galhever.medium.com/sentiment-analysis-with-pytorch-part-4-lstm-bilstm-model-84447f6c4525
|
||||
# self.relu = nn.ReLU()
|
||||
# self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
word_list,
|
||||
word_lengths,
|
||||
is_encoder_side=None,
|
||||
):
|
||||
sos_id = self.decoder.blank_id
|
||||
sos_list = torch.full((word_list.shape[0], 1), sos_id).to(word_list.device)
|
||||
sos_word_list = torch.cat((sos_list, word_list), 1)
|
||||
word_lengths = [x + 1 for x in word_lengths]
|
||||
|
||||
# sos_word_list: (N, U)
|
||||
# decoder_out: (N, U, decoder_dim)
|
||||
out = self.decoder(sos_word_list)
|
||||
|
||||
# https://stackoverflow.com/questions/51030782/why-do-we-pack-the-sequences-in-pytorch
|
||||
# https://gist.github.com/HarshTrivedi/f4e7293e941b17d19058f6fb90ab0fec
|
||||
out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
out,
|
||||
batch_first=True,
|
||||
lengths=word_lengths,
|
||||
enforce_sorted=False
|
||||
)
|
||||
output, (hn, cn) = self.rnn(out) # use default all zeros (h_0, c_0)
|
||||
|
||||
# # https://discuss.pytorch.org/t/bidirectional-3-layer-lstm-hidden-output/41336/4
|
||||
# final_state = hn.view(
|
||||
# self.num_layers,
|
||||
# self.num_directions,
|
||||
# word_list.shape[0],
|
||||
# self.encoder_dim,
|
||||
# )[-1] # Only the last layer
|
||||
# h_1, h_2 = final_state[0], final_state[1]
|
||||
# # X = h_1 + h_2 # Add both states (needs different input size for first linear layer)
|
||||
# final_h = torch.cat((h_1, h_2), dim=1) # Concatenate both states
|
||||
# final_h = self.linear(final_h)
|
||||
|
||||
# hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]
|
||||
# outputs are always from the last layer.
|
||||
# hidden[-2, :, : ] is the last of the forwards RNN
|
||||
# hidden[-1, :, : ] is the last of the backwards RNN
|
||||
h_1, h_2 = hn[-2, :, : ] , hn[-1, :, : ]
|
||||
final_h = torch.cat((h_1, h_2), dim=1) # Concatenate both states
|
||||
final_h = self.linear(final_h)
|
||||
# final_h = drop_out(final_h)
|
||||
|
||||
return final_h
|
||||
@ -0,0 +1,210 @@
|
||||
import torch
|
||||
import random
|
||||
from pathlib import Path
|
||||
import sentencepiece as spm
|
||||
from typing import List
|
||||
import logging
|
||||
import ast
|
||||
import numpy as np
|
||||
|
||||
class ContextGenerator(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
path_is21_deep_bias: Path,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
n_distractors: int = 100,
|
||||
is_predefined: bool = False,
|
||||
keep_ratio: float = 1.0,
|
||||
is_full_context: bool = False,
|
||||
):
|
||||
self.sp = sp
|
||||
self.path_is21_deep_bias = path_is21_deep_bias
|
||||
self.n_distractors = n_distractors
|
||||
self.is_predefined = is_predefined
|
||||
self.keep_ratio = keep_ratio
|
||||
self.is_full_context = is_full_context # use all words (rare or common) in the context
|
||||
|
||||
logging.info(f"""
|
||||
n_distractors={n_distractors},
|
||||
is_predefined={is_predefined},
|
||||
keep_ratio={keep_ratio},
|
||||
is_full_context={is_full_context},
|
||||
""")
|
||||
|
||||
self.all_rare_words2pieces = None
|
||||
self.common_words = None
|
||||
if not is_predefined:
|
||||
with open(path_is21_deep_bias / "words/all_rare_words.txt", "r") as fin:
|
||||
all_rare_words = [l.strip().upper() for l in fin if len(l) > 0] # a list of strings
|
||||
all_rare_words_pieces = sp.encode(all_rare_words, out_type=int) # a list of list of int
|
||||
self.all_rare_words2pieces = {w: pieces for w, pieces in zip(all_rare_words, all_rare_words_pieces)}
|
||||
|
||||
with open(path_is21_deep_bias / "words/common_words_5k.txt", "r") as fin:
|
||||
self.common_words = set([l.strip().upper() for l in fin if len(l) > 0]) # a list of strings
|
||||
|
||||
logging.info(f"Number of common words: {len(self.common_words)}")
|
||||
logging.info(f"Number of rare words: {len(self.all_rare_words2pieces)}")
|
||||
|
||||
self.test_clean_biasing_list = None
|
||||
self.test_other_biasing_list = None
|
||||
if is_predefined:
|
||||
def read_ref_biasing_list(filename):
|
||||
biasing_list = dict()
|
||||
all_cnt = 0
|
||||
rare_cnt = 0
|
||||
with open(filename, "r") as fin:
|
||||
for line in fin:
|
||||
line = line.strip().upper()
|
||||
if len(line) == 0:
|
||||
continue
|
||||
line = line.split("\t")
|
||||
uid, ref_text, ref_rare_words, context_rare_words = line
|
||||
context_rare_words = ast.literal_eval(context_rare_words)
|
||||
biasing_list[uid] = [w for w in context_rare_words]
|
||||
|
||||
ref_rare_words = ast.literal_eval(ref_rare_words)
|
||||
ref_text = ref_text.split()
|
||||
all_cnt += len(ref_text)
|
||||
rare_cnt += len(ref_rare_words)
|
||||
return biasing_list, rare_cnt / all_cnt
|
||||
|
||||
self.test_clean_biasing_list, ratio_clean = \
|
||||
read_ref_biasing_list(self.path_is21_deep_bias / f"ref/test-clean.biasing_{n_distractors}.tsv")
|
||||
self.test_other_biasing_list, ratio_other = \
|
||||
read_ref_biasing_list(self.path_is21_deep_bias / f"ref/test-other.biasing_{n_distractors}.tsv")
|
||||
|
||||
logging.info(f"Number of utterances in test_clean_biasing_list: {len(self.test_clean_biasing_list)}, rare ratio={ratio_clean:.2f}")
|
||||
logging.info(f"Number of utterances in test_other_biasing_list: {len(self.test_other_biasing_list)}, rare ratio={ratio_other:.2f}")
|
||||
|
||||
# from itertools import chain
|
||||
# for uid, context_rare_words in chain(self.test_clean_biasing_list.items(), self.test_other_biasing_list.items()):
|
||||
# for w in context_rare_words:
|
||||
# if self.all_rare_words2pieces:
|
||||
# pass
|
||||
# else:
|
||||
# logging.warning(f"new word: {w}")
|
||||
|
||||
def get_context_word_list(
|
||||
self,
|
||||
batch: dict,
|
||||
):
|
||||
# import pdb; pdb.set_trace()
|
||||
if self.is_predefined:
|
||||
return self.get_context_word_list_predefined(batch=batch)
|
||||
else:
|
||||
return self.get_context_word_list_random(batch=batch)
|
||||
|
||||
def discard_some_common_words(words, keep_ratio):
|
||||
pass
|
||||
|
||||
def get_context_word_list_random(
|
||||
self,
|
||||
batch: dict,
|
||||
):
|
||||
"""
|
||||
Generate context biasing list as a list of words for each utterance
|
||||
Use keep_ratio to simulate the "imperfect" context which may not have 100% coverage of the ground truth words.
|
||||
"""
|
||||
texts = batch["supervisions"]["text"]
|
||||
|
||||
rare_words_list = []
|
||||
for text in texts:
|
||||
rare_words = []
|
||||
for word in text.split():
|
||||
if self.is_full_context or word not in self.common_words:
|
||||
rare_words.append(word)
|
||||
if word not in self.all_rare_words2pieces:
|
||||
self.all_rare_words2pieces[word] = self.sp.encode(word, out_type=int)
|
||||
|
||||
rare_words = list(set(rare_words)) # deduplication
|
||||
|
||||
if self.keep_ratio < 1.0 and len(rare_words) > 0:
|
||||
rare_words = random.sample(rare_words, int(len(rare_words) * self.keep_ratio))
|
||||
|
||||
rare_words_list.append(rare_words)
|
||||
|
||||
n_distractors = self.n_distractors
|
||||
if n_distractors == -1: # variable context list sizes
|
||||
n_distractors_each = np.random.randint(low=80, high=1000, size=len(texts))
|
||||
distractors_cnt = n_distractors_each.sum()
|
||||
else:
|
||||
n_distractors_each = np.zeros(len(texts), int)
|
||||
n_distractors_each[:] = self.n_distractors
|
||||
distractors_cnt = n_distractors_each.sum()
|
||||
|
||||
distractors = random.sample(
|
||||
self.all_rare_words2pieces.keys(),
|
||||
distractors_cnt
|
||||
) # TODO: actually the context should contain both rare and common words
|
||||
distractors_pos = 0
|
||||
rare_words_pieces_list = []
|
||||
max_pieces_len = 0
|
||||
for i, rare_words in enumerate(rare_words_list):
|
||||
rare_words.extend(distractors[distractors_pos: distractors_pos + n_distractors_each[i]])
|
||||
distractors_pos += n_distractors_each[i]
|
||||
# random.shuffle(rare_words)
|
||||
# logging.info(rare_words)
|
||||
|
||||
rare_words_pieces = [self.all_rare_words2pieces[w] for w in rare_words]
|
||||
if len(rare_words_pieces) > 0:
|
||||
max_pieces_len = max(max_pieces_len, max(len(pieces) for pieces in rare_words_pieces))
|
||||
rare_words_pieces_list.append(rare_words_pieces)
|
||||
assert distractors_pos == len(distractors)
|
||||
|
||||
word_list = []
|
||||
word_lengths = []
|
||||
num_words_per_utt = []
|
||||
pad_token = 0
|
||||
for rare_words_pieces in rare_words_pieces_list:
|
||||
num_words_per_utt.append(len(rare_words_pieces))
|
||||
word_lengths.extend([len(pieces) for pieces in rare_words_pieces])
|
||||
|
||||
for pieces in rare_words_pieces:
|
||||
pieces += [pad_token] * (max_pieces_len - len(pieces))
|
||||
word_list.extend(rare_words_pieces)
|
||||
|
||||
word_list = torch.tensor(word_list, dtype=torch.int32)
|
||||
# word_lengths = torch.tensor(word_lengths, dtype=torch.int32)
|
||||
# num_words_per_utt = torch.tensor(num_words_per_utt, dtype=torch.int32)
|
||||
|
||||
return word_list, word_lengths, num_words_per_utt
|
||||
|
||||
def get_context_word_list_predefined(
|
||||
self,
|
||||
batch: dict,
|
||||
):
|
||||
rare_words_list = []
|
||||
for cut in batch['supervisions']['cut']:
|
||||
uid = cut.supervisions[0].id
|
||||
if uid in self.test_clean_biasing_list:
|
||||
rare_words_list.append(self.test_clean_biasing_list[uid])
|
||||
elif uid in self.test_other_biasing_list:
|
||||
rare_words_list.append(self.test_other_biasing_list[uid])
|
||||
else:
|
||||
logging.error(f"uid={uid} cannot find the predefined biasing list of size {self.n_distractors}")
|
||||
|
||||
rare_words_pieces_list = []
|
||||
max_pieces_len = 0
|
||||
for rare_words in rare_words_list:
|
||||
# logging.info(rare_words)
|
||||
rare_words_pieces = self.sp.encode(rare_words, out_type=int)
|
||||
max_pieces_len = max(max_pieces_len, max(len(pieces) for pieces in rare_words_pieces))
|
||||
rare_words_pieces_list.append(rare_words_pieces)
|
||||
|
||||
word_list = []
|
||||
word_lengths = []
|
||||
num_words_per_utt = []
|
||||
pad_token = 0
|
||||
for rare_words_pieces in rare_words_pieces_list:
|
||||
num_words_per_utt.append(len(rare_words_pieces))
|
||||
word_lengths.extend([len(pieces) for pieces in rare_words_pieces])
|
||||
|
||||
for pieces in rare_words_pieces:
|
||||
pieces += [pad_token] * (max_pieces_len - len(pieces))
|
||||
word_list.extend(rare_words_pieces)
|
||||
|
||||
word_list = torch.tensor(word_list, dtype=torch.int32)
|
||||
# word_lengths = torch.tensor(word_lengths, dtype=torch.int32)
|
||||
# num_words_per_utt = torch.tensor(num_words_per_utt, dtype=torch.int32)
|
||||
|
||||
return word_list, word_lengths, num_words_per_utt
|
||||
@ -0,0 +1,371 @@
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from kaldifst.utils import k2_to_openfst
|
||||
|
||||
def generate_context_graph_simple(
|
||||
words_pieces_list: list,
|
||||
backoff_id: int,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
bonus_per_token: float = 0.1,
|
||||
):
|
||||
"""Generate the context graph (in kaldifst format) given
|
||||
the lexicon of the biasing list.
|
||||
|
||||
This context graph is a WFST as in
|
||||
`https://arxiv.org/abs/1808.02480`
|
||||
or
|
||||
`https://wenet.org.cn/wenet/context.html`.
|
||||
It is simple, as it does not have the capability to detect
|
||||
word boundaries. So, if a biasing word (e.g., 'us', the country)
|
||||
happens to be the prefix of another word (e.g., 'useful'),
|
||||
the word 'useful' will be mistakenly boosted. This is not desired.
|
||||
However, this context graph is easy to understand.
|
||||
|
||||
Args:
|
||||
words_pieces_list:
|
||||
A list (batch) of lists. Each sub-list contains the context for
|
||||
the utterance. The sub-list again is a list of lists. Each sub-sub-list
|
||||
is the token sequence of a word.
|
||||
backoff_id:
|
||||
The id of the backoff token. It serves for failure arcs.
|
||||
bonus_per_token:
|
||||
The bonus for each token during decoding, which will hopefully
|
||||
boost the token up to survive beam search.
|
||||
|
||||
Returns:
|
||||
Return an `openfst` object representing the context graph.
|
||||
"""
|
||||
# note: `k2_to_openfst` will multiply it with -1. So it will become +1 in the end.
|
||||
flip = -1
|
||||
|
||||
fsa_list = []
|
||||
fsa_sizes = []
|
||||
for words_pieces in words_pieces_list:
|
||||
start_state = 0
|
||||
next_state = 1 # the next un-allocated state, will be incremented as we go.
|
||||
arcs = []
|
||||
|
||||
arcs.append([start_state, start_state, backoff_id, 0, 0.0])
|
||||
# for token_id in range(sp.vocab_size()):
|
||||
# arcs.append([start_state, start_state, token_id, 0, 0.0])
|
||||
|
||||
for tokens in words_pieces:
|
||||
assert len(tokens) > 0
|
||||
cur_state = start_state
|
||||
|
||||
for i in range(len(tokens) - 1):
|
||||
arcs.append(
|
||||
[
|
||||
cur_state,
|
||||
next_state,
|
||||
tokens[i],
|
||||
0,
|
||||
flip * bonus_per_token
|
||||
]
|
||||
)
|
||||
arcs.append(
|
||||
[
|
||||
next_state,
|
||||
start_state,
|
||||
backoff_id,
|
||||
0,
|
||||
flip * -bonus_per_token * (i + 1),
|
||||
]
|
||||
)
|
||||
|
||||
cur_state = next_state
|
||||
next_state += 1
|
||||
|
||||
# now for the last token of this word
|
||||
i = len(tokens) - 1
|
||||
arcs.append([cur_state, start_state, tokens[i], 0, flip * bonus_per_token])
|
||||
|
||||
final_state = next_state
|
||||
arcs.append([start_state, final_state, -1, -1, 0])
|
||||
arcs.append([final_state])
|
||||
|
||||
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||
arcs = [" ".join(arc) for arc in arcs]
|
||||
arcs = "\n".join(arcs)
|
||||
|
||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||
fsa = k2.arc_sort(fsa)
|
||||
fsa_sizes.append((fsa.shape[0], fsa.num_arcs)) # (n_states, n_arcs)
|
||||
|
||||
fsa = k2_to_openfst(fsa, olabels="aux_labels")
|
||||
fsa_list.append(fsa)
|
||||
|
||||
return fsa_list, fsa_sizes
|
||||
|
||||
|
||||
def generate_context_graph_nfa(
|
||||
words_pieces_list: list,
|
||||
backoff_id: int,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
bonus_per_token: float = 0.1,
|
||||
):
|
||||
"""Generate the context graph (in kaldifst format) given
|
||||
the lexicon of the biasing list.
|
||||
|
||||
This context graph is a WFST capable of detecting word boundaries.
|
||||
It is epsilon-free, non-deterministic.
|
||||
|
||||
Args:
|
||||
words_pieces_list:
|
||||
A list (batch) of lists. Each sub-list contains the context for
|
||||
the utterance. The sub-list again is a list of lists. Each sub-sub-list
|
||||
is the token sequence of a word.
|
||||
backoff_id:
|
||||
The id of the backoff token. It serves for failure arcs.
|
||||
bonus_per_token:
|
||||
The bonus for each token during decoding, which will hopefully
|
||||
boost the token up to survive beam search.
|
||||
|
||||
Returns:
|
||||
Return an `openfst` object representing the context graph.
|
||||
"""
|
||||
# note: `k2_to_openfst` will multiply it with -1. So it will become +1 in the end.
|
||||
flip = -1
|
||||
|
||||
fsa_list = []
|
||||
fsa_sizes = []
|
||||
for words_pieces in words_pieces_list:
|
||||
start_state = 0
|
||||
# if the path go through this state, then a word boundary is detected
|
||||
boundary_state = 1
|
||||
next_state = 2 # the next un-allocated state, will be incremented as we go.
|
||||
arcs = []
|
||||
|
||||
# arcs.append([start_state, start_state, backoff_id, 0, 0.0])
|
||||
for token_id in range(sp.vocab_size()):
|
||||
arcs.append([start_state, start_state, token_id, 0, 0.0])
|
||||
# arcs.append([boundary_state, start_state, token_id, 0, 0.0]) # Note: adding this line here degrades performance. Why? Because it will break the ability to detect word boundary.
|
||||
|
||||
for tokens in words_pieces:
|
||||
assert len(tokens) > 0
|
||||
cur_state = start_state
|
||||
|
||||
# static/constant bonus per token
|
||||
# my_bonus_per_token = flip * bonus_per_token * biasing_list[word]
|
||||
# my_bonus_per_token = flip * 1.0 / len(tokens) * biasing_list[word]
|
||||
my_bonus_per_token = flip * bonus_per_token # TODO: support weighted biasing list
|
||||
|
||||
for i in range(len(tokens) - 1):
|
||||
arcs.append(
|
||||
[cur_state, next_state, tokens[i], 0, my_bonus_per_token]
|
||||
)
|
||||
if i == 0:
|
||||
arcs.append(
|
||||
[boundary_state, next_state, tokens[i], 0, my_bonus_per_token]
|
||||
)
|
||||
|
||||
cur_state = next_state
|
||||
next_state += 1
|
||||
|
||||
# now for the last token of this word
|
||||
i = len(tokens) - 1
|
||||
arcs.append(
|
||||
[cur_state, boundary_state, tokens[i], 0, my_bonus_per_token]
|
||||
)
|
||||
|
||||
for token_id in range(sp.vocab_size()):
|
||||
token = sp.id_to_piece(token_id)
|
||||
if token.startswith("▁"):
|
||||
arcs.append([boundary_state, start_state, token_id, 0, 0.0])
|
||||
|
||||
final_state = next_state
|
||||
arcs.append([start_state, final_state, -1, -1, 0])
|
||||
arcs.append([boundary_state, final_state, -1, -1, 0])
|
||||
arcs.append([final_state])
|
||||
|
||||
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||
arcs = [" ".join(arc) for arc in arcs]
|
||||
arcs = "\n".join(arcs)
|
||||
|
||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||
fsa = k2.arc_sort(fsa)
|
||||
# fsa = k2.determinize(fsa) # No weight pushing is needed.
|
||||
fsa_sizes.append((fsa.shape[0], fsa.num_arcs)) # (n_states, n_arcs)
|
||||
|
||||
fsa = k2_to_openfst(fsa, olabels="aux_labels")
|
||||
fsa_list.append(fsa)
|
||||
|
||||
return fsa_list, fsa_sizes
|
||||
|
||||
|
||||
class TrieNode:
|
||||
def __init__(self, token):
|
||||
self.token = token
|
||||
self.is_end = False
|
||||
self.state_id = None
|
||||
self.weight = -1e9
|
||||
self.children = {}
|
||||
|
||||
|
||||
class Trie(object):
|
||||
# https://albertauyeung.github.io/2020/06/15/python-trie.html/
|
||||
|
||||
def __init__(self):
|
||||
self.root = TrieNode("")
|
||||
|
||||
def insert(self, word_tokens, per_token_weight):
|
||||
"""Insert a word into the trie"""
|
||||
node = self.root
|
||||
|
||||
# Loop through each token in the word
|
||||
# Check if there is no child containing the token, create a new child for the current node
|
||||
for token in word_tokens:
|
||||
if token in node.children:
|
||||
node = node.children[token]
|
||||
node.weight = max(per_token_weight, node.weight)
|
||||
else:
|
||||
# If a character is not found,
|
||||
# create a new node in the trie
|
||||
new_node = TrieNode(token)
|
||||
node.children[token] = new_node
|
||||
node = new_node
|
||||
node.weight = per_token_weight
|
||||
|
||||
# Mark the end of a word
|
||||
node.is_end = True
|
||||
|
||||
|
||||
def generate_context_graph_nfa_trie(
|
||||
words_pieces_list: list,
|
||||
backoff_id: int,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
bonus_per_token: float = 0.1,
|
||||
):
|
||||
"""Generate the context graph (in kaldifst format) given
|
||||
the lexicon of the biasing list.
|
||||
|
||||
This context graph is a WFST capable of detecting word boundaries.
|
||||
It is epsilon-free, non-deterministic.
|
||||
It is also optimized such that the words are organized in a trie. (from Kang Wei)
|
||||
However, this trie may not be good for having a seperate score for each word? Or not?
|
||||
|
||||
Args:
|
||||
words_pieces_list:
|
||||
A list (batch) of lists. Each sub-list contains the context for
|
||||
the utterance. The sub-list again is a list of lists. Each sub-sub-list
|
||||
is the token sequence of a word.
|
||||
backoff_id:
|
||||
The id of the backoff token. It serves for failure arcs.
|
||||
bonus_per_token:
|
||||
The bonus for each token during decoding, which will hopefully
|
||||
boost the token up to survive beam search.
|
||||
|
||||
Returns:
|
||||
Return an `openfst` object representing the context graph.
|
||||
"""
|
||||
# note: `k2_to_openfst` will multiply it with -1. So it will become +1 in the end.
|
||||
flip = -1
|
||||
|
||||
fsa_list = [] # These are the fsas in a batch
|
||||
fsa_sizes = []
|
||||
for words_pieces in words_pieces_list:
|
||||
start_state = 0
|
||||
# if the path go through this state, then a word boundary is detected
|
||||
boundary_state = 1
|
||||
next_state = 2 # the next un-allocated state, will be incremented as we go.
|
||||
arcs = []
|
||||
|
||||
# arcs.append([start_state, start_state, backoff_id, 0, 0.0])
|
||||
for token_id in range(sp.vocab_size()):
|
||||
arcs.append([start_state, start_state, token_id, 0, 0.0])
|
||||
|
||||
# First, compile the word list into a trie (prefix tree)
|
||||
trie = Trie()
|
||||
for tokens in words_pieces:
|
||||
|
||||
# static/constant bonus per token
|
||||
# my_bonus_per_token = flip * bonus_per_token * biasing_list[word]
|
||||
# my_bonus_per_token = flip * 1.0 / len(tokens) * biasing_list[word]
|
||||
my_bonus_per_token = flip * bonus_per_token # TODO: support weighted biasing list
|
||||
|
||||
trie.insert(tokens, my_bonus_per_token)
|
||||
|
||||
trie.root.state_id = start_state
|
||||
node_list = [trie.root]
|
||||
while len(node_list) > 0:
|
||||
cur_node = node_list.pop(0)
|
||||
if cur_node.is_end:
|
||||
continue
|
||||
|
||||
for token, child_node in cur_node.children.items():
|
||||
if child_node.state_id is None and not child_node.is_end:
|
||||
child_node.state_id = next_state
|
||||
next_state += 1
|
||||
|
||||
if child_node.is_end:
|
||||
# If the word has finished, go to the boundary state
|
||||
arcs.append(
|
||||
[
|
||||
cur_node.state_id,
|
||||
boundary_state,
|
||||
token,
|
||||
0,
|
||||
child_node.weight,
|
||||
]
|
||||
)
|
||||
else:
|
||||
# Otherwise, add an arc from the current node to its child node
|
||||
arcs.append(
|
||||
[
|
||||
cur_node.state_id,
|
||||
child_node.state_id,
|
||||
token,
|
||||
0,
|
||||
child_node.weight,
|
||||
]
|
||||
)
|
||||
|
||||
# If this is the first token in a word,
|
||||
# also add an arc from the boundary node to the child node
|
||||
if cur_node == trie.root:
|
||||
arcs.append(
|
||||
[
|
||||
boundary_state,
|
||||
child_node.state_id,
|
||||
token,
|
||||
0,
|
||||
child_node.weight,
|
||||
]
|
||||
)
|
||||
|
||||
node_list.append(child_node)
|
||||
|
||||
for token_id in range(sp.vocab_size()):
|
||||
token = sp.id_to_piece(token_id)
|
||||
if token.startswith("▁"):
|
||||
arcs.append([boundary_state, start_state, token_id, 0, 0.0])
|
||||
|
||||
final_state = next_state
|
||||
arcs.append([start_state, final_state, -1, -1, 0])
|
||||
arcs.append([boundary_state, final_state, -1, -1, 0])
|
||||
arcs.append([final_state])
|
||||
|
||||
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||
arcs = [" ".join(arc) for arc in arcs]
|
||||
arcs = "\n".join(arcs)
|
||||
|
||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||
fsa = k2.arc_sort(fsa)
|
||||
# fsa = k2.determinize(fsa) # No weight pushing is needed.
|
||||
fsa_sizes.append((fsa.shape[0], fsa.num_arcs)) # (n_states, n_arcs)
|
||||
|
||||
fsa = k2_to_openfst(fsa, olabels="aux_labels")
|
||||
fsa_list.append(fsa)
|
||||
|
||||
return fsa_list, fsa_sizes
|
||||
|
||||
1271
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/decode.py
Executable file
1271
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/decode.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,861 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao,
|
||||
# Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search (one best)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
(5) fast beam search (nbest)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(6) fast beam search (nbest oracle WER)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_oracle \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(7) fast beam search (with LG)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_LG \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from gigaspeech import GigaSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from gigaspeech_scoring import asr_text_post_processing
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bpe_500",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--simulate-streaming",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to simulate streaming in decoding, this is a good way to
|
||||
test a streaming model.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decode-chunk-size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The chunk size for decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--left-context",
|
||||
type=int,
|
||||
default=64,
|
||||
help="left context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def post_processing(
|
||||
results: List[Tuple[str, List[str], List[str]]],
|
||||
) -> List[Tuple[str, List[str], List[str]]]:
|
||||
new_results = []
|
||||
for key, ref, hyp in results:
|
||||
new_ref = asr_text_post_processing(" ".join(ref)).split()
|
||||
new_hyp = asr_text_post_processing(" ".join(hyp)).split()
|
||||
new_results.append((key, new_ref, new_hyp))
|
||||
return new_results
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
if params.simulate_streaming:
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
chunk_size=params.decode_chunk_size,
|
||||
left_context=params.left_context,
|
||||
simulate_streaming=True,
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest":
|
||||
hyp_tokens = fast_beam_search_nbest(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=sp.encode(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = post_processing(results)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
"""
|
||||
This scripts test a libri model with libri BPE
|
||||
on Gigaspeech.
|
||||
"""
|
||||
parser = get_parser()
|
||||
GigaSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / (params.decoding_method + "_gigaspeech")
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.simulate_streaming:
|
||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
if params.simulate_streaming:
|
||||
assert (
|
||||
params.causal_convolution
|
||||
), "Decoding in streaming requires causal convolution"
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
word_table = lexicon.word_table
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
word_table = None
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
gigaspeech = GigaSpeechAsrDataModule(args)
|
||||
|
||||
dev_cuts = gigaspeech.dev_cuts()
|
||||
test_cuts = gigaspeech.test_cuts()
|
||||
|
||||
dev_dl = gigaspeech.test_dataloaders(dev_cuts)
|
||||
test_dl = gigaspeech.test_dataloaders(test_cuts)
|
||||
|
||||
test_sets = ["dev", "test"]
|
||||
test_dls = [dev_dl, test_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1257
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/decode_sbc.py
Executable file
1257
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/decode_sbc.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,993 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao,
|
||||
# Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search (one best)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
(5) fast beam search (nbest)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(6) fast beam search (nbest oracle WER)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_oracle \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(7) fast beam search (with LG)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_LG \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
(8) modified beam search with RNNLM shallow fusion
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 35 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search_lm_shallow_fusion \
|
||||
--beam-size 4 \
|
||||
--lm-type rnn \
|
||||
--lm-scale 0.3 \
|
||||
--lm-exp-dir /path/to/LM \
|
||||
--rnn-lm-epoch 99 \
|
||||
--rnn-lm-avg 1 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--rnn-lm-tie-weights 1
|
||||
|
||||
(9) modified beam search with LM shallow fusion + LODR
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--max-duration 600 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--decoding-method modified_beam_search_LODR \
|
||||
--beam-size 4 \
|
||||
--lm-type rnn \
|
||||
--lm-scale 0.4 \
|
||||
--lm-exp-dir /path/to/LM \
|
||||
--rnn-lm-epoch 99 \
|
||||
--rnn-lm-avg 1 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--rnn-lm-tie-weights 1
|
||||
--tokens-ngram 2 \
|
||||
--ngram-lm-scale -0.16 \
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
modified_beam_search_lm_shallow_fusion,
|
||||
modified_beam_search_LODR,
|
||||
modified_beam_search_ngram_rescoring,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import LmScorer, NgramLm
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bpe_500",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
- modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion
|
||||
- modified_beam_search_LODR
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-shallow-fusion",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Use neural network LM for shallow fusion.
|
||||
If you want to use LODR, you will also need to set this to true
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-type",
|
||||
type=str,
|
||||
default="rnn",
|
||||
help="Type of NN lm",
|
||||
choices=["rnn", "transformer"],
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-scale",
|
||||
type=float,
|
||||
default=0.3,
|
||||
help="""The scale of the neural network LM
|
||||
Used only when `--use-shallow-fusion` is set to True.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens-ngram",
|
||||
type=int,
|
||||
default=3,
|
||||
help="""Token Ngram used for rescoring.
|
||||
Used only when the decoding method is
|
||||
modified_beam_search_ngram_rescoring, or LODR
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--backoff-id",
|
||||
type=int,
|
||||
default=500,
|
||||
help="""ID of the backoff symbol.
|
||||
Used only when the decoding method is
|
||||
modified_beam_search_ngram_rescoring""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--part',
|
||||
type=str,
|
||||
default=None,
|
||||
help='e.g., 1/3'
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
ngram_lm: Optional[NgramLm] = None,
|
||||
ngram_lm_scale: float = 1.0,
|
||||
LM: Optional[LmScorer] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
LM:
|
||||
A neural net LM for shallow fusion. Only used when `--use-shallow-fusion`
|
||||
set to true.
|
||||
ngram_lm:
|
||||
A ngram lm. Used in LODR decoding.
|
||||
ngram_lm_scale:
|
||||
The scale of the ngram language model.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest":
|
||||
hyp_tokens = fast_beam_search_nbest(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=sp.encode(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
|
||||
hyp_tokens = modified_beam_search_lm_shallow_fusion(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
LM=LM,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search_LODR":
|
||||
hyp_tokens = modified_beam_search_LODR(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
LODR_lm=ngram_lm,
|
||||
LODR_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
ngram_lm: Optional[NgramLm] = None,
|
||||
ngram_lm_scale: float = 1.0,
|
||||
LM: Optional[LmScorer] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
LM:
|
||||
A neural network LM, used during shallow fusion
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
LmScorer.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
if args.part is not None:
|
||||
args.part = args.part.split("/")
|
||||
args.part = (int(args.part[0]), int(args.part[1]))
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
"modified_beam_search_lm_shallow_fusion",
|
||||
"modified_beam_search_LODR",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if "ngram" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
if params.use_shallow_fusion:
|
||||
if params.lm_type == "rnn":
|
||||
params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}"
|
||||
elif params.lm_type == "transformer":
|
||||
params.suffix += f"-transformer-lm-scale-{params.lm_scale}"
|
||||
|
||||
if "LODR" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
|
||||
)
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# only load N-gram LM when needed
|
||||
if "ngram" in params.decoding_method or "LODR" in params.decoding_method:
|
||||
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
|
||||
logging.info(f"lm filename: {lm_filename}")
|
||||
ngram_lm = NgramLm(
|
||||
str(params.lang_dir / lm_filename),
|
||||
backoff_id=params.backoff_id,
|
||||
is_binary=False,
|
||||
)
|
||||
logging.info(f"num states: {ngram_lm.lm.num_states}")
|
||||
ngram_lm_scale = params.ngram_lm_scale
|
||||
else:
|
||||
ngram_lm = None
|
||||
ngram_lm_scale = None
|
||||
|
||||
# only load the neural network LM if doing shallow fusion
|
||||
if params.use_shallow_fusion:
|
||||
LM = LmScorer(
|
||||
lm_type=params.lm_type,
|
||||
params=params,
|
||||
device=device,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
LM.to(device)
|
||||
LM.eval()
|
||||
|
||||
else:
|
||||
LM = None
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
word_table = lexicon.word_table
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
word_table = None
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
# test_clean_cuts = librispeech.test_clean_cuts()
|
||||
# test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
# test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
# test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
# test_sets = ["test-clean", "test-other"]
|
||||
# test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
|
||||
train_cuts = librispeech.train_all_shuf_cuts()
|
||||
if args.part is not None:
|
||||
from lhotse import CutSet
|
||||
part = args.part[0]; n_parts = args.part[1]
|
||||
# train_cuts = CutSet.from_cuts([c for c in train_cuts if "_sp" not in c.id and int(c.id.split("-")[-1]) % n_parts == part])
|
||||
train_cuts = CutSet.from_cuts([c for c in train_cuts if int(c.id.split("-")[-2]) % n_parts == part % n_parts])
|
||||
train_cuts = train_cuts.sort_by_duration(ascending=False)
|
||||
# train_cuts.describe()
|
||||
logging.info(f"part-{part}/{n_parts} len(train_cuts) = {len(train_cuts)}")
|
||||
else:
|
||||
part = 0
|
||||
n_parts = 0
|
||||
# part = "val"
|
||||
|
||||
train_dl = librispeech.train_dataloaders(train_cuts)
|
||||
test_sets = [f"train-{part}.{n_parts}"]
|
||||
test_dl = [train_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1256
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/decode_uniphore.py
Executable file
1256
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/decode_uniphore.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,105 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""This class modifies the stateless decoder from the following paper:
|
||||
|
||||
RNN-transducer with stateless prediction network
|
||||
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
||||
|
||||
It removes the recurrent connection from the decoder, i.e., the prediction
|
||||
network. Different from the above paper, it adds an extra Conv1d
|
||||
right after the embedding layer.
|
||||
|
||||
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
decoder_dim: int,
|
||||
blank_id: int,
|
||||
context_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
vocab_size:
|
||||
Number of tokens of the modeling unit including blank.
|
||||
decoder_dim:
|
||||
Dimension of the input embedding, and of the decoder output.
|
||||
blank_id:
|
||||
The ID of the blank symbol.
|
||||
context_size:
|
||||
Number of previous words to use to predict the next word.
|
||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.embedding = nn.Embedding(
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=decoder_dim,
|
||||
)
|
||||
self.blank_id = blank_id
|
||||
|
||||
assert context_size >= 1, context_size
|
||||
self.context_size = context_size
|
||||
self.vocab_size = vocab_size
|
||||
if context_size > 1:
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels=decoder_dim,
|
||||
out_channels=decoder_dim,
|
||||
kernel_size=context_size,
|
||||
padding=0,
|
||||
groups=decoder_dim // 4, # group size == 4
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, U).
|
||||
need_pad:
|
||||
True to left pad the input. Should be True during training.
|
||||
False to not pad the input. Should be False during inference.
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, decoder_dim).
|
||||
"""
|
||||
y = y.to(torch.int64)
|
||||
# this stuff about clamp() is a temporary fix for a mismatch
|
||||
# at utterance start, we use negative ids in beam_search.py
|
||||
if torch.jit.is_tracing():
|
||||
# This is for exporting to PNNX via ONNX
|
||||
embedding_out = self.embedding(y)
|
||||
else:
|
||||
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
|
||||
if self.context_size > 1:
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
if need_pad is True:
|
||||
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
|
||||
else:
|
||||
# During inference time, there is no need to do extra padding
|
||||
# as we only need one output
|
||||
assert embedding_out.size(-1) == self.context_size
|
||||
embedding_out = self.conv(embedding_out)
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
embedding_out = F.relu(embedding_out)
|
||||
return embedding_out
|
||||
@ -0,0 +1,43 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class EncoderInterface(nn.Module):
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A tensor of shape (batch_size, input_seq_len, num_features)
|
||||
containing the input features.
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames
|
||||
in `x` before padding.
|
||||
Returns:
|
||||
Return a tuple containing two tensors:
|
||||
- encoder_out, a tensor of (batch_size, out_seq_len, output_dim)
|
||||
containing unnormalized probabilities, i.e., the output of a
|
||||
linear layer.
|
||||
- encoder_out_lens, a tensor of shape (batch_size,) containing
|
||||
the number of frames in `encoder_out` before padding.
|
||||
"""
|
||||
raise NotImplementedError("Please implement it in a subclass")
|
||||
560
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/export-onnx.py
Executable file
560
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/export-onnx.py
Executable file
@ -0,0 +1,560 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This script exports a transducer model from PyTorch to ONNX.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained-epoch-30-avg-9.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained-epoch-30-avg-9.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./pruned_transducer_stateless7/export-onnx.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp \
|
||||
--feedforward-dims "1024,1024,2048,2048,1024"
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
See ./onnx_pretrained.py and ./onnx_check.py for how to
|
||||
use the exported ONNX models.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import onnx
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from decoder import Decoder
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from zipformer import Zipformer
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import setup_logger, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless5/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = value
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
class OnnxEncoder(nn.Module):
|
||||
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
||||
|
||||
def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
A Zipformer encoder.
|
||||
encoder_proj:
|
||||
The projection layer for encoder from the joiner.
|
||||
"""
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.encoder_proj = encoder_proj
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Please see the help information of Zipformer.forward
|
||||
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). Its dtype is torch.int64
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
|
||||
- encoder_out_lens, A 1-D tensor of shape (N,)
|
||||
"""
|
||||
encoder_out, encoder_out_lens = self.encoder(x, x_lens)
|
||||
|
||||
encoder_out = self.encoder_proj(encoder_out)
|
||||
# Now encoder_out is of shape (N, T, joiner_dim)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
|
||||
class OnnxDecoder(nn.Module):
|
||||
"""A wrapper for Decoder and the decoder_proj from the joiner"""
|
||||
|
||||
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
|
||||
super().__init__()
|
||||
self.decoder = decoder
|
||||
self.decoder_proj = decoder_proj
|
||||
|
||||
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, context_size).
|
||||
Returns
|
||||
Return a 2-D tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
need_pad = False
|
||||
decoder_output = self.decoder(y, need_pad=need_pad)
|
||||
decoder_output = decoder_output.squeeze(1)
|
||||
output = self.decoder_proj(decoder_output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class OnnxJoiner(nn.Module):
|
||||
"""A wrapper for the joiner"""
|
||||
|
||||
def __init__(self, output_linear: nn.Linear):
|
||||
super().__init__()
|
||||
self.output_linear = output_linear
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
decoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, vocab_size)
|
||||
"""
|
||||
logit = encoder_out + decoder_out
|
||||
logit = self.output_linear(torch.tanh(logit))
|
||||
return logit
|
||||
|
||||
|
||||
def export_encoder_model_onnx(
|
||||
encoder_model: OnnxEncoder,
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the given encoder model to ONNX format.
|
||||
The exported model has two inputs:
|
||||
|
||||
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||
|
||||
and it has two outputs:
|
||||
|
||||
- encoder_out, a tensor of shape (N, T', joiner_dim)
|
||||
- encoder_out_lens, a tensor of shape (N,)
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_model,
|
||||
(x, x_lens),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "x_lens"],
|
||||
output_names=["encoder_out", "encoder_out_lens"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"x_lens": {0: "N"},
|
||||
"encoder_out": {0: "N", 1: "T"},
|
||||
"encoder_out_lens": {0: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def export_decoder_model_onnx(
|
||||
decoder_model: OnnxDecoder,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
|
||||
The exported model has one input:
|
||||
|
||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||
|
||||
and has one output:
|
||||
|
||||
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
context_size = decoder_model.decoder.context_size
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
meta_data = {
|
||||
"context_size": str(context_size),
|
||||
"vocab_size": str(vocab_size),
|
||||
}
|
||||
add_meta_data(filename=decoder_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
def export_joiner_model_onnx(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported joiner model has two inputs:
|
||||
|
||||
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- logit: a tensor of shape (N, vocab_size)
|
||||
"""
|
||||
joiner_dim = joiner_model.output_linear.weight.shape[1]
|
||||
logging.info(f"joiner dim: {joiner_dim}")
|
||||
|
||||
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(projected_encoder_out, projected_decoder_out),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"encoder_out",
|
||||
"decoder_out",
|
||||
],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
meta_data = {
|
||||
"joiner_dim": str(joiner_dim),
|
||||
}
|
||||
add_meta_data(filename=joiner_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log-export/log-export-onnx")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
|
||||
encoder = OnnxEncoder(
|
||||
encoder=model.encoder,
|
||||
encoder_proj=model.joiner.encoder_proj,
|
||||
)
|
||||
|
||||
decoder = OnnxDecoder(
|
||||
decoder=model.decoder,
|
||||
decoder_proj=model.joiner.decoder_proj,
|
||||
)
|
||||
|
||||
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
|
||||
|
||||
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
|
||||
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
|
||||
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
|
||||
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||
logging.info(f"total parameters: {total_num_param}")
|
||||
|
||||
if params.iter > 0:
|
||||
suffix = f"iter-{params.iter}"
|
||||
else:
|
||||
suffix = f"epoch-{params.epoch}"
|
||||
|
||||
suffix += f"-avg-{params.avg}"
|
||||
|
||||
opset_version = 13
|
||||
|
||||
logging.info("Exporting encoder")
|
||||
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
|
||||
export_encoder_model_onnx(
|
||||
encoder,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported encoder to {encoder_filename}")
|
||||
|
||||
logging.info("Exporting decoder")
|
||||
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
|
||||
export_decoder_model_onnx(
|
||||
decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported decoder to {decoder_filename}")
|
||||
|
||||
logging.info("Exporting joiner")
|
||||
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
|
||||
export_joiner_model_onnx(
|
||||
joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported joiner to {joiner_filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
main()
|
||||
320
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/export.py
Executable file
320
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/export.py
Executable file
@ -0,0 +1,320 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
|
||||
Usage:
|
||||
|
||||
(1) Export to torchscript model using torch.jit.script()
|
||||
|
||||
./pruned_transducer_stateless7/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
|
||||
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
|
||||
load it by `torch.jit.load("cpu_jit.pt")`.
|
||||
|
||||
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
||||
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
||||
|
||||
Check
|
||||
https://github.com/k2-fsa/sherpa
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
(2) Export `model.state_dict()`
|
||||
|
||||
./pruned_transducer_stateless7/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||
|
||||
To use the generated file with `pruned_transducer_stateless7/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search \
|
||||
--bpe-model data/lang_bpe_500/bpe.model
|
||||
|
||||
Check ./pretrained.py for its usage.
|
||||
|
||||
Note: If you don't want to train a model from scratch, we have
|
||||
provided one for you. You can get it at
|
||||
|
||||
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
|
||||
|
||||
with the following commands:
|
||||
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
|
||||
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
It will generate a file named cpu_jit.pt
|
||||
|
||||
Check ./jit_pretrained.py for how to use it.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.jit is True:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torchscript. Export model.state_dict()")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
1342
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/finetune.py
Executable file
1342
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/finetune.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,282 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) use the averaged model with checkpoint exp_dir/epoch-xxx.pt
|
||||
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--use-averaged-model True \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp
|
||||
|
||||
It will generate a file `epoch-28-avg-15-use-averaged-model.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt")`.
|
||||
|
||||
(2) use the averaged model with checkpoint exp_dir/checkpoint-iter.pt
|
||||
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
|
||||
--iter 22000 \
|
||||
--avg 5 \
|
||||
--use-averaged-model True \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp
|
||||
|
||||
It will generate a file `iter-22000-avg-5-use-averaged-model.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt")`.
|
||||
|
||||
(3) use the original model with checkpoint exp_dir/epoch-xxx.pt
|
||||
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--use-averaged-model False \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp
|
||||
|
||||
It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.load("epoch-28-avg-15.pt")`.
|
||||
|
||||
(4) use the original model with checkpoint exp_dir/checkpoint-iter.pt
|
||||
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
|
||||
--iter 22000 \
|
||||
--avg 5 \
|
||||
--use-averaged-model False \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp
|
||||
|
||||
It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.load("iter-22000-avg-5.pt")`.
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.utils import str2bool
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model."
|
||||
"If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
print("Script started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
print(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
print("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
print(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt"
|
||||
torch.save({"model": model.state_dict()}, filename)
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
torch.save({"model": model.state_dict()}, filename)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
print(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
torch.save({"model": model.state_dict()}, filename)
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
print(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
filename = (
|
||||
params.exp_dir
|
||||
/ f"iter-{params.iter}-avg-{params.avg}-use-averaged-model.pt"
|
||||
)
|
||||
torch.save({"model": model.state_dict()}, filename)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
print(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
filename = (
|
||||
params.exp_dir
|
||||
/ f"epoch-{params.epoch}-avg-{params.avg}-use-averaged-model.pt"
|
||||
)
|
||||
torch.save({"model": model.state_dict()}, filename)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
print("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -0,0 +1,406 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import (
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class GigaSpeechAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--concatenate-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, utterances (cuts) will be concatenated "
|
||||
"to minimize the amount of padding.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--duration-factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Determines the maximum duration of a concatenated cut "
|
||||
"relative to the duration of the longest cut in a batch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gap",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The amount of padding (in seconds) inserted between "
|
||||
"concatenated cuts. This padding is filled with noise when "
|
||||
"noise augmentation is used.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it "
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
# GigaSpeech specific arguments
|
||||
group.add_argument(
|
||||
"--subset",
|
||||
type=str,
|
||||
default="XL",
|
||||
help="Select the GigaSpeech subset (XS|S|M|L|XL)",
|
||||
)
|
||||
group.add_argument(
|
||||
"--small-dev",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Should we use only 1000 utterances for dev (speeds up training)",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||
)
|
||||
# Cut concatenation should be the first transform in the list,
|
||||
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||
# different utterances.
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
# NOTE: the PerturbSpeed transform should be added only if we
|
||||
# remove it from data prep stage.
|
||||
# Add on-the-fly speed perturbation; since originally it would
|
||||
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||
# 3x more epochs.
|
||||
# Speed perturbation probably should come first before
|
||||
# concatenation, but in principle the transforms order doesn't have
|
||||
# to be strict (e.g. could be randomized)
|
||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=True,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info(f"About to get train_{self.args.subset} cuts")
|
||||
path = self.args.manifest_dir / f"cuts_{self.args.subset}.jsonl.gz"
|
||||
cuts_train = CutSet.from_jsonl_lazy(path)
|
||||
return cuts_train
|
||||
|
||||
@lru_cache()
|
||||
def dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
|
||||
if self.args.small_dev:
|
||||
return cuts_valid.subset(first=1000)
|
||||
else:
|
||||
return cuts_valid
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz")
|
||||
@ -0,0 +1,115 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Jiayu Du
|
||||
# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
conversational_filler = [
|
||||
"UH",
|
||||
"UHH",
|
||||
"UM",
|
||||
"EH",
|
||||
"MM",
|
||||
"HM",
|
||||
"AH",
|
||||
"HUH",
|
||||
"HA",
|
||||
"ER",
|
||||
"OOF",
|
||||
"HEE",
|
||||
"ACH",
|
||||
"EEE",
|
||||
"EW",
|
||||
]
|
||||
unk_tags = ["<UNK>", "<unk>"]
|
||||
gigaspeech_punctuations = [
|
||||
"<COMMA>",
|
||||
"<PERIOD>",
|
||||
"<QUESTIONMARK>",
|
||||
"<EXCLAMATIONPOINT>",
|
||||
]
|
||||
gigaspeech_garbage_utterance_tags = ["<SIL>", "<NOISE>", "<MUSIC>", "<OTHER>"]
|
||||
non_scoring_words = (
|
||||
conversational_filler
|
||||
+ unk_tags
|
||||
+ gigaspeech_punctuations
|
||||
+ gigaspeech_garbage_utterance_tags
|
||||
)
|
||||
|
||||
|
||||
def asr_text_post_processing(text: str) -> str:
|
||||
# 1. convert to uppercase
|
||||
text = text.upper()
|
||||
|
||||
# 2. remove hyphen
|
||||
# "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART"
|
||||
text = text.replace("-", " ")
|
||||
|
||||
# 3. remove non-scoring words from evaluation
|
||||
remaining_words = []
|
||||
for word in text.split():
|
||||
if word in non_scoring_words:
|
||||
continue
|
||||
remaining_words.append(word)
|
||||
|
||||
return " ".join(remaining_words)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="This script evaluates GigaSpeech ASR result via"
|
||||
"SCTK's tool sclite"
|
||||
)
|
||||
parser.add_argument(
|
||||
"ref",
|
||||
type=str,
|
||||
help="sclite's standard transcription(trn) reference file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"hyp",
|
||||
type=str,
|
||||
help="sclite's standard transcription(trn) hypothesis file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"work_dir",
|
||||
type=str,
|
||||
help="working dir",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.isdir(args.work_dir):
|
||||
os.mkdir(args.work_dir)
|
||||
|
||||
REF = os.path.join(args.work_dir, "REF")
|
||||
HYP = os.path.join(args.work_dir, "HYP")
|
||||
RESULT = os.path.join(args.work_dir, "RESULT")
|
||||
|
||||
for io in [(args.ref, REF), (args.hyp, HYP)]:
|
||||
with open(io[0], "r", encoding="utf8") as fi:
|
||||
with open(io[1], "w+", encoding="utf8") as fo:
|
||||
for line in fi:
|
||||
line = line.strip()
|
||||
if line:
|
||||
cols = line.split()
|
||||
text = asr_text_post_processing(" ".join(cols[0:-1]))
|
||||
uttid_field = cols[-1]
|
||||
print(f"{text} {uttid_field}", file=fo)
|
||||
|
||||
# GigaSpeech's uttid comforms to swb
|
||||
os.system(f"sclite -r {REF} trn -h {HYP} trn -i swb | tee {RESULT}")
|
||||
272
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/jit_pretrained.py
Executable file
272
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/jit_pretrained.py
Executable file
@ -0,0 +1,272 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads torchscript models, exported by `torch.jit.script()`
|
||||
and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./pruned_transducer_stateless7/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--jit 1
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./pruned_transducer_stateless7/jit_pretrained.py \
|
||||
--nn-model-filename ./pruned_transducer_stateless7/exp/cpu_jit.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nn-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the torchscript model cpu_jit.pt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float = 16000
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: torch.jit.ScriptModule,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,).
|
||||
Returns:
|
||||
Return the decoded results for each utterance.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
device = encoder_out.device
|
||||
blank_id = 0 # hard-code to 0
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
context_size = model.decoder.context_size
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (N, context_size)
|
||||
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
).squeeze(1)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = packed_encoder_out.data[start:end]
|
||||
current_encoder_out = current_encoder_out
|
||||
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
decoder_out = decoder_out[:batch_size]
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
)
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
hyps[i].append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
)
|
||||
decoder_out = decoder_out.squeeze(1)
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model = torch.jit.load(args.nn_model_filename)
|
||||
|
||||
model.eval()
|
||||
|
||||
model.to(device)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features,
|
||||
x_lens=feature_lengths,
|
||||
)
|
||||
|
||||
hyps = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
s = "\n"
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = sp.decode(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
@ -0,0 +1,64 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Joiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder_proj = nn.Linear(encoder_dim, joiner_dim)
|
||||
self.decoder_proj = nn.Linear(decoder_dim, joiner_dim)
|
||||
self.output_linear = nn.Linear(joiner_dim, vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
project_input: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, s_range, C).
|
||||
decoder_out:
|
||||
Output from the decoder. Its shape is (N, T, s_range, C).
|
||||
project_input:
|
||||
If true, apply input projections encoder_proj and decoder_proj.
|
||||
If this is false, it is the user's responsibility to do this
|
||||
manually.
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, s_range, C).
|
||||
"""
|
||||
assert encoder_out.ndim == decoder_out.ndim
|
||||
assert encoder_out.ndim in (2, 4)
|
||||
|
||||
if project_input:
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
||||
else:
|
||||
logit = encoder_out + decoder_out
|
||||
|
||||
logit = self.output_linear(torch.tanh(logit))
|
||||
|
||||
return logit
|
||||
@ -0,0 +1,242 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import random
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import penalize_abs_values_gt
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from typing import Union, List
|
||||
|
||||
class Transducer(nn.Module):
|
||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||
"Sequence Transduction with Recurrent Neural Networks"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
joiner: nn.Module,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
context_encoder: nn.Module,
|
||||
encoder_biasing_adapter: nn.Module,
|
||||
decoder_biasing_adapter: nn.Module,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||
It should contain one attribute: `blank_id`.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
||||
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
assert hasattr(decoder, "blank_id")
|
||||
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
self.context_encoder = context_encoder
|
||||
self.encoder_biasing_adapter = encoder_biasing_adapter
|
||||
self.decoder_biasing_adapter = decoder_biasing_adapter
|
||||
|
||||
self.simple_am_proj = nn.Linear(
|
||||
encoder_dim,
|
||||
vocab_size,
|
||||
)
|
||||
self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
|
||||
|
||||
# For temporary convenience
|
||||
self.scratch_space = None
|
||||
self.no_encoder_biasing = None
|
||||
self.no_decoder_biasing = None
|
||||
self.no_wfst_lm_biasing = None
|
||||
self.params = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
contexts: dict,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
word_list:
|
||||
A list of words, where each word is a list of token ids.
|
||||
The list of tokens for each word has been padded.
|
||||
word_lengths:
|
||||
The number of tokens per word
|
||||
num_words_per_utt:
|
||||
The number of words in the context for each utterance
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
# breakpoint()
|
||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
contexts_h, contexts_mask = self.context_encoder.embed_contexts(
|
||||
contexts
|
||||
)
|
||||
# assert x.size(0) == contexts_h.size(0) == contexts_mask.size(0)
|
||||
# assert contexts_h.ndim == 3
|
||||
# assert contexts_h.ndim == 2
|
||||
if self.params.irrelevance_learning:
|
||||
need_weights = True
|
||||
else:
|
||||
need_weights = False
|
||||
encoder_biasing_out, attn_enc = self.encoder_biasing_adapter.forward(encoder_out, contexts_h, contexts_mask, need_weights=need_weights)
|
||||
encoder_out = encoder_out + encoder_biasing_out
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
if self.context_encoder.bi_encoders:
|
||||
contexts_dec_h, contexts_dec_mask = self.context_encoder.embed_contexts(
|
||||
contexts,
|
||||
is_encoder_side=False,
|
||||
)
|
||||
else:
|
||||
contexts_dec_h, contexts_dec_mask = contexts_h, contexts_mask
|
||||
|
||||
decoder_biasing_out, attn_dec = self.decoder_biasing_adapter.forward(decoder_out, contexts_dec_h, contexts_dec_mask, need_weights=need_weights)
|
||||
decoder_out = decoder_out + decoder_biasing_out
|
||||
|
||||
# Note: y does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
# if self.training and random.random() < 0.25:
|
||||
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||
# if self.training and random.random() < 0.25:
|
||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad,
|
||||
py_grad=py_grad,
|
||||
boundary=boundary,
|
||||
s_range=prune_range,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=self.joiner.encoder_proj(encoder_out),
|
||||
lm=self.joiner.decoder_proj(decoder_out),
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
||||
146
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/my_profile.py
Executable file
146
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/my_profile.py
Executable file
@ -0,0 +1,146 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Usage: ./pruned_transducer_stateless7/my_profile.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
|
||||
from typing import Tuple
|
||||
from torch import Tensor, nn
|
||||
|
||||
from icefall.profiler import get_model_profile
|
||||
from scaling import BasicNorm, DoubleSwish
|
||||
from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _basic_norm_flops_compute(module, input, output):
|
||||
assert len(input) == 1, len(input)
|
||||
# estimate as layer_norm, see icefall/profiler.py
|
||||
flops = input[0].numel() * 5
|
||||
module.__flops__ += int(flops)
|
||||
|
||||
|
||||
def _doubleswish_module_flops_compute(module, input, output):
|
||||
# For DoubleSwish
|
||||
assert len(input) == 1, len(input)
|
||||
# estimate as swish/silu, see icefall/profiler.py
|
||||
flops = input[0].numel()
|
||||
module.__flops__ += int(flops)
|
||||
|
||||
|
||||
MODULE_HOOK_MAPPING = {
|
||||
BasicNorm: _basic_norm_flops_compute,
|
||||
DoubleSwish: _doubleswish_module_flops_compute,
|
||||
}
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""A Wrapper for encoder and encoder_proj"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: nn.Module,
|
||||
encoder_proj: nn.Module,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.encoder_proj = encoder_proj
|
||||
|
||||
def forward(self, feature: Tensor, feature_lens: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
encoder_out, encoder_out_lens = self.encoder(feature, feature_lens)
|
||||
|
||||
logits = self.encoder_proj(encoder_out)
|
||||
|
||||
return logits, encoder_out_lens
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
|
||||
# We only profile the encoder part
|
||||
model = Model(
|
||||
encoder=get_encoder_model(params),
|
||||
encoder_proj=get_joiner_model(params).encoder_proj,
|
||||
)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# for 30-second input
|
||||
B, T, D = 1, 3000, 80
|
||||
feature = torch.ones(B, T, D, dtype=torch.float32).to(device)
|
||||
feature_lens = torch.full((B,), T, dtype=torch.int64).to(device)
|
||||
|
||||
flops, params = get_model_profile(
|
||||
model=model,
|
||||
args=(feature, feature_lens),
|
||||
module_hoop_mapping=MODULE_HOOK_MAPPING,
|
||||
)
|
||||
logging.info(f"For the encoder part, params: {params}, flops: {flops}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
||||
@ -0,0 +1,634 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import abc
|
||||
import copy
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class Ffn(nn.Module):
|
||||
def __init__(self, input_dim, hidden_dim, out_dim, nlayers=1, drop_out=0.1, skip=False) -> None:
|
||||
super().__init__()
|
||||
|
||||
layers = []
|
||||
for ilayer in range(nlayers):
|
||||
_in = hidden_dim if ilayer > 0 else input_dim
|
||||
_out = hidden_dim if ilayer < nlayers - 1 else out_dim
|
||||
layers.extend([
|
||||
nn.Linear(_in, _out),
|
||||
# nn.ReLU(),
|
||||
# nn.Sigmoid(),
|
||||
nn.Tanh(),
|
||||
nn.Dropout(p=drop_out),
|
||||
])
|
||||
self.ffn = torch.nn.Sequential(
|
||||
*layers,
|
||||
)
|
||||
|
||||
self.skip = skip
|
||||
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
x_out = self.ffn(x)
|
||||
|
||||
if self.skip:
|
||||
x_out = x_out + x
|
||||
|
||||
return x_out
|
||||
|
||||
|
||||
class ContextEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(ContextEncoder, self).__init__()
|
||||
|
||||
self.stats_num_distractors_per_utt = 0
|
||||
self.stats_num_utt = 0
|
||||
|
||||
@abc.abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
word_list,
|
||||
word_lengths,
|
||||
is_encoder_side=None,
|
||||
):
|
||||
pass
|
||||
|
||||
def embed_contexts(
|
||||
self,
|
||||
contexts,
|
||||
is_encoder_side=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
contexts:
|
||||
The contexts, see below for details
|
||||
Returns:
|
||||
final_h:
|
||||
A tensor of shape (batch_size, max(num_words_per_utt) + 1, joiner_dim),
|
||||
which is the embedding for each context word.
|
||||
mask_h:
|
||||
A tensor of shape (batch_size, max(num_words_per_utt) + 1),
|
||||
which contains a True/False mask for final_h
|
||||
"""
|
||||
if contexts["mode"] == "get_context_word_list":
|
||||
"""
|
||||
word_list:
|
||||
Option1: A list of words, where each word is a list of token ids.
|
||||
The list of tokens for each word has been padded.
|
||||
Option2: A list of words, where each word is an embedding.
|
||||
word_lengths:
|
||||
Option1: The number of tokens per word
|
||||
Option2: None
|
||||
num_words_per_utt:
|
||||
The number of words in the context for each utterance
|
||||
"""
|
||||
word_list, word_lengths, num_words_per_utt = \
|
||||
contexts["word_list"], contexts["word_lengths"], contexts["num_words_per_utt"]
|
||||
|
||||
assert word_lengths is None or word_list.size(0) == len(word_lengths)
|
||||
batch_size = len(num_words_per_utt)
|
||||
elif contexts["mode"] == "get_context_word_list_shared":
|
||||
"""
|
||||
word_list:
|
||||
Option1: A list of words, where each word is a list of token ids.
|
||||
The list of tokens for each word has been padded.
|
||||
Option2: A list of words, where each word is an embedding.
|
||||
word_lengths:
|
||||
Option1: The number of tokens per word
|
||||
Option2: None
|
||||
positive_mask_list:
|
||||
For each utterance, it contains a list of indices of the words should be masked
|
||||
"""
|
||||
# word_list, word_lengths, positive_mask_list = \
|
||||
# contexts["word_list"], contexts["word_lengths"], contexts["positive_mask_list"]
|
||||
# batch_size = len(positive_mask_list)
|
||||
word_list, word_lengths, num_words_per_utt = \
|
||||
contexts["word_list"], contexts["word_lengths"], contexts["num_words_per_utt"]
|
||||
batch_size = len(num_words_per_utt)
|
||||
|
||||
assert word_lengths is None or word_list.size(0) == len(word_lengths)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# print(f"word_list.shape={word_list.shape}")
|
||||
final_h = self.forward(word_list, word_lengths, is_encoder_side=is_encoder_side)
|
||||
|
||||
if contexts["mode"] == "get_context_word_list":
|
||||
final_h = torch.split(final_h, num_words_per_utt)
|
||||
final_h = torch.nn.utils.rnn.pad_sequence(
|
||||
final_h,
|
||||
batch_first=True,
|
||||
padding_value=0.0
|
||||
)
|
||||
# print(f"final_h.shape={final_h.shape}")
|
||||
|
||||
# add one no-bias token
|
||||
no_bias_h = torch.zeros(final_h.shape[0], 1, final_h.shape[-1])
|
||||
no_bias_h = no_bias_h.to(final_h.device)
|
||||
final_h = torch.cat((no_bias_h, final_h), dim=1)
|
||||
# print(final_h)
|
||||
|
||||
# https://stackoverflow.com/questions/53403306/how-to-batch-convert-sentence-lengths-to-masks-in-pytorch
|
||||
mask_h = torch.arange(max(num_words_per_utt) + 1)
|
||||
mask_h = mask_h.expand(len(num_words_per_utt), max(num_words_per_utt) + 1) > torch.Tensor(num_words_per_utt).unsqueeze(1)
|
||||
mask_h = mask_h.to(final_h.device)
|
||||
|
||||
num_utt = len(num_words_per_utt)
|
||||
self.stats_num_distractors_per_utt = len(word_list) / (num_utt + self.stats_num_utt) + self.stats_num_utt / (num_utt + self.stats_num_utt) * self.stats_num_distractors_per_utt
|
||||
self.stats_num_utt += num_utt
|
||||
elif contexts["mode"] == "get_context_word_list_shared":
|
||||
no_bias_h = torch.zeros(1, final_h.shape[-1])
|
||||
no_bias_h = no_bias_h.to(final_h.device)
|
||||
final_h = torch.cat((no_bias_h, final_h), dim=0)
|
||||
|
||||
final_h = final_h.expand(batch_size, -1, -1)
|
||||
|
||||
# mask_h = torch.full(False, (batch_size, final_h.shape(1))) # TODO
|
||||
# for i, my_mask in enumerate(positive_mask_list):
|
||||
# if len(my_mask) > 0:
|
||||
# my_mask = torch.Tensor(my_mask, dtype=int)
|
||||
# my_mask += 1
|
||||
# mask_h[i][my_mask] = True
|
||||
mask_h = None
|
||||
|
||||
num_utt = len(num_words_per_utt)
|
||||
self.stats_num_distractors_per_utt = num_utt / (num_utt + self.stats_num_utt) * len(word_list) + self.stats_num_utt / (num_utt + self.stats_num_utt) * self.stats_num_distractors_per_utt
|
||||
self.stats_num_utt += num_utt
|
||||
|
||||
|
||||
# TODO: validate this shape is correct:
|
||||
# final_h: batch_size * max_num_words_per_utt + 1 * dim
|
||||
# mask_h: batch_size * max_num_words_per_utt + 1
|
||||
return final_h, mask_h
|
||||
|
||||
def clustering(self):
|
||||
pass
|
||||
|
||||
def cache(self):
|
||||
pass
|
||||
|
||||
|
||||
class ContextEncoderLSTM(ContextEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int = None,
|
||||
context_encoder_dim: int = None,
|
||||
embedding_layer: nn.Module = None,
|
||||
output_dim: int = None,
|
||||
num_layers: int = None,
|
||||
num_directions: int = None,
|
||||
drop_out: float = 0.1,
|
||||
bi_encoders: bool = False,
|
||||
):
|
||||
super(ContextEncoderLSTM, self).__init__()
|
||||
self.num_layers = num_layers
|
||||
self.num_directions = num_directions
|
||||
self.context_encoder_dim = context_encoder_dim
|
||||
|
||||
if embedding_layer is not None:
|
||||
# self.embed = embedding_layer
|
||||
self.embed = torch.nn.Embedding(
|
||||
vocab_size,
|
||||
context_encoder_dim
|
||||
)
|
||||
self.embed.weight.data = embedding_layer.weight.data
|
||||
else:
|
||||
self.embed = torch.nn.Embedding(
|
||||
vocab_size,
|
||||
context_encoder_dim
|
||||
)
|
||||
|
||||
self.rnn = torch.nn.LSTM(
|
||||
input_size=self.embed.weight.shape[1],
|
||||
hidden_size=context_encoder_dim,
|
||||
num_layers=self.num_layers,
|
||||
batch_first=True,
|
||||
bidirectional=(self.num_directions == 2),
|
||||
dropout=0.1 if self.num_layers > 1 else 0
|
||||
)
|
||||
self.linear = torch.nn.Linear(
|
||||
context_encoder_dim * self.num_directions,
|
||||
output_dim
|
||||
)
|
||||
self.drop_out = torch.nn.Dropout(drop_out)
|
||||
|
||||
# TODO: Do we need some relu layer?
|
||||
# https://galhever.medium.com/sentiment-analysis-with-pytorch-part-4-lstm-bilstm-model-84447f6c4525
|
||||
# self.relu = nn.ReLU()
|
||||
# self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.bi_encoders = bi_encoders
|
||||
if bi_encoders:
|
||||
# Create the decoder/predictor side of the context encoder
|
||||
self.embed_dec = copy.deepcopy(self.embed)
|
||||
self.rnn_dec = copy.deepcopy(self.rnn)
|
||||
self.linear_dec = copy.deepcopy(self.linear)
|
||||
self.drop_out_dec = copy.deepcopy(self.drop_out)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
word_list,
|
||||
word_lengths,
|
||||
is_encoder_side=None,
|
||||
):
|
||||
if is_encoder_side is None or is_encoder_side is True:
|
||||
embed = self.embed
|
||||
rnn = self.rnn
|
||||
linear = self.linear
|
||||
drop_out = self.drop_out
|
||||
else:
|
||||
embed = self.embed_dec
|
||||
rnn = self.rnn_dec
|
||||
linear = self.linear_dec
|
||||
drop_out = self.drop_out_dec
|
||||
|
||||
out = embed(word_list)
|
||||
# https://stackoverflow.com/questions/51030782/why-do-we-pack-the-sequences-in-pytorch
|
||||
out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
out,
|
||||
batch_first=True,
|
||||
lengths=word_lengths,
|
||||
enforce_sorted=False
|
||||
)
|
||||
output, (hn, cn) = rnn(out) # use default all zeros (h_0, c_0)
|
||||
|
||||
# # https://discuss.pytorch.org/t/bidirectional-3-layer-lstm-hidden-output/41336/4
|
||||
# final_state = hn.view(
|
||||
# self.num_layers,
|
||||
# self.num_directions,
|
||||
# word_list.shape[0],
|
||||
# self.encoder_dim,
|
||||
# )[-1] # Only the last layer
|
||||
# h_1, h_2 = final_state[0], final_state[1]
|
||||
# # X = h_1 + h_2 # Add both states (needs different input size for first linear layer)
|
||||
# final_h = torch.cat((h_1, h_2), dim=1) # Concatenate both states
|
||||
# final_h = self.linear(final_h)
|
||||
|
||||
# hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]
|
||||
# outputs are always from the last layer.
|
||||
# hidden[-2, :, : ] is the last of the forwards RNN
|
||||
# hidden[-1, :, : ] is the last of the backwards RNN
|
||||
h_1, h_2 = hn[-2, :, : ] , hn[-1, :, : ]
|
||||
final_h = torch.cat((h_1, h_2), dim=1) # Concatenate both states
|
||||
final_h = linear(final_h)
|
||||
# final_h = drop_out(final_h)
|
||||
|
||||
return final_h
|
||||
|
||||
|
||||
class SimpleGLU(nn.Module):
|
||||
def __init__(self):
|
||||
super(SimpleGLU, self).__init__()
|
||||
# Initialize the learnable parameter 'a'
|
||||
self.a = nn.Parameter(torch.tensor(0.0))
|
||||
|
||||
def forward(self, x):
|
||||
# Perform the operation a * x
|
||||
return self.a * x
|
||||
|
||||
|
||||
class BiasingModule(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
query_dim,
|
||||
qkv_dim=64,
|
||||
num_heads=4,
|
||||
):
|
||||
super(BiasingModule, self).__init__()
|
||||
self.proj_in1 = nn.Linear(query_dim, qkv_dim)
|
||||
self.proj_in2 = Ffn(
|
||||
input_dim=qkv_dim,
|
||||
hidden_dim=qkv_dim,
|
||||
out_dim=qkv_dim,
|
||||
skip=True,
|
||||
drop_out=0.1,
|
||||
nlayers=2,
|
||||
)
|
||||
self.multihead_attn = torch.nn.MultiheadAttention(
|
||||
embed_dim=qkv_dim,
|
||||
num_heads=num_heads,
|
||||
# kdim=64,
|
||||
# vdim=64,
|
||||
batch_first=True,
|
||||
)
|
||||
self.proj_out1 = Ffn(
|
||||
input_dim=qkv_dim,
|
||||
hidden_dim=qkv_dim,
|
||||
out_dim=qkv_dim,
|
||||
skip=True,
|
||||
drop_out=0.1,
|
||||
nlayers=2,
|
||||
)
|
||||
self.proj_out2 = nn.Linear(qkv_dim, query_dim)
|
||||
self.glu = nn.GLU()
|
||||
# self.glu = SimpleGLU()
|
||||
|
||||
self.contexts = None
|
||||
self.contexts_mask = None
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# query_dim,
|
||||
# qkv_dim=64,
|
||||
# num_heads=4,
|
||||
# ):
|
||||
# super(BiasingModule, self).__init__()
|
||||
# self.proj_in1 = Ffn(
|
||||
# input_dim=query_dim,
|
||||
# hidden_dim=query_dim,
|
||||
# out_dim=query_dim,
|
||||
# skip=True,
|
||||
# drop_out=0.1,
|
||||
# nlayers=2,
|
||||
# )
|
||||
# self.proj_in2 = nn.Linear(query_dim, qkv_dim)
|
||||
# self.multihead_attn = torch.nn.MultiheadAttention(
|
||||
# embed_dim=qkv_dim,
|
||||
# num_heads=num_heads,
|
||||
# # kdim=64,
|
||||
# # vdim=64,
|
||||
# batch_first=True,
|
||||
# )
|
||||
# self.proj_out1 = nn.Linear(qkv_dim, query_dim)
|
||||
# self.proj_out2 = Ffn(
|
||||
# input_dim=query_dim,
|
||||
# hidden_dim=query_dim,
|
||||
# out_dim=query_dim,
|
||||
# skip=True,
|
||||
# drop_out=0.1,
|
||||
# nlayers=2,
|
||||
# )
|
||||
# self.glu = nn.GLU()
|
||||
|
||||
# self.contexts = None
|
||||
# self.contexts_mask = None
|
||||
# self.attn_output_weights = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
queries,
|
||||
contexts=None,
|
||||
contexts_mask=None,
|
||||
need_weights=False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
query:
|
||||
of shape batch_size * seq_length * query_dim
|
||||
contexts:
|
||||
of shape batch_size * max_contexts_size * query_dim
|
||||
contexts_mask:
|
||||
of shape batch_size * max_contexts_size
|
||||
Returns:
|
||||
attn_output:
|
||||
of shape batch_size * seq_length * context_dim
|
||||
"""
|
||||
|
||||
if contexts is None:
|
||||
contexts = self.contexts
|
||||
if contexts_mask is None:
|
||||
contexts_mask = self.contexts_mask
|
||||
|
||||
_queries = self.proj_in1(queries)
|
||||
_queries = self.proj_in2(_queries)
|
||||
# _queries = _queries / 0.01
|
||||
|
||||
attn_output, attn_output_weights = self.multihead_attn(
|
||||
_queries, # query
|
||||
contexts, # key
|
||||
contexts, # value
|
||||
key_padding_mask=contexts_mask,
|
||||
need_weights=need_weights,
|
||||
)
|
||||
biasing_output = self.proj_out1(attn_output)
|
||||
biasing_output = self.proj_out2(biasing_output)
|
||||
|
||||
# apply the gated linear unit
|
||||
biasing_output = self.glu(biasing_output.repeat(1,1,2))
|
||||
# biasing_output = self.glu(biasing_output)
|
||||
|
||||
# inject contexts here
|
||||
output = queries + biasing_output
|
||||
|
||||
# print(f"query={query.shape}")
|
||||
# print(f"value={contexts} value.shape={contexts.shape}")
|
||||
# print(f"attn_output_weights={attn_output_weights} attn_output_weights.shape={attn_output_weights.shape}")
|
||||
|
||||
return output, attn_output_weights
|
||||
|
||||
|
||||
def tuple_to_list(t):
|
||||
if isinstance(t, tuple):
|
||||
return list(map(tuple_to_list, t))
|
||||
return t
|
||||
|
||||
|
||||
def list_to_tuple(l):
|
||||
if isinstance(l, list):
|
||||
return tuple(map(list_to_tuple, l))
|
||||
return l
|
||||
|
||||
|
||||
class ContextualSequential(nn.Sequential):
|
||||
def __init__(self, *args):
|
||||
super(ContextualSequential, self).__init__(*args)
|
||||
|
||||
self.contexts_h = None
|
||||
self.contexts_mask = None
|
||||
|
||||
def set_contexts(self, contexts_h, contexts_masks):
|
||||
self.contexts_h, self.contexts_mask = contexts_h, contexts_masks
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
# print(f"input: {type(args[0])=}, {args[0].shape=}")
|
||||
|
||||
is_hf = False
|
||||
|
||||
for module in self._modules.values():
|
||||
module_name = type(module).__name__
|
||||
# if "AudioEncoder" in module_name:
|
||||
# args = (module(*args, **kwargs),)
|
||||
# elif "TextDecoder" in module_name:
|
||||
# args = (module(*args, **kwargs),)
|
||||
# elif "WhisperDecoderLayer" in module_name:
|
||||
# args = (module(*args, **kwargs),)
|
||||
# elif "WhisperEncoderLayer" in module_name:
|
||||
# args = (module(*args, **kwargs),)
|
||||
|
||||
if "WhisperDecoderLayer" in module_name or "WhisperEncoderLayer" in module_name:
|
||||
is_hf = True
|
||||
|
||||
if "BiasingModule" in module_name:
|
||||
x = args[0]
|
||||
while isinstance(x, list) or isinstance(x, tuple):
|
||||
x = x[0]
|
||||
|
||||
if self.contexts_h is not None:
|
||||
# The contexts are injected here
|
||||
x, attn_output_weights = module(x, contexts=self.contexts_h, contexts_mask=self.contexts_mask, need_weights=True)
|
||||
else:
|
||||
# final_h: batch_size * max_num_words_per_utt + 1 * dim
|
||||
# mask_h: batch_size * max_num_words_per_utt + 1
|
||||
|
||||
batch_size = x.size(0)
|
||||
contexts_h = torch.zeros(batch_size, 1, module.multihead_attn.embed_dim)
|
||||
contexts_h = contexts_h.to(x.device)
|
||||
|
||||
contexts_mask = torch.zeros(batch_size, 1, dtype=torch.bool)
|
||||
contexts_mask = contexts_mask.to(x.device)
|
||||
|
||||
x, attn_output_weights = module(x, contexts=contexts_h, contexts_mask=contexts_mask, need_weights=True)
|
||||
|
||||
args = (x,)
|
||||
else:
|
||||
x = module(*args, **kwargs)
|
||||
|
||||
while isinstance(x, list) or isinstance(x, tuple):
|
||||
x = x[0]
|
||||
args = (x,)
|
||||
|
||||
# print(f"output: {type(args[0])=}, {args[0].shape=}")
|
||||
if is_hf:
|
||||
return args
|
||||
else:
|
||||
return args[0]
|
||||
|
||||
|
||||
def set_contexts_for_model(model, contexts):
|
||||
# check each module in the model, if it is a class of "ContextualSequential",
|
||||
# then set the contexts for the module
|
||||
|
||||
# if hasattr(model, "context_encoder") and model.context_encoder is not None:
|
||||
contexts_h, contexts_mask = model.context_encoder.embed_contexts(
|
||||
contexts
|
||||
)
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, ContextualSequential):
|
||||
module.set_contexts(contexts_h, contexts_mask)
|
||||
|
||||
|
||||
def get_contextual_model(model, encoder_biasing_layers="31,", decoder_biasing_layers="31,", context_dim=128) -> nn.Module:
|
||||
# context_dim = 128 # 1.5%
|
||||
# context_dim = 256 # 5.22% => seems better?
|
||||
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
# [(n, param.requires_grad) for n, param in context_encoder.named_parameters()]
|
||||
# [(n, p.numel()) for n, p in model.named_parameters() if p.requires_grad]
|
||||
|
||||
print("Before neural biasing:")
|
||||
print(f"{total_params=}")
|
||||
print(f"{trainable_params=} ({trainable_params/total_params*100:.2f}%)")
|
||||
|
||||
encoder_biasing_layers = [int(l) for l in encoder_biasing_layers.strip().split(",") if len(l) > 0]
|
||||
decoder_biasing_layers = [int(l) for l in decoder_biasing_layers.strip().split(",") if len(l) > 0]
|
||||
|
||||
if len(encoder_biasing_layers) > 0 or len(decoder_biasing_layers) > 0:
|
||||
if hasattr(model, "model"): # Hugegingface models
|
||||
embedding_layer = model.model.decoder.embed_tokens
|
||||
else:
|
||||
embedding_layer = model.decoder.token_embedding
|
||||
|
||||
context_encoder = ContextEncoderLSTM(
|
||||
# vocab_size=embedding_layer.weight.shape[0],
|
||||
embedding_layer=embedding_layer,
|
||||
# context_encoder_dim=int(params.encoder_dims.split(",")[-1]),
|
||||
context_encoder_dim=context_dim,
|
||||
output_dim=context_dim,
|
||||
num_layers=2,
|
||||
num_directions=2,
|
||||
drop_out=0.1,
|
||||
)
|
||||
model.context_encoder = context_encoder
|
||||
|
||||
# encoder_biasing_adapter = BiasingModule(
|
||||
# query_dim=int(params.encoder_dims.split(",")[-1]),
|
||||
# qkv_dim=context_dim,
|
||||
# num_heads=4,
|
||||
# )
|
||||
|
||||
if hasattr(model, "model") and hasattr(model.model.encoder, "layers"): # Huggingface models
|
||||
for i, layer in enumerate(model.model.encoder.layers):
|
||||
if i in encoder_biasing_layers:
|
||||
layer_output_dim = layer.final_layer_norm.normalized_shape[0]
|
||||
model.model.encoder.layers[i] = ContextualSequential(OrderedDict([
|
||||
("layer", layer),
|
||||
("biasing_adapter", BiasingModule(
|
||||
query_dim=layer_output_dim,
|
||||
qkv_dim=context_dim,
|
||||
num_heads=4,
|
||||
))
|
||||
]))
|
||||
elif hasattr(model.encoder, "blocks"): # OpenAI models
|
||||
for i, layer in enumerate(model.encoder.blocks):
|
||||
if i in encoder_biasing_layers:
|
||||
layer_output_dim = layer.mlp_ln.normalized_shape[0]
|
||||
model.encoder.blocks[i] = ContextualSequential(OrderedDict([
|
||||
("layer", layer),
|
||||
("biasing_adapter", BiasingModule(
|
||||
query_dim=layer_output_dim,
|
||||
qkv_dim=context_dim,
|
||||
num_heads=4,
|
||||
))
|
||||
]))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if hasattr(model, "model") and hasattr(model.model.decoder, "layers"):
|
||||
for i, layer in enumerate(model.model.decoder.layers):
|
||||
if i in decoder_biasing_layers:
|
||||
layer_output_dim = layer.final_layer_norm.normalized_shape[0]
|
||||
model.model.decoder.layers[i] = ContextualSequential(OrderedDict([
|
||||
("layer", layer),
|
||||
("biasing_adapter", BiasingModule(
|
||||
query_dim=layer_output_dim,
|
||||
qkv_dim=context_dim,
|
||||
num_heads=4,
|
||||
))
|
||||
]))
|
||||
elif hasattr(model.decoder, "blocks"):
|
||||
for i, layer in enumerate(model.decoder.blocks):
|
||||
if i in decoder_biasing_layers:
|
||||
layer_output_dim = layer.mlp_ln.normalized_shape[0]
|
||||
model.decoder.blocks[i] = ContextualSequential(OrderedDict([
|
||||
("layer", layer),
|
||||
("biasing_adapter", BiasingModule(
|
||||
query_dim=layer_output_dim,
|
||||
qkv_dim=context_dim,
|
||||
num_heads=4,
|
||||
))
|
||||
]))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# Freeze the model params
|
||||
# exception_types = (BiasingModule, ContextEncoderLSTM)
|
||||
for name, param in model.named_parameters():
|
||||
# Check if the parameter belongs to a layer of the specified types
|
||||
if "biasing_adapter" in name:
|
||||
param.requires_grad = True
|
||||
elif "context_encoder" in name and "context_encoder.embed" not in name: # We will not fine-tune the embedding layer, which comes from the original model
|
||||
param.requires_grad = True
|
||||
elif "context_encoder" in name and "context_encoder.embed" in name: # Debug
|
||||
param.requires_grad = False
|
||||
else:
|
||||
param.requires_grad = False
|
||||
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
# [(n, param.requires_grad) for n, param in context_encoder.named_parameters()]
|
||||
# [(n, p.numel()) for n, p in model.named_parameters() if p.requires_grad]
|
||||
|
||||
print("Neural biasing:")
|
||||
print(f"{total_params=}")
|
||||
print(f"{trainable_params=} ({trainable_params/total_params*100:.2f}%)")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# # Test:
|
||||
# import whisper, torch; device = torch.device("cuda", 0)
|
||||
# model = whisper.load_model("large-v2", is_ctx=True, device=device)
|
||||
# from neural_biasing import get_contextual_model
|
||||
# model1 = get_contextual_model(model)
|
||||
239
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/onnx_check.py
Executable file
239
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/onnx_check.py
Executable file
@ -0,0 +1,239 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script checks that exported onnx models produce the same output
|
||||
with the given torchscript model for the same input.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained-iter-1224000-avg-14.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt
|
||||
popd
|
||||
|
||||
2. Export the model via torchscript (torch.jit.script())
|
||||
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp/ \
|
||||
--jit 1
|
||||
|
||||
It will generate the following file in $repo/exp:
|
||||
- cpu_jit.pt
|
||||
|
||||
3. Export the model to ONNX
|
||||
|
||||
./pruned_transducer_stateless3/export-onnx.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp/
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-9999-avg-1.onnx
|
||||
- decoder-epoch-9999-avg-1.onnx
|
||||
- joiner-epoch-9999-avg-1.onnx
|
||||
|
||||
4. Run this file
|
||||
|
||||
./pruned_transducer_stateless3/onnx_check.py \
|
||||
--jit-filename $repo/exp/cpu_jit.pt \
|
||||
--onnx-encoder-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \
|
||||
--onnx-decoder-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \
|
||||
--onnx-joiner-filename $repo/exp/joiner-epoch-9999-avg-1.onnx
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from icefall import is_module_available
|
||||
from onnx_pretrained import OnnxModel
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the torchscript model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-encoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-decoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-joiner-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx joiner model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def test_encoder(
|
||||
torch_model: torch.jit.ScriptModule,
|
||||
onnx_model: OnnxModel,
|
||||
):
|
||||
C = 80
|
||||
for i in range(3):
|
||||
N = torch.randint(low=1, high=20, size=(1,)).item()
|
||||
T = torch.randint(low=30, high=50, size=(1,)).item()
|
||||
logging.info(f"test_encoder: iter {i}, N={N}, T={T}")
|
||||
|
||||
x = torch.rand(N, T, C)
|
||||
x_lens = torch.randint(low=30, high=T + 1, size=(N,))
|
||||
x_lens[0] = T
|
||||
|
||||
torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens)
|
||||
torch_encoder_out = torch_model.joiner.encoder_proj(torch_encoder_out)
|
||||
|
||||
onnx_encoder_out, onnx_encoder_out_lens = onnx_model.run_encoder(x, x_lens)
|
||||
|
||||
assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-05), (
|
||||
(torch_encoder_out - onnx_encoder_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
def test_decoder(
|
||||
torch_model: torch.jit.ScriptModule,
|
||||
onnx_model: OnnxModel,
|
||||
):
|
||||
context_size = onnx_model.context_size
|
||||
vocab_size = onnx_model.vocab_size
|
||||
for i in range(10):
|
||||
N = torch.randint(1, 100, size=(1,)).item()
|
||||
logging.info(f"test_decoder: iter {i}, N={N}")
|
||||
x = torch.randint(
|
||||
low=1,
|
||||
high=vocab_size,
|
||||
size=(N, context_size),
|
||||
dtype=torch.int64,
|
||||
)
|
||||
torch_decoder_out = torch_model.decoder(x, need_pad=torch.tensor([False]))
|
||||
torch_decoder_out = torch_model.joiner.decoder_proj(torch_decoder_out)
|
||||
torch_decoder_out = torch_decoder_out.squeeze(1)
|
||||
|
||||
onnx_decoder_out = onnx_model.run_decoder(x)
|
||||
assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), (
|
||||
(torch_decoder_out - onnx_decoder_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
def test_joiner(
|
||||
torch_model: torch.jit.ScriptModule,
|
||||
onnx_model: OnnxModel,
|
||||
):
|
||||
encoder_dim = torch_model.joiner.encoder_proj.weight.shape[1]
|
||||
decoder_dim = torch_model.joiner.decoder_proj.weight.shape[1]
|
||||
for i in range(10):
|
||||
N = torch.randint(1, 100, size=(1,)).item()
|
||||
logging.info(f"test_joiner: iter {i}, N={N}")
|
||||
encoder_out = torch.rand(N, encoder_dim)
|
||||
decoder_out = torch.rand(N, decoder_dim)
|
||||
|
||||
projected_encoder_out = torch_model.joiner.encoder_proj(encoder_out)
|
||||
projected_decoder_out = torch_model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
torch_joiner_out = torch_model.joiner(encoder_out, decoder_out)
|
||||
onnx_joiner_out = onnx_model.run_joiner(
|
||||
projected_encoder_out, projected_decoder_out
|
||||
)
|
||||
|
||||
assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), (
|
||||
(torch_joiner_out - onnx_joiner_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
torch_model = torch.jit.load(args.jit_filename)
|
||||
|
||||
onnx_model = OnnxModel(
|
||||
encoder_model_filename=args.onnx_encoder_filename,
|
||||
decoder_model_filename=args.onnx_decoder_filename,
|
||||
joiner_model_filename=args.onnx_joiner_filename,
|
||||
)
|
||||
|
||||
logging.info("Test encoder")
|
||||
test_encoder(torch_model, onnx_model)
|
||||
|
||||
logging.info("Test decoder")
|
||||
test_decoder(torch_model, onnx_model)
|
||||
|
||||
logging.info("Test joiner")
|
||||
test_joiner(torch_model, onnx_model)
|
||||
logging.info("Finished checking ONNX models")
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
# See https://github.com/pytorch/pytorch/issues/38342
|
||||
# and https://github.com/pytorch/pytorch/issues/33354
|
||||
#
|
||||
# If we don't do this, the delay increases whenever there is
|
||||
# a new request that changes the actual batch size.
|
||||
# If you use `py-spy dump --pid <server-pid> --native`, you will
|
||||
# see a lot of time is spent in re-compiling the torch script model.
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._set_graph_executor_optimize(False)
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20220727)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
319
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/onnx_decode.py
Executable file
319
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/onnx_decode.py
Executable file
@ -0,0 +1,319 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao,
|
||||
# Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads ONNX exported models and uses them to decode the test sets.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained-epoch-30-avg-9.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained-epoch-30-avg-9.pt epoch-9999.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./pruned_transducer_stateless7/export-onnx.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp/
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-9999-avg-1.onnx
|
||||
- decoder-epoch-9999-avg-1.onnx
|
||||
- joiner-epoch-9999-avg-1.onnx
|
||||
|
||||
2. Run this file
|
||||
|
||||
./pruned_transducer_stateless7/onnx_decode.py \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
|
||||
from onnx_pretrained import greedy_search, OnnxModel
|
||||
|
||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="Valid values are greedy_search and modified_beam_search",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
model: OnnxModel, sp: spm.SentencePieceProcessor, batch: dict
|
||||
) -> List[List[str]]:
|
||||
"""Decode one batch and return the result.
|
||||
Currently it only greedy_search is supported.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
|
||||
Returns:
|
||||
Return the decoded results for each utterance.
|
||||
"""
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(dtype=torch.int64)
|
||||
|
||||
encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens)
|
||||
|
||||
hyps = greedy_search(
|
||||
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
|
||||
)
|
||||
|
||||
hyps = [sp.decode(h).split() for h in hyps]
|
||||
return hyps
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
|
||||
Returns:
|
||||
- A list of tuples. Each tuple contains three elements:
|
||||
- cut_id,
|
||||
- reference transcript,
|
||||
- predicted result.
|
||||
- The total duration (in seconds) of the dataset.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
log_interval = 10
|
||||
total_duration = 0
|
||||
|
||||
results = []
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
|
||||
|
||||
hyps = decode_one_batch(model=model, sp=sp, batch=batch)
|
||||
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results.extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
|
||||
return results, total_duration
|
||||
|
||||
|
||||
def save_results(
|
||||
res_dir: Path,
|
||||
test_set_name: str,
|
||||
results: List[Tuple[str, List[str], List[str]]],
|
||||
):
|
||||
recog_path = res_dir / f"recogs-{test_set_name}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = res_dir / f"errs-{test_set_name}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
errs_info = res_dir / f"wer-summary-{test_set_name}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("WER", file=f)
|
||||
print(wer, file=f)
|
||||
|
||||
s = "\nFor {}, WER is {}:\n".format(test_set_name, wer)
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert (
|
||||
args.decoding_method == "greedy_search"
|
||||
), "Only supports greedy_search currently."
|
||||
res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}"
|
||||
|
||||
setup_logger(f"{res_dir}/log-decode")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
|
||||
blank_id = sp.piece_to_id("<blk>")
|
||||
assert blank_id == 0, blank_id
|
||||
|
||||
logging.info(vars(args))
|
||||
|
||||
logging.info("About to create model")
|
||||
model = OnnxModel(
|
||||
encoder_model_filename=args.encoder_model_filename,
|
||||
decoder_model_filename=args.decoder_model_filename,
|
||||
joiner_model_filename=args.joiner_model_filename,
|
||||
)
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
start_time = time.time()
|
||||
results, total_duration = decode_dataset(dl=test_dl, model=model, sp=sp)
|
||||
end_time = time.time()
|
||||
elapsed_seconds = end_time - start_time
|
||||
rtf = elapsed_seconds / total_duration
|
||||
|
||||
logging.info(f"Elapsed time: {elapsed_seconds:.3f} s")
|
||||
logging.info(f"Wave duration: {total_duration:.3f} s")
|
||||
logging.info(
|
||||
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
|
||||
)
|
||||
|
||||
save_results(res_dir=res_dir, test_set_name=test_set, results=results)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
417
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/onnx_pretrained.py
Executable file
417
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/onnx_pretrained.py
Executable file
@ -0,0 +1,417 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads ONNX models and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained-iter-1224000-avg-14.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./pruned_transducer_stateless3/export-onnx.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp/
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-9999-avg-1.onnx
|
||||
- decoder-epoch-9999-avg-1.onnx
|
||||
- joiner-epoch-9999-avg-1.onnx
|
||||
|
||||
3. Run this file
|
||||
|
||||
./pruned_transducer_stateless3/onnx_pretrained.py \
|
||||
--encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="""Path to tokens.txt.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
encoder_model_filename: str,
|
||||
decoder_model_filename: str,
|
||||
joiner_model_filename: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 4
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
self.init_encoder(encoder_model_filename)
|
||||
self.init_decoder(decoder_model_filename)
|
||||
self.init_joiner(joiner_model_filename)
|
||||
|
||||
def init_encoder(self, encoder_model_filename: str):
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
def init_decoder(self, decoder_model_filename: str):
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
self.context_size = int(decoder_meta["context_size"])
|
||||
self.vocab_size = int(decoder_meta["vocab_size"])
|
||||
|
||||
logging.info(f"context_size: {self.context_size}")
|
||||
logging.info(f"vocab_size: {self.vocab_size}")
|
||||
|
||||
def init_joiner(self, joiner_model_filename: str):
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
self.joiner_dim = int(joiner_meta["joiner_dim"])
|
||||
|
||||
logging.info(f"joiner_dim: {self.joiner_dim}")
|
||||
|
||||
def run_encoder(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
x_lens:
|
||||
A 2-D tensor of shape (N,). Its dtype is torch.int64
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- encoder_out, its shape is (N, T', joiner_dim)
|
||||
- encoder_out_lens, its shape is (N,)
|
||||
"""
|
||||
out = self.encoder.run(
|
||||
[
|
||||
self.encoder.get_outputs()[0].name,
|
||||
self.encoder.get_outputs()[1].name,
|
||||
],
|
||||
{
|
||||
self.encoder.get_inputs()[0].name: x.numpy(),
|
||||
self.encoder.get_inputs()[1].name: x_lens.numpy(),
|
||||
},
|
||||
)
|
||||
return torch.from_numpy(out[0]), torch.from_numpy(out[1])
|
||||
|
||||
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
decoder_input:
|
||||
A 2-D tensor of shape (N, context_size)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
out = self.decoder.run(
|
||||
[self.decoder.get_outputs()[0].name],
|
||||
{self.decoder.get_inputs()[0].name: decoder_input.numpy()},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(out)
|
||||
|
||||
def run_joiner(
|
||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
decoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, vocab_size)
|
||||
"""
|
||||
out = self.joiner.run(
|
||||
[self.joiner.get_outputs()[0].name],
|
||||
{
|
||||
self.joiner.get_inputs()[0].name: encoder_out.numpy(),
|
||||
self.joiner.get_inputs()[1].name: decoder_out.numpy(),
|
||||
},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(out)
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: OnnxModel,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, joiner_dim)
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,).
|
||||
Returns:
|
||||
Return the decoded results for each utterance.
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
blank_id = 0 # hard-code to 0
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
context_size = model.context_size
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
dtype=torch.int64,
|
||||
) # (N, context_size)
|
||||
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = packed_encoder_out.data[start:end]
|
||||
# current_encoder_out's shape: (batch_size, joiner_dim)
|
||||
offset = end
|
||||
|
||||
decoder_out = decoder_out[:batch_size]
|
||||
logits = model.run_joiner(current_encoder_out, decoder_out)
|
||||
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
hyps[i].append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
model = OnnxModel(
|
||||
encoder_model_filename=args.encoder_model_filename,
|
||||
decoder_model_filename=args.decoder_model_filename,
|
||||
joiner_model_filename=args.joiner_model_filename,
|
||||
)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = args.sample_rate
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
expected_sample_rate=args.sample_rate,
|
||||
)
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
|
||||
encoder_out, encoder_out_lens = model.run_encoder(features, feature_lengths)
|
||||
|
||||
hyps = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
s = "\n"
|
||||
|
||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||
|
||||
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||
text = ""
|
||||
for i in token_ids:
|
||||
text += symbol_table[i]
|
||||
return text.replace("▁", " ").strip()
|
||||
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = token_ids_to_words(hyp)
|
||||
s += f"{filename}:\n{words}\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
1103
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/optim.py
Normal file
1103
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/optim.py
Normal file
File diff suppressed because it is too large
Load Diff
355
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/pretrained.py
Executable file
355
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/pretrained.py
Executable file
@ -0,0 +1,355 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads a checkpoint and uses it to decode waves.
|
||||
You can generate the checkpoint with the following command:
|
||||
|
||||
./pruned_transducer_stateless7/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
Usage of this script:
|
||||
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method greedy_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) beam search
|
||||
./pruned_transducer_stateless7/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless7/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(4) fast beam search
|
||||
./pruned_transducer_stateless7/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method fast_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`.
|
||||
|
||||
Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by
|
||||
./pruned_transducer_stateless7/export.py
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame. Used only when
|
||||
--method is greedy_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyps = []
|
||||
msg = f"Using {params.method}"
|
||||
if params.method == "beam_search":
|
||||
msg += f" with beam size {params.beam_size}"
|
||||
logging.info(msg)
|
||||
|
||||
if params.method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
for i in range(num_waves):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {params.method}")
|
||||
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,214 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This file replaces various modules in a model.
|
||||
Specifically, ActivationBalancer is replaced with an identity operator;
|
||||
Whiten is also replaced with an identity operator;
|
||||
BasicNorm is replaced by a module with `exp` removed.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import ActivationBalancer, BasicNorm, Whiten
|
||||
from zipformer import PoolingModule
|
||||
|
||||
|
||||
class PoolingModuleNoProj(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cached_len: torch.Tensor,
|
||||
cached_avg: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A tensor of shape (T, N, C)
|
||||
cached_len:
|
||||
A tensor of shape (N,)
|
||||
cached_avg:
|
||||
A tensor of shape (N, C)
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- new_x
|
||||
- new_cached_len
|
||||
- new_cached_avg
|
||||
"""
|
||||
x = x.cumsum(dim=0) # (T, N, C)
|
||||
x = x + (cached_avg * cached_len.unsqueeze(1)).unsqueeze(0)
|
||||
# Cumulated numbers of frames from start
|
||||
cum_mask = torch.arange(1, x.size(0) + 1, device=x.device)
|
||||
cum_mask = cum_mask.unsqueeze(1) + cached_len.unsqueeze(0) # (T, N)
|
||||
pooling_mask = (1.0 / cum_mask).unsqueeze(2)
|
||||
# now pooling_mask: (T, N, 1)
|
||||
x = x * pooling_mask # (T, N, C)
|
||||
|
||||
cached_len = cached_len + x.size(0)
|
||||
cached_avg = x[-1]
|
||||
|
||||
return x, cached_len, cached_avg
|
||||
|
||||
|
||||
class PoolingModuleWithProj(nn.Module):
|
||||
def __init__(self, proj: torch.nn.Module):
|
||||
super().__init__()
|
||||
self.proj = proj
|
||||
self.pooling = PoolingModuleNoProj()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cached_len: torch.Tensor,
|
||||
cached_avg: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A tensor of shape (T, N, C)
|
||||
cached_len:
|
||||
A tensor of shape (N,)
|
||||
cached_avg:
|
||||
A tensor of shape (N, C)
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- new_x
|
||||
- new_cached_len
|
||||
- new_cached_avg
|
||||
"""
|
||||
x, cached_len, cached_avg = self.pooling(x, cached_len, cached_avg)
|
||||
return self.proj(x), cached_len, cached_avg
|
||||
|
||||
def streaming_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cached_len: torch.Tensor,
|
||||
cached_avg: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A tensor of shape (T, N, C)
|
||||
cached_len:
|
||||
A tensor of shape (N,)
|
||||
cached_avg:
|
||||
A tensor of shape (N, C)
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- new_x
|
||||
- new_cached_len
|
||||
- new_cached_avg
|
||||
"""
|
||||
x, cached_len, cached_avg = self.pooling(x, cached_len, cached_avg)
|
||||
return self.proj(x), cached_len, cached_avg
|
||||
|
||||
|
||||
class NonScaledNorm(nn.Module):
|
||||
"""See BasicNorm for doc"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
eps_exp: float,
|
||||
channel_dim: int = -1, # CAUTION: see documentation.
|
||||
):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.channel_dim = channel_dim
|
||||
self.eps_exp = eps_exp
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if not torch.jit.is_tracing():
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
scales = (
|
||||
torch.mean(x * x, dim=self.channel_dim, keepdim=True) + self.eps_exp
|
||||
).pow(-0.5)
|
||||
return x * scales
|
||||
|
||||
|
||||
def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm:
|
||||
assert isinstance(basic_norm, BasicNorm), type(basic_norm)
|
||||
norm = NonScaledNorm(
|
||||
num_channels=basic_norm.num_channels,
|
||||
eps_exp=basic_norm.eps.data.exp().item(),
|
||||
channel_dim=basic_norm.channel_dim,
|
||||
)
|
||||
return norm
|
||||
|
||||
|
||||
def convert_pooling_module(pooling: PoolingModule) -> PoolingModuleWithProj:
|
||||
assert isinstance(pooling, PoolingModule), type(pooling)
|
||||
return PoolingModuleWithProj(proj=pooling.proj)
|
||||
|
||||
|
||||
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
|
||||
# get_submodule was added to nn.Module at v1.9.0
|
||||
def get_submodule(model, target):
|
||||
if target == "":
|
||||
return model
|
||||
atoms: List[str] = target.split(".")
|
||||
mod: torch.nn.Module = model
|
||||
for item in atoms:
|
||||
if not hasattr(mod, item):
|
||||
raise AttributeError(
|
||||
mod._get_name() + " has no " "attribute `" + item + "`"
|
||||
)
|
||||
mod = getattr(mod, item)
|
||||
if not isinstance(mod, torch.nn.Module):
|
||||
raise AttributeError("`" + item + "` is not " "an nn.Module")
|
||||
return mod
|
||||
|
||||
|
||||
def convert_scaled_to_non_scaled(
|
||||
model: nn.Module,
|
||||
inplace: bool = False,
|
||||
is_pnnx: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
model:
|
||||
The model to be converted.
|
||||
inplace:
|
||||
If True, the input model is modified inplace.
|
||||
If False, the input model is copied and we modify the copied version.
|
||||
is_pnnx:
|
||||
True if we are going to export the model for PNNX.
|
||||
Return:
|
||||
Return a model without scaled layers.
|
||||
"""
|
||||
if not inplace:
|
||||
model = copy.deepcopy(model)
|
||||
|
||||
d = {}
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, BasicNorm):
|
||||
d[name] = convert_basic_norm(m)
|
||||
elif isinstance(m, (ActivationBalancer, Whiten)):
|
||||
d[name] = nn.Identity()
|
||||
elif isinstance(m, PoolingModule) and is_pnnx:
|
||||
d[name] = convert_pooling_module(m)
|
||||
|
||||
for k, v in d.items():
|
||||
if "." in k:
|
||||
parent, child = k.rsplit(".", maxsplit=1)
|
||||
setattr(get_submodule(model, parent), child, v)
|
||||
else:
|
||||
setattr(model, k, v)
|
||||
|
||||
return model
|
||||
@ -0,0 +1,337 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# Taken from: https://github.com/facebookresearch/fbai-speech/blob/main/is21_deep_bias/score.py
|
||||
|
||||
from collections import deque
|
||||
from enum import Enum
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import json
|
||||
from pathlib import Path, PosixPath
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.addHandler(logging.StreamHandler())
|
||||
|
||||
|
||||
class Code(Enum):
|
||||
match = 1
|
||||
substitution = 2
|
||||
insertion = 3
|
||||
deletion = 4
|
||||
|
||||
|
||||
class AlignmentResult(object):
|
||||
def __init__(self, refs, hyps, codes, score):
|
||||
self.refs = refs # deque<int>
|
||||
self.hyps = hyps # deque<int>
|
||||
self.codes = codes # deque<Code>
|
||||
self.score = score # float
|
||||
|
||||
|
||||
class WordError(object):
|
||||
def __init__(self):
|
||||
self.errors = {
|
||||
Code.substitution: 0,
|
||||
Code.insertion: 0,
|
||||
Code.deletion: 0,
|
||||
}
|
||||
self.ref_words = 0
|
||||
|
||||
def get_wer(self):
|
||||
assert self.ref_words != 0
|
||||
errors = (
|
||||
self.errors[Code.substitution]
|
||||
+ self.errors[Code.insertion]
|
||||
+ self.errors[Code.deletion]
|
||||
)
|
||||
return 100.0 * errors / self.ref_words
|
||||
|
||||
def get_result_string(self):
|
||||
return (
|
||||
f"error_rate={self.get_wer()}, "
|
||||
f"ref_words={self.ref_words}, "
|
||||
f"subs={self.errors[Code.substitution]}, "
|
||||
f"ins={self.errors[Code.insertion]}, "
|
||||
f"dels={self.errors[Code.deletion]}"
|
||||
)
|
||||
|
||||
|
||||
def coordinate_to_offset(row, col, ncols):
|
||||
return int(row * ncols + col)
|
||||
|
||||
|
||||
def offset_to_row(offset, ncols):
|
||||
return int(offset / ncols)
|
||||
|
||||
|
||||
def offset_to_col(offset, ncols):
|
||||
return int(offset % ncols)
|
||||
|
||||
|
||||
class EditDistance(object):
|
||||
def __init__(self):
|
||||
self.scores_ = None
|
||||
self.backtraces_ = None
|
||||
self.confusion_pairs_ = {}
|
||||
self.inserted_words_ = {}
|
||||
self.deleted_words_ = {}
|
||||
|
||||
def cost(self, ref, hyp, code):
|
||||
if code == Code.match:
|
||||
return 0
|
||||
elif code == Code.insertion or code == Code.deletion:
|
||||
return 3
|
||||
else: # substitution
|
||||
return 4
|
||||
|
||||
def get_result(self, refs, hyps):
|
||||
res = AlignmentResult(refs=deque(), hyps=deque(), codes=deque(), score=None)
|
||||
|
||||
num_rows, num_cols = len(self.scores_), len(self.scores_[0])
|
||||
res.score = self.scores_[num_rows - 1][num_cols - 1]
|
||||
|
||||
curr_offset = coordinate_to_offset(num_rows - 1, num_cols - 1, num_cols)
|
||||
|
||||
while curr_offset != 0:
|
||||
curr_row = offset_to_row(curr_offset, num_cols)
|
||||
curr_col = offset_to_col(curr_offset, num_cols)
|
||||
|
||||
prev_offset = self.backtraces_[curr_row][curr_col]
|
||||
|
||||
prev_row = offset_to_row(prev_offset, num_cols)
|
||||
prev_col = offset_to_col(prev_offset, num_cols)
|
||||
|
||||
res.refs.appendleft(curr_row - 1)
|
||||
res.hyps.appendleft(curr_col - 1)
|
||||
if curr_row - 1 == prev_row and curr_col == prev_col:
|
||||
ref_str = refs[res.refs[0]]
|
||||
deleted_word = ref_str
|
||||
if deleted_word not in self.deleted_words_:
|
||||
self.deleted_words_[deleted_word] = 1
|
||||
else:
|
||||
self.deleted_words_[deleted_word] += 1
|
||||
|
||||
res.codes.appendleft(Code.deletion)
|
||||
|
||||
elif curr_row == prev_row and curr_col - 1 == prev_col:
|
||||
hyp_str = hyps[res.hyps[0]]
|
||||
inserted_word = hyp_str
|
||||
if inserted_word not in self.inserted_words_:
|
||||
self.inserted_words_[inserted_word] = 1
|
||||
else:
|
||||
self.inserted_words_[inserted_word] += 1
|
||||
|
||||
res.codes.appendleft(Code.insertion)
|
||||
|
||||
else:
|
||||
# assert(curr_row - 1 == prev_row and curr_col - 1 == prev_col)
|
||||
ref_str = refs[res.refs[0]]
|
||||
hyp_str = hyps[res.hyps[0]]
|
||||
|
||||
if ref_str == hyp_str:
|
||||
res.codes.appendleft(Code.match)
|
||||
else:
|
||||
res.codes.appendleft(Code.substitution)
|
||||
|
||||
confusion_pair = "%s -> %s" % (ref_str, hyp_str)
|
||||
if confusion_pair not in self.confusion_pairs_:
|
||||
self.confusion_pairs_[confusion_pair] = 1
|
||||
else:
|
||||
self.confusion_pairs_[confusion_pair] += 1
|
||||
|
||||
curr_offset = prev_offset
|
||||
|
||||
return res
|
||||
|
||||
def align(self, refs, hyps):
|
||||
if len(refs) == 0 and len(hyps) == 0:
|
||||
raise ValueError("Doesn't support empty ref AND hyp!")
|
||||
|
||||
# NOTE: we're not resetting the values in these matrices because every value
|
||||
# will be overridden in the loop below. If this assumption doesn't hold,
|
||||
# be sure to set all entries in self.scores_ and self.backtraces_ to 0.
|
||||
self.scores_ = [[0.0] * (len(hyps) + 1) for _ in range(len(refs) + 1)]
|
||||
self.backtraces_ = [[0] * (len(hyps) + 1) for _ in range(len(refs) + 1)]
|
||||
|
||||
num_rows, num_cols = len(self.scores_), len(self.scores_[0])
|
||||
|
||||
for i in range(num_rows):
|
||||
for j in range(num_cols):
|
||||
if i == 0 and j == 0:
|
||||
self.scores_[i][j] = 0.0
|
||||
self.backtraces_[i][j] = 0
|
||||
continue
|
||||
|
||||
if i == 0:
|
||||
self.scores_[i][j] = self.scores_[i][j - 1] + self.cost(
|
||||
None, hyps[j - 1], Code.insertion
|
||||
)
|
||||
self.backtraces_[i][j] = coordinate_to_offset(i, j - 1, num_cols)
|
||||
continue
|
||||
|
||||
if j == 0:
|
||||
self.scores_[i][j] = self.scores_[i - 1][j] + self.cost(
|
||||
refs[i - 1], None, Code.deletion
|
||||
)
|
||||
self.backtraces_[i][j] = coordinate_to_offset(i - 1, j, num_cols)
|
||||
continue
|
||||
|
||||
# Below here both i and j are greater than 0
|
||||
ref = refs[i - 1]
|
||||
hyp = hyps[j - 1]
|
||||
best_score = self.scores_[i - 1][j - 1] + (
|
||||
self.cost(ref, hyp, Code.match)
|
||||
if ref == hyp
|
||||
else self.cost(ref, hyp, Code.substitution)
|
||||
)
|
||||
|
||||
prev_row = i - 1
|
||||
prev_col = j - 1
|
||||
ins = self.scores_[i][j - 1] + self.cost(None, hyp, Code.insertion)
|
||||
if ins < best_score:
|
||||
best_score = ins
|
||||
prev_row = i
|
||||
prev_col = j - 1
|
||||
|
||||
delt = self.scores_[i - 1][j] + self.cost(ref, None, Code.deletion)
|
||||
if delt < best_score:
|
||||
best_score = delt
|
||||
prev_row = i - 1
|
||||
prev_col = j
|
||||
|
||||
self.scores_[i][j] = best_score
|
||||
self.backtraces_[i][j] = coordinate_to_offset(
|
||||
prev_row, prev_col, num_cols
|
||||
)
|
||||
|
||||
return self.get_result(refs, hyps)
|
||||
|
||||
|
||||
def main(args):
|
||||
refs = {}
|
||||
if type(args.refs) is str or type(args.refs) is PosixPath:
|
||||
with open(args.refs, "r") as f:
|
||||
for line in f:
|
||||
ary = line.strip().split("\t")
|
||||
uttid, ref, biasing_words = ary[0], ary[1], set(json.loads(ary[2]))
|
||||
refs[uttid] = {"text": ref, "biasing_words": biasing_words}
|
||||
logger.info("Loaded %d reference utts from %s", len(refs), args.refs)
|
||||
elif type(args.refs) is dict:
|
||||
refs = args.refs
|
||||
logger.info("Loaded %d reference utts", len(refs))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
hyps = {}
|
||||
if type(args.hyps) is str or type(args.hyps) is PosixPath:
|
||||
with open(args.hyps, "r") as f:
|
||||
for line in f:
|
||||
ary = line.strip().split("\t")
|
||||
# May have empty hypo
|
||||
if len(ary) >= 2:
|
||||
uttid, hyp = ary[0], ary[1]
|
||||
else:
|
||||
uttid, hyp = ary[0], ""
|
||||
hyps[uttid] = hyp
|
||||
logger.info("Loaded %d hypothesis utts from %s", len(hyps), args.hyps)
|
||||
elif type(args.hyps) is dict:
|
||||
hyps = args.hyps
|
||||
logger.info("Loaded %d hypothesis utts", len(hyps))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if not args.lenient:
|
||||
for uttid in refs:
|
||||
if uttid in hyps:
|
||||
continue
|
||||
raise ValueError(
|
||||
f"{uttid} missing in hyps! Set `--lenient` flag to ignore this error."
|
||||
)
|
||||
|
||||
# train_rare_count = dict()
|
||||
# with open("", "r") as fin:
|
||||
# for line in fin:
|
||||
# w, c = line.strip().split()
|
||||
# train_rare_count[w] = int(c)
|
||||
|
||||
test_rare_count = dict()
|
||||
|
||||
# Calculate WER, U-WER, and B-WER
|
||||
wer = WordError()
|
||||
u_wer = WordError()
|
||||
b_wer = WordError()
|
||||
for uttid in refs:
|
||||
if uttid not in hyps:
|
||||
continue
|
||||
ref_tokens = refs[uttid]["text"].split()
|
||||
biasing_words = refs[uttid]["biasing_words"]
|
||||
hyp_tokens = hyps[uttid].split()
|
||||
ed = EditDistance()
|
||||
result = ed.align(ref_tokens, hyp_tokens)
|
||||
for code, ref_idx, hyp_idx in zip(result.codes, result.refs, result.hyps):
|
||||
if code == Code.match:
|
||||
wer.ref_words += 1
|
||||
if ref_tokens[ref_idx] in biasing_words:
|
||||
b_wer.ref_words += 1
|
||||
else:
|
||||
u_wer.ref_words += 1
|
||||
elif code == Code.substitution:
|
||||
wer.ref_words += 1
|
||||
wer.errors[Code.substitution] += 1
|
||||
if ref_tokens[ref_idx] in biasing_words:
|
||||
b_wer.ref_words += 1
|
||||
b_wer.errors[Code.substitution] += 1
|
||||
else:
|
||||
u_wer.ref_words += 1
|
||||
u_wer.errors[Code.substitution] += 1
|
||||
elif code == Code.deletion:
|
||||
wer.ref_words += 1
|
||||
wer.errors[Code.deletion] += 1
|
||||
if ref_tokens[ref_idx] in biasing_words:
|
||||
b_wer.ref_words += 1
|
||||
b_wer.errors[Code.deletion] += 1
|
||||
else:
|
||||
u_wer.ref_words += 1
|
||||
u_wer.errors[Code.deletion] += 1
|
||||
elif code == Code.insertion:
|
||||
wer.errors[Code.insertion] += 1
|
||||
if hyp_tokens[hyp_idx] in biasing_words:
|
||||
b_wer.errors[Code.insertion] += 1
|
||||
else:
|
||||
u_wer.errors[Code.insertion] += 1
|
||||
|
||||
# Report results
|
||||
print(f"WER: {wer.get_result_string()}")
|
||||
print(f"U-WER: {u_wer.get_result_string()}")
|
||||
print(f"B-WER: {b_wer.get_result_string()}")
|
||||
print(f"{wer.get_wer():.2f}({u_wer.get_wer():.2f}/{b_wer.get_wer():.2f})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
desc = "Compute WER, U-WER, and B-WER. Results are output to stdout."
|
||||
parser = argparse.ArgumentParser(description=desc)
|
||||
parser.add_argument(
|
||||
"--refs",
|
||||
required=True,
|
||||
help="Path to tab-separated reference file. First column is utterance ID. "
|
||||
"Second column is reference text. Last column is list of biasing words.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hyps",
|
||||
required=True,
|
||||
help="Path to tab-separated hypothesis file. First column is utterance ID. "
|
||||
"Second column is hypothesis text.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lenient",
|
||||
action="store_true",
|
||||
help="If set, hyps doesn't have to cover all of refs.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@ -0,0 +1,197 @@
|
||||
#!/usr/bin/env bash
|
||||
#$ -wd /exp/rhuang/meta/icefall/egs/librispeech/ASR/
|
||||
#$ -V
|
||||
#$ -N train_context
|
||||
#$ -j y -o /exp/rhuang/meta/icefall/egs/librispeech/ASR/log/log-$JOB_NAME-$JOB_ID.out
|
||||
#$ -M ruizhe@jhu.edu
|
||||
#$ -m e
|
||||
#$ -l mem_free=32G,h_rt=600:00:00,gpu=4,hostname=!r7n07*
|
||||
#$ -q gpu.q@@v100
|
||||
|
||||
# #$ -q gpu.q@@v100
|
||||
# #$ -q gpu.q@@rtx
|
||||
|
||||
# #$ -l ram_free=300G,mem_free=300G,gpu=0,hostname=b*
|
||||
|
||||
# hostname=b19
|
||||
# hostname=!c04*&!b*&!octopod*
|
||||
# hostname
|
||||
# nvidia-smi
|
||||
|
||||
# conda activate /home/hltcoe/rhuang/mambaforge/envs/aligner5
|
||||
export PATH="/home/hltcoe/rhuang/mambaforge/envs/aligner5/bin/":$PATH
|
||||
module load cuda11.7/toolkit
|
||||
module load cudnn/8.5.0.96_cuda11.x
|
||||
module load nccl/2.13.4-1_cuda11.7
|
||||
module load gcc/7.2.0
|
||||
module load intel/mkl/64/2019/5.281
|
||||
|
||||
which python
|
||||
nvcc --version
|
||||
nvidia-smi
|
||||
date
|
||||
|
||||
# k2
|
||||
K2_ROOT=/exp/rhuang/meta/k2/
|
||||
export PYTHONPATH=$K2_ROOT/k2/python:$PYTHONPATH # for `import k2`
|
||||
export PYTHONPATH=$K2_ROOT/temp.linux-x86_64-cpython-310/lib:$PYTHONPATH # for `import _k2`
|
||||
export PYTHONPATH=/exp/rhuang/meta/icefall:$PYTHONPATH
|
||||
|
||||
# # torchaudio recipe
|
||||
# cd /exp/rhuang/meta/audio
|
||||
# cd examples/asr/librispeech_conformer_ctc
|
||||
|
||||
# To verify SGE_HGR_gpu and CUDA_VISIBLE_DEVICES match for GPU jobs.
|
||||
env | grep SGE_HGR_gpu
|
||||
env | grep CUDA_VISIBLE_DEVICES
|
||||
echo "hostname: `hostname`"
|
||||
echo "current path:" `pwd`
|
||||
|
||||
# export PYTHONPATH=/exp/rhuang/meta/audio/examples/asr/librispeech_conformer_ctc2:$PYTHONPATH
|
||||
|
||||
# exp_dir=/exp/rhuang/meta/icefall/egs/librispeech/ASR/pruned_transducer_stateless7_context_ali/exp/exp_libri # 11073148
|
||||
# exp_dir=/exp/rhuang/meta/icefall/egs/librispeech/ASR/pruned_transducer_stateless7_context_ali/exp/exp_libri_100 # 11073150
|
||||
# exp_dir=/exp/rhuang/meta/icefall/egs/librispeech/ASR/pruned_transducer_stateless7_context_ali/exp/exp_libri_100_ts # 11073234, 11073240, 11073243
|
||||
# exp_dir=/exp/rhuang/meta/icefall/egs/librispeech/ASR/pruned_transducer_stateless7_context_ali/exp/exp_libri_ts # 11073238, 11073255, log-train-2024-01-13-20-11-00 => 11073331 => log-train-2024-01-14-02-16-10
|
||||
# exp_dir=/exp/rhuang/meta/icefall/egs/librispeech/ASR/pruned_transducer_stateless7_context_ali/exp/exp_libri_ts2 # log-train-2024-01-14-07-22-54-0
|
||||
|
||||
# exp_dir=/exp/rhuang/meta/icefall/egs/librispeech/ASR/pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri_100
|
||||
# exp_dir=/exp/rhuang/meta/icefall/egs/librispeech/ASR/pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri2 # baseline, no biasing
|
||||
# exp_dir=/exp/rhuang/meta/icefall/egs/librispeech/ASR/pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri # 11169512
|
||||
# exp_dir=/exp/rhuang/meta/icefall/egs/librispeech/ASR/pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri_proxy # 11169515, 11169916
|
||||
# exp_dir=pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri_proxy_34
|
||||
# exp_dir=pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri_proxy_3ctc # 11171405, log-train-2024-03-04-02-10-27-2,
|
||||
# exp_dir=pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri_proxy_3ctc # 11171405, log-train-2024-03-04-21-22-31-0
|
||||
# exp_dir=pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri_proxy_3ctc_attn #
|
||||
|
||||
# exp_dir=pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri_proxy_2early
|
||||
# exp_dir=pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri_proxy_4early
|
||||
# exp_dir=pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri_proxy_234early
|
||||
# exp_dir=pruned_transducer_stateless7_context_proxy_all_layers/exp/exp_libri_proxy_3early_no5
|
||||
|
||||
exp_dir=pruned_transducer_stateless7_contextual/exp/exp_libri_test
|
||||
|
||||
mkdir -p $exp_dir
|
||||
|
||||
echo
|
||||
echo "exp_dir:" $exp_dir
|
||||
echo
|
||||
|
||||
path_to_pretrained_asr_model=/exp/rhuang/librispeech/pretrained2/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/
|
||||
# path_to_pretrained_asr_model=/scratch4/skhudan1/rhuang25/icefall/egs/librispeech/ASR/download/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
|
||||
|
||||
# From pretrained ASR model
|
||||
if [ ! -f $exp_dir/epoch-1.pt ]; then
|
||||
ln -s $path_to_pretrained_asr_model/exp/pretrained.pt $exp_dir/epoch-1.pt
|
||||
fi
|
||||
|
||||
####################################
|
||||
# train
|
||||
####################################
|
||||
|
||||
# if false; then
|
||||
# echo "True"
|
||||
# else
|
||||
# echo "False"
|
||||
# fi
|
||||
|
||||
if true; then
|
||||
# # stage 1:
|
||||
max_duration=1600
|
||||
# max_duration=400 # libri100
|
||||
n_distractors=0
|
||||
is_full_context=true
|
||||
|
||||
# stage 2:
|
||||
# max_duration=1600
|
||||
# # max_duration=400 # libri100
|
||||
# n_distractors=100
|
||||
# is_full_context=false
|
||||
|
||||
python pruned_transducer_stateless7_contextual/train.py \
|
||||
--world-size 4 \
|
||||
--use-fp16 true \
|
||||
--max-duration $max_duration \
|
||||
--exp-dir $exp_dir \
|
||||
--bpe-model "data/lang_bpe_500/bpe.model" \
|
||||
--prune-range 5 \
|
||||
--full-libri true \
|
||||
--context-dir "data/fbai-speech/is21_deep_bias/" \
|
||||
--keep-ratio 1.0 \
|
||||
--start-epoch 2 \
|
||||
--num-epochs 30 \
|
||||
--is-pretrained-context-encoder false \
|
||||
--is-reused-context-encoder false \
|
||||
--is-full-context $is_full_context \
|
||||
--n-distractors $n_distractors --start-epoch 14 --num-epochs 40 --master-port 12357 --proxy-prob 0.2 --keep-ratio 0.8 --throwaway-prob 0.7 # --start-batch 24000 # --base-lr 0.08 --master-port 12355 --irrelevance-learning true
|
||||
fi
|
||||
|
||||
--n-distractors $n_distractors --master-port 12357 --proxy-prob 0.4 --early-layers 2 --enable-nn true
|
||||
--n-distractors $n_distractors --master-port 12357 --proxy-prob 0.4 --early-layers 4 --enable-nn true
|
||||
--n-distractors $n_distractors --master-port 12357 --proxy-prob 0.4 --early-layers 2,3,4 --enable-nn true
|
||||
|
||||
####################################
|
||||
# tensorboard
|
||||
####################################
|
||||
# tensorboard dev upload --logdir /exp/rhuang/meta/icefall/egs/librispeech/ASR/$exp_dir/tensorboard --description `pwd`
|
||||
# wandb sync $exp_dir/tensorboard
|
||||
|
||||
# https://github.com/k2-fsa/icefall/issues/1298
|
||||
# python -c "import wandb; wandb.init(project='icefall-asr-gigaspeech-zipformer-2023-10-20')"
|
||||
# wandb sync zipformer/exp/tensorboard -p icefall-asr-gigaspeech-zipformer-2023-10-20
|
||||
|
||||
# https://stackoverflow.com/questions/37987839/how-can-i-run-tensorboard-on-a-remote-server
|
||||
# ssh -L 16006:127.0.0.1:6006 rhuang@test1.hltcoe.jhu.edu
|
||||
# tensorboard --logdir $exp_dir/tensorboard --port 6006
|
||||
# http://localhost:16006
|
||||
|
||||
# no-biasing: /exp/rhuang/icefall_latest/egs/spgispeech/ASR/pruned_transducer_stateless7/exp_500_norm/tensorboard/
|
||||
# bi_enc: /exp/rhuang/icefall_latest/egs/spgispeech/ASR/pruned_transducer_stateless7_context/exp/exp_libri_full_c-1_stage1/
|
||||
# single_enc: /exp/rhuang/icefall_latest/egs/spgispeech/ASR/pruned_transducer_stateless7_context/exp/exp_libri_full_c-1_stage1_single_enc
|
||||
|
||||
|
||||
# exp_dir=pruned_transducer_stateless7_context_ali/exp
|
||||
# n_distractors=0
|
||||
# max_duration=1200
|
||||
# python /exp/rhuang/meta/icefall/egs/spgispeech/ASR/pruned_transducer_stateless7_context_ali/train.py --world-size 1 --use-fp16 true --max-duration $max_duration --exp-dir $exp_dir --bpe-model "data/lang_bpe_500/bpe.model" --prune-range 5 --use-fp16 true --context-dir "data/uniphore_contexts/" --keep-ratio 1.0 --start-epoch 2 --num-epochs 30 --is-bi-context-encoder false --is-pretrained-context-encoder false --is-full-context true --n-distractors $n_distractors
|
||||
|
||||
|
||||
####### debug: RuntimeError: grad_scale is too small, exiting: 8.470329472543003e-22
|
||||
# encoder_out=rs["encoder_out"]; contexts_h=rs["contexts_h"]; contexts_mask=rs["contexts_mask"]
|
||||
# queries=encoder_out; contexts=contexts_h; contexts_mask=contexts_mask; need_weights=True
|
||||
# md = rs["encoder_biasing_adapter"]
|
||||
# with torch.cuda.amp.autocast(enabled=True):
|
||||
# queries = md.proj_in(queries)
|
||||
# print("queries:", torch.any(torch.isnan(queries) | torch.isinf(queries)))
|
||||
# attn_output, attn_output_weights = md.multihead_attn(queries,contexts,contexts,key_padding_mask=contexts_mask,need_weights=need_weights,)
|
||||
# print(torch.any(torch.isnan(attn_output) | torch.isinf(attn_output)))
|
||||
# print(torch.any(torch.isnan(attn_output_weights) | torch.isinf(attn_output_weights)))
|
||||
# output = md.proj_out(attn_output)
|
||||
# print(torch.any(torch.isnan(output) | torch.isinf(output)))
|
||||
|
||||
# encoder_out=rs["encoder_out"]; contexts_h=rs["contexts_h"]; contexts_mask=rs["contexts_mask"]
|
||||
# queries=encoder_out; contexts=contexts_h; contexts_mask=contexts_mask; need_weights=True
|
||||
# md = rs["encoder_biasing_adapter"]
|
||||
# with torch.cuda.amp.autocast(enabled=False):
|
||||
# queries = md.proj_in(queries)
|
||||
# print("queries:", torch.any(torch.isnan(queries) | torch.isinf(queries)))
|
||||
# attn_output, attn_output_weights = md.multihead_attn(queries,contexts,contexts,key_padding_mask=contexts_mask,need_weights=need_weights,)
|
||||
# print(torch.any(torch.isnan(attn_output) | torch.isinf(attn_output)))
|
||||
# print(torch.any(torch.isnan(attn_output_weights) | torch.isinf(attn_output_weights)))
|
||||
# output = md.proj_out(attn_output)
|
||||
# print(torch.any(torch.isnan(output) | torch.isinf(output)))
|
||||
|
||||
# encoder_out=rs["encoder_out"]; contexts_h=rs["contexts_h"]; contexts_mask=rs["contexts_mask"]
|
||||
# queries=encoder_out; contexts=contexts_h; contexts_mask=contexts_mask; need_weights=True
|
||||
# md = rs["encoder_biasing_adapter"]
|
||||
# queries = md.proj_in(queries)
|
||||
# print("queries:", torch.any(torch.isnan(queries) | torch.isinf(queries)))
|
||||
# attn_output, attn_output_weights = md.multihead_attn(queries,contexts,contexts,key_padding_mask=contexts_mask,need_weights=need_weights,)
|
||||
# print(torch.any(torch.isnan(attn_output) | torch.isinf(attn_output)))
|
||||
# print(torch.any(torch.isnan(attn_output_weights) | torch.isinf(attn_output_weights)))
|
||||
# output = md.proj_out(attn_output)
|
||||
# print(torch.any(torch.isnan(output) | torch.isinf(output)))
|
||||
|
||||
|
||||
|
||||
python pruned_transducer_stateless7_context_proxy_all_layers/train.py --world-size 4 --use-fp16 true --max-duration $max_duration --exp-dir $exp_dir --bpe-model "data/lang_bpe_500/bpe.model" --prune-range 5 --full-libri true --context-dir "data/fbai-speech/is21_deep_bias/" --keep-ratio 1.0 --start-epoch 2 --num-epochs 30 --is-pretrained-context-encoder false --is-reused-context-encoder false --is-full-context $is_full_context --n-distractors $n_distractors --start-epoch 14 --num-epochs 40 --master-port 12357 --proxy-prob 0.2 --keep-ratio 0.8 --throwaway-prob 0.7 --n-distractors 100 --start-epoch 29 --num-epochs 40 --world-size 4 --max-duration 1800
|
||||
130
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/test_compute_ali.py
Executable file
130
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/test_compute_ali.py
Executable file
@ -0,0 +1,130 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script compares the word-level alignments generated based on modified_beam_search decoding
|
||||
(in ./pruned_transducer_stateless7/compute_ali.py) to the reference alignments generated
|
||||
by torchaudio framework (in ./add_alignments.sh).
|
||||
|
||||
Usage:
|
||||
|
||||
./pruned_transducer_stateless7/compute_ali.py \
|
||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--dataset test-clean \
|
||||
--max-duration 300 \
|
||||
--beam-size 4 \
|
||||
--cuts-out-dir data/fbank_ali_beam_search
|
||||
|
||||
And the you can run:
|
||||
|
||||
./pruned_transducer_stateless7/test_compute_ali.py \
|
||||
--cuts-out-dir ./data/fbank_ali_test \
|
||||
--cuts-ref-dir ./data/fbank_ali_torch \
|
||||
--dataset train-clean-100
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import load_manifest
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cuts-out-dir",
|
||||
type=Path,
|
||||
default="./data/fbank_ali",
|
||||
help="The dir that saves the generated cuts manifests with alignments",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cuts-ref-dir",
|
||||
type=Path,
|
||||
default="./data/fbank_ali_torch",
|
||||
help="The dir that saves the reference cuts manifests with alignments",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=True,
|
||||
help="""The name of the dataset:
|
||||
Possible values are:
|
||||
- test-clean
|
||||
- test-other
|
||||
- train-clean-100
|
||||
- train-clean-360
|
||||
- train-other-500
|
||||
- dev-clean
|
||||
- dev-other
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
|
||||
cuts_out_jsonl = args.cuts_out_dir / f"librispeech_cuts_{args.dataset}.jsonl.gz"
|
||||
cuts_ref_jsonl = args.cuts_ref_dir / f"librispeech_cuts_{args.dataset}.jsonl.gz"
|
||||
|
||||
logging.info(f"Loading {cuts_out_jsonl} and {cuts_ref_jsonl}")
|
||||
cuts_out = load_manifest(cuts_out_jsonl)
|
||||
cuts_ref = load_manifest(cuts_ref_jsonl)
|
||||
cuts_ref = cuts_ref.sort_like(cuts_out)
|
||||
|
||||
all_time_diffs = []
|
||||
for cut_out, cut_ref in zip(cuts_out, cuts_ref):
|
||||
time_out = [
|
||||
ali.start
|
||||
for ali in cut_out.supervisions[0].alignment["word"]
|
||||
if ali.symbol != ""
|
||||
]
|
||||
time_ref = [
|
||||
ali.start
|
||||
for ali in cut_ref.supervisions[0].alignment["word"]
|
||||
if ali.symbol != ""
|
||||
]
|
||||
assert len(time_out) == len(time_ref), (len(time_out), len(time_ref))
|
||||
diff = [
|
||||
round(abs(out - ref), ndigits=3) for out, ref in zip(time_out, time_ref)
|
||||
]
|
||||
all_time_diffs += diff
|
||||
|
||||
all_time_diffs = torch.tensor(all_time_diffs)
|
||||
logging.info(
|
||||
f"For the word-level alignments abs difference on dataset {args.dataset}, "
|
||||
f"mean: {'%.2f' % all_time_diffs.mean()}s, std: {'%.2f' % all_time_diffs.std()}s"
|
||||
)
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
68
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/test_model.py
Executable file
68
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/test_model.py
Executable file
@ -0,0 +1,68 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./pruned_transducer_stateless7/test_model.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def test_model():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.num_encoder_layers = "2,4,3,2,4"
|
||||
params.feedforward_dims = "1024,1024,2048,2048,1024"
|
||||
params.nhead = "8,8,8,8,8"
|
||||
params.encoder_dims = "384,384,384,384,384"
|
||||
params.attention_dims = "192,192,192,192,192"
|
||||
params.encoder_unmasked_dims = "256,256,256,256,256"
|
||||
params.zipformer_downsampling_factors = "1,2,4,8,2"
|
||||
params.cnn_module_kernels = "31,31,31,31,31"
|
||||
params.decoder_dim = 512
|
||||
params.joiner_dim = 512
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
# Test jit script
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
print("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
|
||||
|
||||
def main():
|
||||
test_model()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -0,0 +1,374 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file is to test that models can be exported to onnx.
|
||||
"""
|
||||
import os
|
||||
|
||||
from icefall import is_module_available
|
||||
|
||||
if not is_module_available("onnxruntime"):
|
||||
raise ValueError("Please 'pip install onnxruntime' first.")
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from zipformer import (
|
||||
Conv2dSubsampling,
|
||||
RelPositionalEncoding,
|
||||
Zipformer,
|
||||
ZipformerEncoder,
|
||||
ZipformerEncoderLayer,
|
||||
)
|
||||
|
||||
ort.set_default_logger_severity(3)
|
||||
|
||||
|
||||
def test_conv2d_subsampling():
|
||||
filename = "conv2d_subsampling.onnx"
|
||||
opset_version = 13
|
||||
N = 30
|
||||
T = 50
|
||||
num_features = 80
|
||||
d_model = 512
|
||||
x = torch.rand(N, T, num_features)
|
||||
|
||||
encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||
encoder_embed.eval()
|
||||
encoder_embed = convert_scaled_to_non_scaled(encoder_embed, inplace=True)
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_embed,
|
||||
x,
|
||||
filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x"],
|
||||
output_names=["y"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"y": {0: "N", 1: "T"},
|
||||
},
|
||||
)
|
||||
|
||||
options = ort.SessionOptions()
|
||||
options.inter_op_num_threads = 1
|
||||
options.intra_op_num_threads = 1
|
||||
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
assert input_nodes[0].name == "x"
|
||||
assert input_nodes[0].shape == ["N", "T", num_features]
|
||||
|
||||
inputs = {input_nodes[0].name: x.numpy()}
|
||||
|
||||
onnx_y = session.run(["y"], inputs)[0]
|
||||
|
||||
onnx_y = torch.from_numpy(onnx_y)
|
||||
torch_y = encoder_embed(x)
|
||||
assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max()
|
||||
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
def test_rel_pos():
|
||||
filename = "rel_pos.onnx"
|
||||
|
||||
opset_version = 13
|
||||
N = 30
|
||||
T = 50
|
||||
num_features = 80
|
||||
d_model = 512
|
||||
x = torch.rand(N, T, num_features)
|
||||
|
||||
encoder_pos = RelPositionalEncoding(d_model, dropout_rate=0.1)
|
||||
encoder_pos.eval()
|
||||
encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True)
|
||||
|
||||
x = x.permute(1, 0, 2)
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_pos,
|
||||
x,
|
||||
filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x"],
|
||||
output_names=["pos_emb"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"pos_emb": {0: "N", 1: "T"},
|
||||
},
|
||||
)
|
||||
|
||||
options = ort.SessionOptions()
|
||||
options.inter_op_num_threads = 1
|
||||
options.intra_op_num_threads = 1
|
||||
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
assert input_nodes[0].name == "x"
|
||||
assert input_nodes[0].shape == ["N", "T", num_features]
|
||||
|
||||
inputs = {input_nodes[0].name: x.numpy()}
|
||||
onnx_pos_emb = session.run(["pos_emb"], inputs)
|
||||
onnx_pos_emb = torch.from_numpy(onnx_pos_emb[0])
|
||||
|
||||
torch_pos_emb = encoder_pos(x)
|
||||
assert torch.allclose(onnx_pos_emb, torch_pos_emb, atol=1e-05), (
|
||||
(onnx_pos_emb - torch_pos_emb).abs().max()
|
||||
)
|
||||
print(onnx_pos_emb.abs().sum(), torch_pos_emb.abs().sum())
|
||||
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
def test_zipformer_encoder_layer():
|
||||
filename = "zipformer_encoder_layer.onnx"
|
||||
opset_version = 13
|
||||
N = 30
|
||||
T = 50
|
||||
|
||||
d_model = 384
|
||||
attention_dim = 192
|
||||
nhead = 8
|
||||
feedforward_dim = 1024
|
||||
dropout = 0.1
|
||||
cnn_module_kernel = 31
|
||||
pos_dim = 4
|
||||
|
||||
x = torch.rand(N, T, d_model)
|
||||
|
||||
encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
encoder_pos.eval()
|
||||
encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True)
|
||||
|
||||
x = x.permute(1, 0, 2)
|
||||
pos_emb = encoder_pos(x)
|
||||
|
||||
encoder_layer = ZipformerEncoderLayer(
|
||||
d_model,
|
||||
attention_dim,
|
||||
nhead,
|
||||
feedforward_dim,
|
||||
dropout,
|
||||
cnn_module_kernel,
|
||||
pos_dim,
|
||||
)
|
||||
encoder_layer.eval()
|
||||
encoder_layer = convert_scaled_to_non_scaled(encoder_layer, inplace=True)
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_layer,
|
||||
(x, pos_emb),
|
||||
filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "pos_emb"],
|
||||
output_names=["y"],
|
||||
dynamic_axes={
|
||||
"x": {0: "T", 1: "N"},
|
||||
"pos_emb": {0: "N", 1: "T"},
|
||||
"y": {0: "T", 1: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
options = ort.SessionOptions()
|
||||
options.inter_op_num_threads = 1
|
||||
options.intra_op_num_threads = 1
|
||||
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
inputs = {
|
||||
input_nodes[0].name: x.numpy(),
|
||||
input_nodes[1].name: pos_emb.numpy(),
|
||||
}
|
||||
onnx_y = session.run(["y"], inputs)[0]
|
||||
onnx_y = torch.from_numpy(onnx_y)
|
||||
|
||||
torch_y = encoder_layer(x, pos_emb)
|
||||
assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max()
|
||||
|
||||
print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape)
|
||||
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
def test_zipformer_encoder():
|
||||
filename = "zipformer_encoder.onnx"
|
||||
|
||||
opset_version = 13
|
||||
N = 3
|
||||
T = 15
|
||||
|
||||
d_model = 512
|
||||
attention_dim = 192
|
||||
nhead = 8
|
||||
feedforward_dim = 1024
|
||||
dropout = 0.1
|
||||
cnn_module_kernel = 31
|
||||
pos_dim = 4
|
||||
num_encoder_layers = 12
|
||||
|
||||
warmup_batches = 4000.0
|
||||
warmup_begin = warmup_batches / (num_encoder_layers + 1)
|
||||
warmup_end = warmup_batches / (num_encoder_layers + 1)
|
||||
|
||||
x = torch.rand(N, T, d_model)
|
||||
|
||||
encoder_layer = ZipformerEncoderLayer(
|
||||
d_model,
|
||||
attention_dim,
|
||||
nhead,
|
||||
feedforward_dim,
|
||||
dropout,
|
||||
cnn_module_kernel,
|
||||
pos_dim,
|
||||
)
|
||||
encoder = ZipformerEncoder(
|
||||
encoder_layer, num_encoder_layers, dropout, warmup_begin, warmup_end
|
||||
)
|
||||
encoder.eval()
|
||||
encoder = convert_scaled_to_non_scaled(encoder, inplace=True)
|
||||
|
||||
# jit_model = torch.jit.trace(encoder, (pos_emb))
|
||||
|
||||
torch_y = encoder(x)
|
||||
|
||||
torch.onnx.export(
|
||||
encoder,
|
||||
(x),
|
||||
filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x"],
|
||||
output_names=["y"],
|
||||
dynamic_axes={
|
||||
"x": {0: "T", 1: "N"},
|
||||
"y": {0: "T", 1: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
options = ort.SessionOptions()
|
||||
options.inter_op_num_threads = 1
|
||||
options.intra_op_num_threads = 1
|
||||
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
inputs = {
|
||||
input_nodes[0].name: x.numpy(),
|
||||
}
|
||||
onnx_y = session.run(["y"], inputs)[0]
|
||||
onnx_y = torch.from_numpy(onnx_y)
|
||||
|
||||
torch_y = encoder(x)
|
||||
assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max()
|
||||
|
||||
print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape)
|
||||
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
def test_zipformer():
|
||||
filename = "zipformer.onnx"
|
||||
opset_version = 11
|
||||
N = 3
|
||||
T = 15
|
||||
num_features = 80
|
||||
x = torch.rand(N, T, num_features)
|
||||
x_lens = torch.full((N,), fill_value=T, dtype=torch.int64)
|
||||
|
||||
zipformer = Zipformer(num_features=num_features)
|
||||
zipformer.eval()
|
||||
zipformer = convert_scaled_to_non_scaled(zipformer, inplace=True)
|
||||
|
||||
# jit_model = torch.jit.trace(zipformer, (x, x_lens))
|
||||
torch.onnx.export(
|
||||
zipformer,
|
||||
(x, x_lens),
|
||||
filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "x_lens"],
|
||||
output_names=["y", "y_lens"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"x_lens": {0: "N"},
|
||||
"y": {0: "N", 1: "T"},
|
||||
"y_lens": {0: "N"},
|
||||
},
|
||||
)
|
||||
options = ort.SessionOptions()
|
||||
options.inter_op_num_threads = 1
|
||||
options.intra_op_num_threads = 1
|
||||
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
inputs = {
|
||||
input_nodes[0].name: x.numpy(),
|
||||
input_nodes[1].name: x_lens.numpy(),
|
||||
}
|
||||
onnx_y, onnx_y_lens = session.run(["y", "y_lens"], inputs)
|
||||
onnx_y = torch.from_numpy(onnx_y)
|
||||
onnx_y_lens = torch.from_numpy(onnx_y_lens)
|
||||
|
||||
torch_y, torch_y_lens = zipformer(x, x_lens)
|
||||
assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max()
|
||||
|
||||
assert torch.allclose(onnx_y_lens, torch_y_lens, atol=1e-05), (
|
||||
(onnx_y_lens - torch_y_lens).abs().max()
|
||||
)
|
||||
print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape)
|
||||
print(onnx_y_lens, torch_y_lens)
|
||||
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
test_conv2d_subsampling()
|
||||
test_rel_pos()
|
||||
test_zipformer_encoder_layer()
|
||||
test_zipformer_encoder()
|
||||
test_zipformer()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20221011)
|
||||
main()
|
||||
1537
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/train.py
Executable file
1537
egs/librispeech/ASR/pruned_transducer_stateless7_contextual/train.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,83 @@
|
||||
import torch
|
||||
from transformers import BertTokenizer, BertModel
|
||||
import logging
|
||||
|
||||
|
||||
class BertEncoder:
|
||||
def __init__(self, device=None, **kwargs):
|
||||
# https://huggingface.co/bert-base-uncased
|
||||
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
# TODO: fast_tokenizers: https://huggingface.co/docs/transformers/v4.27.2/en/fast_tokenizers
|
||||
|
||||
self.bert_model = BertModel.from_pretrained("bert-base-uncased")
|
||||
self.bert_model.eval()
|
||||
for param in self.bert_model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.name = self.bert_model.config._name_or_path
|
||||
self.embedding_size = 768
|
||||
|
||||
num_param = sum([p.numel() for p in self.bert_model.parameters()])
|
||||
logging.info(f"Number of parameters in '{self.name}': {num_param}")
|
||||
|
||||
if device is not None:
|
||||
self.bert_model.to(device)
|
||||
logging.info(f"Loaded '{self.name}' to {device}")
|
||||
logging.info(f"cuda.memory_allocated: {torch.cuda.memory_allocated()/1024/1024:.2f} MB")
|
||||
|
||||
def _encode_strings(
|
||||
self,
|
||||
word_list,
|
||||
):
|
||||
# word_list is a list of uncased strings
|
||||
|
||||
encoded_input = self.tokenizer(word_list, return_tensors='pt', padding=True)
|
||||
encoded_input = encoded_input.to(self.bert_model.device)
|
||||
out = self.bert_model(**encoded_input).pooler_output
|
||||
|
||||
# TODO:
|
||||
# 1. compare with some online API
|
||||
# 2. smaller or no dropout
|
||||
# 3. other activation function: different ranges, sigmoid
|
||||
# 4. compare the range with lstm encoder or rnnt hidden representation
|
||||
# 5. more layers, transformer layers: how to connect two spaces
|
||||
|
||||
# out is of the shape: len(word_list) * 768
|
||||
return out
|
||||
|
||||
def encode_strings(
|
||||
self,
|
||||
word_list,
|
||||
batch_size = 6000,
|
||||
silent = False,
|
||||
):
|
||||
"""
|
||||
Encode a list of uncased strings into a list of embeddings
|
||||
Args:
|
||||
word_list:
|
||||
A list of words, where each word is a string
|
||||
Returns:
|
||||
embeddings:
|
||||
A list of embeddings (on CPU), of the shape len(word_list) * 768
|
||||
"""
|
||||
i = 0
|
||||
embeddings_list = list()
|
||||
while i < len(word_list):
|
||||
if not silent and int(i / 10000) % 5 == 0:
|
||||
logging.info(f"Using '{self.name}' to encode the wordlist: {i}/{len(word_list)}")
|
||||
wlist = word_list[i: i + batch_size]
|
||||
embeddings = self._encode_strings(wlist)
|
||||
embeddings = embeddings.detach().cpu() # To save GPU memory
|
||||
embeddings = list(embeddings)
|
||||
embeddings_list.extend(embeddings)
|
||||
i += batch_size
|
||||
if not silent:
|
||||
logging.info(f"Done, len(embeddings_list)={len(embeddings_list)}")
|
||||
return embeddings_list
|
||||
|
||||
def free_up(self):
|
||||
self.bert_model = self.bert_model.to(torch.device("cpu"))
|
||||
torch.cuda.empty_cache()
|
||||
logging.info(f"cuda.memory_allocated: {torch.cuda.memory_allocated()/1024/1024:.2f} MB")
|
||||
|
||||
|
||||
@ -0,0 +1,81 @@
|
||||
import torch
|
||||
import logging
|
||||
import io
|
||||
import fasttext
|
||||
|
||||
|
||||
class FastTextEncoder:
|
||||
def __init__(self, embeddings_path=None, model_path=None, **kwargs):
|
||||
logging.info(f"Loading word embeddings from: {embeddings_path}")
|
||||
self.word_to_vector = self.load_vectors(embeddings_path)
|
||||
logging.info(f"Number of word embeddings: {len(self.word_to_vector)}")
|
||||
|
||||
self.model_path = model_path
|
||||
self.model = None
|
||||
|
||||
self.name = "FastText"
|
||||
self.embedding_size = 300
|
||||
|
||||
def load_vectors(self, fname):
|
||||
fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
|
||||
# n, d = map(int, fin.readline().split())
|
||||
data = {}
|
||||
for line in fin:
|
||||
tokens = line.rstrip().split(' ')
|
||||
w = tokens[0].upper()
|
||||
embedding = list(map(float, tokens[1:]))
|
||||
data[w] = torch.tensor(embedding)
|
||||
return data
|
||||
|
||||
def _encode_strings(
|
||||
self,
|
||||
word_list,
|
||||
lower_case=True,
|
||||
):
|
||||
'''
|
||||
Encode unseen strings
|
||||
'''
|
||||
if len(word_list) == 0:
|
||||
return []
|
||||
|
||||
if self.model is None:
|
||||
logging.info(f"Lazy loading FastText model from: {self.model_path}")
|
||||
self.model = fasttext.FastText.load_model(self.model_path) # "/exp/rhuang/fastText/cc.en.300.bin"
|
||||
out = [self.model[w.lower()] if lower_case else self.model[w] for w in word_list]
|
||||
|
||||
for w, embedding in zip(word_list, out):
|
||||
self.word_to_vector[w] = torch.tensor(embedding)
|
||||
|
||||
return out
|
||||
|
||||
def encode_strings(
|
||||
self,
|
||||
word_list,
|
||||
batch_size = 6000,
|
||||
silent = False,
|
||||
):
|
||||
"""
|
||||
Encode a list of uncased strings into a list of embeddings
|
||||
Args:
|
||||
word_list:
|
||||
A list of words, where each word is a string
|
||||
Returns:
|
||||
embeddings:
|
||||
A list of embeddings (on CPU), of the shape len(word_list) * 768
|
||||
"""
|
||||
embeddings_list = list()
|
||||
for i, w in enumerate(word_list):
|
||||
if not silent and i % 50000 == 0:
|
||||
logging.info(f"Using FastText to encode the wordlist: {i}/{len(word_list)}")
|
||||
|
||||
if w in self.word_to_vector:
|
||||
embeddings_list.append(self.word_to_vector[w])
|
||||
else:
|
||||
embedding = self._encode_strings([w], lower_case=True)[0]
|
||||
embeddings_list.append(embedding)
|
||||
if not silent:
|
||||
logging.info(f"Done, len(embeddings_list)={len(embeddings_list)}")
|
||||
return embeddings_list
|
||||
|
||||
def free_up(self):
|
||||
pass
|
||||
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user