Merge bb58c85c839d9d0766fc57e217788aa17a7f85cb into 800bf4b6a2e32745e7d0c31dd78d473f1faff509

This commit is contained in:
uni-manjunath-ke 2023-10-27 12:19:42 +07:00 committed by GitHub
commit b6f48899c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 10649 additions and 0 deletions

View File

@ -375,6 +375,38 @@ for m in greedy_search modified_beam_search fast_beam_search; do
done done
``` ```
### zipformer ctc streaming
| decoding method | test-clean | test-other | comment |
|----------------------|------------|------------|--------------------|
| greedy_search | 4.07 | 10.51 | --epoch 30 --avg 15|
| greedy_search | 4.0 | 10.3 | --epoch 30 --avg 9 |
The training command is:
```bash
export CUDA_VISIBLE_DEVICES="0,1"
python ./zipformer_ctc_streaming/train.py \
--world-size 2 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir ./zipformer_ctc_streaming/exp \
--max-duration 100
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
./zipformer_ctc_streaming/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer_ctc_streaming/exp \
--max-duration 300 \
--decode-chunk-len 32 \
--method ctc-decoding \
--lm-dir data/lm \
--lang-dir data/lang_bpe_500
```
### pruned_transducer_stateless7 (Fine-tune with mux) ### pruned_transducer_stateless7 (Fine-tune with mux)
See <https://github.com/k2-fsa/icefall/pull/1059> for more details. See <https://github.com/k2-fsa/icefall/pull/1059> for more details.
@ -456,6 +488,39 @@ The decoding commands are:
--max-states 64 --max-states 64
``` ```
### zipformer ctc streaming
| decoding method | test-clean | test-other | comment |
|----------------------|------------|------------|--------------------|
| greedy_search | 4.07 | 10.51 | --epoch 30 --avg 15|
| greedy_search | 4.0 | 10.3 | --epoch 30 --avg 9 |
The training command is:
```bash
export CUDA_VISIBLE_DEVICES="0,1"
python ./zipformer_ctc_streaming/train.py \
--world-size 2 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir ./zipformer_ctc_streaming/exp \
--max-duration 100
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
./zipformer_ctc_streaming/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer_ctc_streaming/exp \
--max-duration 300 \
--decode-chunk-len 32 \
--method ctc-decoding \
--lm-dir data/lm \
--lang-dir data/lang_bpe_500
```
### pruned_transducer_stateless7 (zipformer + multidataset(LibriSpeech + GigaSpeech + CommonVoice 13.0)) ### pruned_transducer_stateless7 (zipformer + multidataset(LibriSpeech + GigaSpeech + CommonVoice 13.0))
See <https://github.com/k2-fsa/icefall/pull/1010> for more details. See <https://github.com/k2-fsa/icefall/pull/1010> for more details.

View File

@ -0,0 +1,475 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class LibriSpeechAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--full-libri",
type=str2bool,
default=True,
help="""Used only when --mini-libri is False.When enabled,
use 960h LibriSpeech. Otherwise, use 100h subset.""",
)
group.add_argument(
"--mini-libri",
type=str2bool,
default=False,
help="True for mini librispeech",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_clean_5_cuts(self) -> CutSet:
logging.info("mini_librispeech: About to get train-clean-5 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz"
)
@lru_cache()
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
)
@lru_cache()
def train_clean_360_cuts(self) -> CutSet:
logging.info("About to get train-clean-360 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
)
@lru_cache()
def train_other_500_cuts(self) -> CutSet:
logging.info("About to get train-other-500 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
)
@lru_cache()
def train_all_shuf_cuts(self) -> CutSet:
logging.info(
"About to get the shuffled train-clean-100, \
train-clean-360 and train-other-500 cuts"
)
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
)
@lru_cache()
def dev_clean_2_cuts(self) -> CutSet:
logging.info("mini_librispeech: About to get dev-clean-2 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz"
)
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
)
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
)
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
)
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
)

View File

@ -0,0 +1,891 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from train import add_model_arguments, get_ctc_model, get_params
from transformer import encoder_padding_mask
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
rescore_with_rnn_lm,
rescore_with_whole_lattice,
)
from icefall.lexicon import Lexicon
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import (
AttributeDict,
get_texts,
load_averaged_model,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=77,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=55,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--method",
type=str,
default="attention-decoder",
help="""Decoding method.
Supported values are:
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (1) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result.
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
is the decoding result.
- (5) attention-decoder. Extract n paths from the LM rescored
lattice, the path with the highest score is the decoding result.
- (6) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume
you have trained an RNN LM using ./rnn_lm/train.py
- (7) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
A smaller value results in more unique paths.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer_ctc/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_500",
help="The lang dir",
)
parser.add_argument(
"--lm-dir",
type=str,
default="data/lm",
help="""The n-gram LM dir.
It should contain either G_3_gram.pt or G_3_gram.fst.txt
""",
)
parser.add_argument(
"--rnn-lm-exp-dir",
type=str,
default="rnn_lm/exp",
help="""Used only when --method is rnn-lm.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is rnn-lm.
It specifies the checkpoint to use.
""",
)
parser.add_argument(
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is rnn-lm.
It specifies the number of checkpoints to average.
""",
)
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
rnn_lm_model: Optional[nn.Module],
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
batch: dict,
word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if no rescoring is used, the key is the string `no_rescore`.
If LM rescoring is used, the key is the string `lm_scale_xxx`,
where `xxx` is the value of `lm_scale`. An example key is
`lm_scale_0.7`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.method is "1best", it uses 1best decoding without LM rescoring.
- params.method is "nbest", it uses nbest decoding without LM rescoring.
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
rescoring.
model:
The neural model.
rnn_lm_model:
The neural model for RNN LM.
HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
The token ID of the EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return the decoding result. See above description for the format of
the returned dict. Note: If it decodes to nothing, then return None.
"""
if HLG is not None:
device = HLG.device
else:
device = H.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
nnet_output, _ = model.encoder(feature, feature_lens)
ctc_output = model.ctc_output(nnet_output)
# nnet_output is (N, T, C)
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // params.subsampling_factor,
supervisions["num_frames"] // params.subsampling_factor,
),
1,
).to(torch.int32)
if H is None:
assert HLG is not None
decoding_graph = HLG
else:
assert HLG is None
assert bpe_model is not None
decoding_graph = H
lattice = get_lattice(
nnet_output=ctc_output,
decoding_graph=decoding_graph,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method == "ctc-decoding":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
# Note: `best_path.aux_labels` contains token IDs, not word IDs
# since we are using H, not HLG here.
#
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "ctc-decoding"
return {key: hyps}
if params.method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is only slightly worse than that of rescored lattices.
best_path = nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
word_table=word_table,
nbest_scale=params.nbest_scale,
oov="<UNK>",
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
return {key: hyps}
if params.method in ["1best", "nbest"]:
if params.method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = "no_rescore"
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
return {key: hyps}
assert params.method in [
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
nnet_output = nnet_output.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
mask = encoder_padding_mask(nnet_output.size(0), supervisions)
mask = mask.to(nnet_output.device) if mask is not None else None
mmodel = model.decoder.module if hasattr(model.decoder, "module") else model.decoder
if params.method == "nbest-rescoring":
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
nbest_scale=params.nbest_scale,
)
elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
elif params.method == "attention-decoder":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
)
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=mmodel,
memory=nnet_output,
memory_key_padding_mask=mask,
sos_id=sos_id,
eos_id=eos_id,
nbest_scale=params.nbest_scale,
)
elif params.method == "rnn-lm":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
)
best_path_dict = rescore_with_rnn_lm(
lattice=rescored_lattice,
num_paths=params.num_paths,
rnn_lm_model=rnn_lm_model,
model=mmodel,
memory=nnet_output,
memory_key_padding_mask=mask,
sos_id=sos_id,
eos_id=eos_id,
blank_id=0,
nbest_scale=params.nbest_scale,
)
else:
assert False, f"Unsupported decoding method: {params.method}"
ans = dict()
if best_path_dict is not None:
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
else:
ans = None
return ans
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
rnn_lm_model: Optional[nn.Module],
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
rnn_lm_model:
The neural model for RNN LM.
HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
word_table:
It is the word symbol table.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
rnn_lm_model=rnn_lm_model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
batch=batch,
word_table=word_table,
G=G,
sos_id=sos_id,
eos_id=eos_id,
)
if hyps_dict is not None:
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch)
else:
assert len(results) > 0, "It should not decode to empty in the first batch!"
this_batch = []
hyp_words = []
for ref_text in texts:
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
for lm_scale in results.keys():
results[lm_scale].extend(this_batch)
num_cuts += len(texts)
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[int], List[int]]]],
):
if params.method in ("attention-decoder", "rnn-lm"):
# Set it to False since there are too many logs.
enable_log = False
else:
enable_log = True
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=enable_log
)
test_set_wers[key] = wer
if enable_log:
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
args.lm_dir = Path(args.lm_dir)
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
logging.info("Decoding started")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
params.vocab_size = num_classes
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id
params.num_classes = num_classes
params.sos_id = sos_id
params.eos_id = eos_id
if params.method == "ctc-decoding":
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
else:
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
)
assert HLG.requires_grad is False
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()
if params.method in (
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
):
if not (params.lm_dir / "G_3_gram.pt").is_file():
logging.info("Loading G_3_gram.fst.txt")
logging.warning("It may take 8 minutes.")
with open(params.lm_dir / "G_3_gram.fst.txt") as f:
#with open(params.lm_dir / "G_3_gram.fst.txt") as f:
first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set G.properties to None
G.__dict__["_properties"] = None
G = k2.Fsa.from_fsas([G]).to(device)
G = k2.arc_sort(G)
# Save a dummy value so that it can be loaded in C++.
# See https://github.com/pytorch/pytorch/issues/67902
# for why we need to do this.
G.dummy = 1
torch.save(G.as_dict(), params.lm_dir / "G_3_gram.pt")
else:
logging.info("Loading pre-compiled G_3_gram.pt")
d = torch.load(params.lm_dir / "G_3_gram.pt", map_location=device)
G = k2.Fsa.from_dict(d)
if params.method in [
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
]:
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G = G.to(device)
# G.lm_scores is used to replace HLG.lm_scores during
# LM rescoring.
G.lm_scores = G.scores.clone()
else:
G = None
logging.info("About to create model")
model = get_ctc_model(params)
assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, (
model.encoder.decode_chunk_size,
params.decode_chunk_len,
)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
rnn_lm_model = None
if params.method == "rnn-lm":
rnn_lm_model = RnnLmModel(
vocab_size=params.num_classes,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
)
if params.rnn_lm_avg == 1:
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
)
rnn_lm_model.to(device)
else:
rnn_lm_model = load_averaged_model(
params.rnn_lm_exp_dir,
rnn_lm_model,
params.rnn_lm_epoch,
params.rnn_lm_avg,
device,
)
rnn_lm_model.eval()
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
rnn_lm_model=rnn_lm_model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
word_table=lexicon.word_table,
G=G,
sos_id=sos_id,
eos_id=eos_id,
)
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,887 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from train import add_model_arguments, get_ctc_model, get_params
from transformer import encoder_padding_mask
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
rescore_with_rnn_lm,
rescore_with_whole_lattice,
)
from icefall.lexicon import Lexicon
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import (
AttributeDict,
get_texts,
load_averaged_model,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=77,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=55,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--method",
type=str,
default="attention-decoder",
help="""Decoding method.
Supported values are:
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (1) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result.
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
is the decoding result.
- (5) attention-decoder. Extract n paths from the LM rescored
lattice, the path with the highest score is the decoding result.
- (6) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume
you have trained an RNN LM using ./rnn_lm/train.py
- (7) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
A smaller value results in more unique paths.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer_ctc/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_500",
help="The lang dir",
)
parser.add_argument(
"--lm-dir",
type=str,
default="data/lm",
help="""The n-gram LM dir.
It should contain either G_4_gram.pt or G_4_gram.fst.txt
""",
)
parser.add_argument(
"--rnn-lm-exp-dir",
type=str,
default="rnn_lm/exp",
help="""Used only when --method is rnn-lm.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is rnn-lm.
It specifies the checkpoint to use.
""",
)
parser.add_argument(
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is rnn-lm.
It specifies the number of checkpoints to average.
""",
)
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
rnn_lm_model: Optional[nn.Module],
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
batch: dict,
word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if no rescoring is used, the key is the string `no_rescore`.
If LM rescoring is used, the key is the string `lm_scale_xxx`,
where `xxx` is the value of `lm_scale`. An example key is
`lm_scale_0.7`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.method is "1best", it uses 1best decoding without LM rescoring.
- params.method is "nbest", it uses nbest decoding without LM rescoring.
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
rescoring.
model:
The neural model.
rnn_lm_model:
The neural model for RNN LM.
HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
The token ID of the EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return the decoding result. See above description for the format of
the returned dict. Note: If it decodes to nothing, then return None.
"""
if HLG is not None:
device = HLG.device
else:
device = H.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
nnet_output, _ = model.encoder(feature, feature_lens)
ctc_output = model.ctc_output(nnet_output)
# nnet_output is (N, T, C)
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // params.subsampling_factor,
supervisions["num_frames"] // params.subsampling_factor,
),
1,
).to(torch.int32)
if H is None:
assert HLG is not None
decoding_graph = HLG
else:
assert HLG is None
assert bpe_model is not None
decoding_graph = H
lattice = get_lattice(
nnet_output=ctc_output,
decoding_graph=decoding_graph,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method == "ctc-decoding":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
# Note: `best_path.aux_labels` contains token IDs, not word IDs
# since we are using H, not HLG here.
#
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "ctc-decoding"
return {key: hyps}
if params.method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is only slightly worse than that of rescored lattices.
best_path = nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
word_table=word_table,
nbest_scale=params.nbest_scale,
oov="<UNK>",
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
return {key: hyps}
if params.method in ["1best", "nbest"]:
if params.method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = "no_rescore"
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
return {key: hyps}
assert params.method in [
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
nnet_output = nnet_output.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
mask = encoder_padding_mask(nnet_output.size(0), supervisions)
mask = mask.to(nnet_output.device) if mask is not None else None
mmodel = model.decoder.module if hasattr(model.decoder, "module") else model.decoder
if params.method == "nbest-rescoring":
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
nbest_scale=params.nbest_scale,
)
elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
elif params.method == "attention-decoder":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
)
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=mmodel,
memory=nnet_output,
memory_key_padding_mask=mask,
sos_id=sos_id,
eos_id=eos_id,
nbest_scale=params.nbest_scale,
)
elif params.method == "rnn-lm":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
)
best_path_dict = rescore_with_rnn_lm(
lattice=rescored_lattice,
num_paths=params.num_paths,
rnn_lm_model=rnn_lm_model,
model=mmodel,
memory=nnet_output,
memory_key_padding_mask=mask,
sos_id=sos_id,
eos_id=eos_id,
blank_id=0,
nbest_scale=params.nbest_scale,
)
else:
assert False, f"Unsupported decoding method: {params.method}"
ans = dict()
if best_path_dict is not None:
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
else:
ans = None
return ans
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
rnn_lm_model: Optional[nn.Module],
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
rnn_lm_model:
The neural model for RNN LM.
HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
word_table:
It is the word symbol table.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
rnn_lm_model=rnn_lm_model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
batch=batch,
word_table=word_table,
G=G,
sos_id=sos_id,
eos_id=eos_id,
)
if hyps_dict is not None:
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch)
else:
assert len(results) > 0, "It should not decode to empty in the first batch!"
this_batch = []
hyp_words = []
for ref_text in texts:
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
for lm_scale in results.keys():
results[lm_scale].extend(this_batch)
num_cuts += len(texts)
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[int], List[int]]]],
):
if params.method in ("attention-decoder", "rnn-lm"):
# Set it to False since there are too many logs.
enable_log = False
else:
enable_log = True
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=enable_log
)
test_set_wers[key] = wer
if enable_log:
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
args.lm_dir = Path(args.lm_dir)
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
logging.info("Decoding started")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
params.vocab_size = num_classes
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id
params.num_classes = num_classes
params.sos_id = sos_id
params.eos_id = eos_id
if params.method == "ctc-decoding":
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
else:
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
)
assert HLG.requires_grad is False
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()
if params.method in (
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
):
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take 8 minutes.")
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
#with open(params.lm_dir / "G_3_gram.fst.txt") as f:
first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set G.properties to None
G.__dict__["_properties"] = None
G = k2.Fsa.from_fsas([G]).to(device)
G = k2.arc_sort(G)
# Save a dummy value so that it can be loaded in C++.
# See https://github.com/pytorch/pytorch/issues/67902
# for why we need to do this.
G.dummy = 1
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
G = k2.Fsa.from_dict(d)
if params.method in [
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
]:
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G = G.to(device)
# G.lm_scores is used to replace HLG.lm_scores during
# LM rescoring.
G.lm_scores = G.scores.clone()
else:
G = None
logging.info("About to create model")
model = get_ctc_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
rnn_lm_model = None
if params.method == "rnn-lm":
rnn_lm_model = RnnLmModel(
vocab_size=params.num_classes,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
)
if params.rnn_lm_avg == 1:
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
)
rnn_lm_model.to(device)
else:
rnn_lm_model = load_averaged_model(
params.rnn_lm_exp_dir,
rnn_lm_model,
params.rnn_lm_epoch,
params.rnn_lm_avg,
device,
)
rnn_lm_model.eval()
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
rnn_lm_model=rnn_lm_model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
word_table=lexicon.word_table,
G=G,
sos_id=sos_id,
eos_id=eos_id,
)
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,298 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from label_smoothing import LabelSmoothingLoss
from torch.nn.utils.rnn import pad_sequence
from transformer import PositionalEncoding, TransformerDecoderLayer
class Decoder(nn.Module):
"""This class implements Transformer based decoder for an attention-based encoder-decoder
model.
"""
def __init__(
self,
num_layers: int,
num_classes: int,
d_model: int = 256,
nhead: int = 4,
dim_feedforward: int = 2048,
dropout: float = 0.1,
normalize_before: bool = True,
):
"""
Args:
num_layers:
Number of layers.
num_classes:
Number of tokens of the modeling unit including blank.
d_model:
Dimension of the input embedding, and of the decoder output.
"""
super().__init__()
if num_layers > 0:
self.decoder_num_class = num_classes # bpe model already has sos/eos symbol
self.decoder_embed = nn.Embedding(
num_embeddings=self.decoder_num_class, embedding_dim=d_model
)
self.decoder_pos = PositionalEncoding(d_model, dropout)
decoder_layer = TransformerDecoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
normalize_before=normalize_before,
)
if normalize_before:
decoder_norm = nn.LayerNorm(d_model)
else:
decoder_norm = None
self.decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=num_layers,
norm=decoder_norm,
)
self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
self.decoder_criterion = LabelSmoothingLoss()
else:
self.decoder_criterion = None
@torch.jit.export
def forward(
self,
memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
token_ids: List[List[int]],
sos_id: int,
eos_id: int,
) -> torch.Tensor:
"""
Args:
memory:
It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask:
The padding mask from the encoder.
token_ids:
A list-of-list IDs. Each sublist contains IDs for an utterance.
The IDs can be either phone IDs or word piece IDs.
sos_id:
sos token id
eos_id:
eos token id
Returns:
A scalar, the **sum** of label smoothing loss over utterances
in the batch without any normalization.
"""
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
device = memory.device
ys_in_pad = ys_in_pad.to(device)
ys_out_pad = ys_out_pad.to(device)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask
# We set the first column to False since the first column in ys_in_pad
# contains sos_id, which is the same as eos_id in our current setting.
tgt_key_padding_mask[:, 0] = False
tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C)
tgt = self.decoder_pos(tgt)
tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
pred_pad = self.decoder(
tgt=tgt,
memory=memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
) # (T, N, C)
pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C)
decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad)
return decoder_loss
@torch.jit.export
def decoder_nll(
self,
memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
token_ids: List[torch.Tensor],
sos_id: int,
eos_id: int,
) -> torch.Tensor:
"""
Args:
memory:
It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask:
The padding mask from the encoder.
token_ids:
A list-of-list IDs (e.g., word piece IDs).
Each sublist represents an utterance.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
Returns:
A 2-D tensor of shape (len(token_ids), max_token_length)
representing the cross entropy loss (i.e., negative log-likelihood).
"""
# The common part between this function and decoder_forward could be
# extracted as a separate function.
if isinstance(token_ids[0], torch.Tensor):
# This branch is executed by torchscript in C++.
# See https://github.com/k2-fsa/k2/pull/870
# https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286
token_ids = [tolist(t) for t in token_ids]
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask
# We set the first column to False since the first column in ys_in_pad
# contains sos_id, which is the same as eos_id in our current setting.
tgt_key_padding_mask[:, 0] = False
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
tgt = self.decoder_pos(tgt)
tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
pred_pad = self.decoder(
tgt=tgt,
memory=memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
) # (T, B, F)
pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F)
pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F)
# nll: negative log-likelihood
nll = torch.nn.functional.cross_entropy(
pred_pad.view(-1, self.decoder_num_class),
ys_out_pad.view(-1),
ignore_index=-1,
reduction="none",
)
nll = nll.view(pred_pad.shape[0], -1)
return nll
def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
"""Prepend sos_id to each utterance.
Args:
token_ids:
A list-of-list of token IDs. Each sublist contains
token IDs (e.g., word piece IDs) of an utterance.
sos_id:
The ID of the SOS token.
Return:
Return a new list-of-list, where each sublist starts
with SOS ID.
"""
return [[sos_id] + utt for utt in token_ids]
def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
"""Append eos_id to each utterance.
Args:
token_ids:
A list-of-list of token IDs. Each sublist contains
token IDs (e.g., word piece IDs) of an utterance.
eos_id:
The ID of the EOS token.
Return:
Return a new list-of-list, where each sublist ends
with EOS ID.
"""
return [utt + [eos_id] for utt in token_ids]
def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
"""Generate a length mask for input.
The masked position are filled with True,
Unmasked positions are filled with False.
Args:
ys_pad:
padded tensor of dimension (batch_size, input_length).
ignore_id:
the ignored number (the padding number) in ys_pad
Returns:
Tensor:
a bool tensor of the same shape as the input tensor.
"""
ys_mask = ys_pad == ignore_id
return ys_mask
def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
"""Generate a square mask for the sequence. The masked positions are
filled with float('-inf'). Unmasked positions are filled with float(0.0).
The mask can be used for masked self-attention.
For instance, if sz is 3, it returns::
tensor([[0., -inf, -inf],
[0., 0., -inf],
[0., 0., 0]])
Args:
sz: mask size
Returns:
A square mask of dimension (sz, sz)
"""
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = (
mask.float()
.masked_fill(mask == 0, float("-inf"))
.masked_fill(mask == 1, float(0.0))
)
return mask
def tolist(t: torch.Tensor) -> List[int]:
"""Used by jit"""
return torch.jit.annotate(List[int], t.tolist())

View File

@ -0,0 +1,43 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
import torch.nn as nn
class EncoderInterface(nn.Module):
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
A tensor of shape (batch_size, input_seq_len, num_features)
containing the input features.
x_lens:
A tensor of shape (batch_size,) containing the number of frames
in `x` before padding.
Returns:
Return a tuple containing two tensors:
- encoder_out, a tensor of (batch_size, out_seq_len, output_dim)
containing unnormalized probabilities, i.e., the output of a
linear layer.
- encoder_out_lens, a tensor of shape (batch_size,) containing
the number of frames in `encoder_out` before padding.
"""
raise NotImplementedError("Please implement it in a subclass")

View File

@ -0,0 +1,240 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
import argparse
import logging
from pathlib import Path
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_ctc_model, get_params
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer_ctc/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_500",
help="""It contains language related input files such as "lexicon.txt"
""",
)
parser.add_argument(
"--jit",
type=str2bool,
default=True,
help="""True to save a model after applying torch.jit.script.
""",
)
add_model_arguments(parser)
return parser
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
params.vocab_size = num_classes
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = get_ctc_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
convert_scaled_to_non_scaled(model, inplace=True)
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,109 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class LabelSmoothingLoss(torch.nn.Module):
"""
Implement the LabelSmoothingLoss proposed in the following paper
https://arxiv.org/pdf/1512.00567.pdf
(Rethinking the Inception Architecture for Computer Vision)
"""
def __init__(
self,
ignore_index: int = -1,
label_smoothing: float = 0.1,
reduction: str = "sum",
) -> None:
"""
Args:
ignore_index:
ignored class id
label_smoothing:
smoothing rate (0.0 means the conventional cross entropy loss)
reduction:
It has the same meaning as the reduction in
`torch.nn.CrossEntropyLoss`. It can be one of the following three
values: (1) "none": No reduction will be applied. (2) "mean": the
mean of the output is taken. (3) "sum": the output will be summed.
"""
super().__init__()
assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}"
assert reduction in ("none", "sum", "mean"), reduction
self.ignore_index = ignore_index
self.label_smoothing = label_smoothing
self.reduction = reduction
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute loss between x and target.
Args:
x:
prediction of dimension
(batch_size, input_length, number_of_classes).
target:
target masked with self.ignore_index of
dimension (batch_size, input_length).
Returns:
A scalar tensor containing the loss without normalization.
"""
assert x.ndim == 3
assert target.ndim == 2
assert x.shape[:2] == target.shape
num_classes = x.size(-1)
x = x.reshape(-1, num_classes)
# Now x is of shape (N*T, C)
# We don't want to change target in-place below,
# so we make a copy of it here
target = target.clone().reshape(-1)
ignored = target == self.ignore_index
# See https://github.com/k2-fsa/icefall/issues/240
# and https://github.com/k2-fsa/icefall/issues/297
# for why we don't use target[ignored] = 0 here
target = torch.where(ignored, torch.zeros_like(target), target)
true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x)
true_dist = (
true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes
)
# Set the value of ignored indexes to 0
#
# See https://github.com/k2-fsa/icefall/issues/240
# and https://github.com/k2-fsa/icefall/issues/297
# for why we don't use true_dist[ignored] = 0 here
true_dist = torch.where(
ignored.unsqueeze(1).repeat(1, true_dist.shape[1]),
torch.zeros_like(true_dist),
true_dist,
)
loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
if self.reduction == "sum":
return loss.sum()
elif self.reduction == "mean":
return loss.sum() / (~ignored).sum()
else:
return loss.sum(dim=-1)

View File

@ -0,0 +1,158 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from transformer import encoder_padding_mask
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.utils import encode_supervisions
class CTCModel(nn.Module):
"""It implements a CTC model with an auxiliary attention head."""
def __init__(
self,
encoder: EncoderInterface,
decoder: nn.Module,
encoder_dim: int,
vocab_size: int,
):
"""
Args:
encoder:
An instance of `EncoderInterface`. The shared encoder for the CTC and attention
branches
decoder:
An instance of `nn.Module`. This is the decoder for the attention branch.
encoder_dim:
Dimension of the encoder output.
decoder_dim:
Dimension of the decoder output.
vocab_size:
Number of tokens of the modeling unit including blank.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
self.encoder = encoder
self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)
self.decoder = decoder
@torch.jit.ignore
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
supervisions: torch.Tensor,
graph_compiler: BpeCtcTrainingGraphCompiler,
subsampling_factor: int = 1,
beam_size: int = 10,
reduction: str = "sum",
use_double_scores: bool = False,
) -> torch.Tensor:
"""
Args:
x:
Tensor of dimension (N, T, C) where N is the batch size,
T is the number of frames, and C is the feature dimension.
x_lens:
Tensor of dimension (N,) where N is the batch size.
supervisions:
Supervisions are used in training.
graph_compiler:
It is used to compile a decoding graph from texts.
subsampling_factor:
It is used to compute the `supervisions` for the encoder.
beam_size:
Beam size used in `k2.ctc_loss`.
reduction:
Reduction method used in `k2.ctc_loss`.
use_double_scores:
If True, use double precision in `k2.ctc_loss`.
Returns:
Return the CTC loss, attention loss, and the total number of frames.
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
nnet_output, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
# compute ctc log-probs
ctc_output = self.ctc_output(nnet_output)
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=subsampling_factor
)
num_frames = supervision_segments[:, 2].sum().item()
# Works with a BPE model
token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
dense_fsa_vec = k2.DenseFsaVec(
ctc_output,
supervision_segments.cpu(),
allow_truncate=subsampling_factor - 1,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=beam_size,
reduction=reduction,
use_double_scores=use_double_scores,
)
if self.decoder is not None:
nnet_output = nnet_output.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
mmodel = (
self.decoder.module if hasattr(self.decoder, "module") else self.decoder
)
# Note: We need to generate an unsorted version of token_ids
# `encode_supervisions()` called above sorts text, but
# encoder_memory and memory_mask are not sorted, so we
# use an unsorted version `supervisions["text"]` to regenerate
# the token_ids
#
# See https://github.com/k2-fsa/icefall/issues/97
# for more details
unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
mask = encoder_padding_mask(nnet_output.size(0), supervisions)
mask = mask.to(nnet_output.device) if mask is not None else None
att_loss = mmodel.forward(
nnet_output,
mask,
token_ids=unsorted_token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
else:
att_loss = torch.tensor([0])
return ctc_loss, att_loss, num_frames

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,114 @@
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file replaces various modules in a model.
Specifically, ActivationBalancer is replaced with an identity operator;
Whiten is also replaced with an identity operator;
BasicNorm is replaced by a module with `exp` removed.
"""
import copy
from typing import List
import torch
import torch.nn as nn
from scaling import ActivationBalancer, BasicNorm, Whiten
class NonScaledNorm(nn.Module):
"""See BasicNorm for doc"""
def __init__(
self,
num_channels: int,
eps_exp: float,
channel_dim: int = -1, # CAUTION: see documentation.
):
super().__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
self.eps_exp = eps_exp
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not torch.jit.is_tracing():
assert x.shape[self.channel_dim] == self.num_channels
scales = (
torch.mean(x * x, dim=self.channel_dim, keepdim=True) + self.eps_exp
).pow(-0.5)
return x * scales
def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm:
assert isinstance(basic_norm, BasicNorm), type(BasicNorm)
norm = NonScaledNorm(
num_channels=basic_norm.num_channels,
eps_exp=basic_norm.eps.data.exp().item(),
channel_dim=basic_norm.channel_dim,
)
return norm
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
# get_submodule was added to nn.Module at v1.9.0
def get_submodule(model, target):
if target == "":
return model
atoms: List[str] = target.split(".")
mod: torch.nn.Module = model
for item in atoms:
if not hasattr(mod, item):
raise AttributeError(
mod._get_name() + " has no " "attribute `" + item + "`"
)
mod = getattr(mod, item)
if not isinstance(mod, torch.nn.Module):
raise AttributeError("`" + item + "` is not " "an nn.Module")
return mod
def convert_scaled_to_non_scaled(
model: nn.Module,
inplace: bool = False,
):
"""
Args:
model:
The model to be converted.
inplace:
If True, the input model is modified inplace.
If False, the input model is copied and we modify the copied version.
Return:
Return a model without scaled layers.
"""
if not inplace:
model = copy.deepcopy(model)
d = {}
for name, m in model.named_modules():
if isinstance(m, BasicNorm):
d[name] = convert_basic_norm(m)
elif isinstance(m, (ActivationBalancer, Whiten)):
d[name] = nn.Identity()
for k, v in d.items():
if "." in k:
parent, child = k.rsplit(".", maxsplit=1)
setattr(get_submodule(model, parent), child, v)
else:
setattr(model, k, v)
return model

View File

@ -0,0 +1,153 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(self, idim: int, odim: int) -> None:
"""
Args:
idim:
Input dim. The input shape is (N, T, idim).
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
"""
assert idim >= 7
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
nn.ReLU(),
nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
nn.ReLU(),
)
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
Returns:
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x)
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
return x
class VggSubsampling(nn.Module):
"""Trying to follow the setup described in the following paper:
https://arxiv.org/pdf/1910.09799.pdf
This paper is not 100% explicit so I am guessing to some extent,
and trying to compare with other VGG implementations.
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
"""
def __init__(self, idim: int, odim: int) -> None:
"""Construct a VggSubsampling object.
This uses 2 VGG blocks with 2 Conv2d layers each,
subsampling its input by a factor of 4 in the time dimensions.
Args:
idim:
Input dim. The input shape is (N, T, idim).
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
"""
super().__init__()
cur_channels = 1
layers = []
block_dims = [32, 64]
# The decision to use padding=1 for the 1st convolution, then padding=0
# for the 2nd and for the max-pooling, and ceil_mode=True, was driven by
# a back-compatibility concern so that the number of frames at the
# output would be equal to:
# (((T-1)//2)-1)//2.
# We can consider changing this by using padding=1 on the
# 2nd convolution, so the num-frames at the output would be T//4.
for block_dim in block_dims:
layers.append(
torch.nn.Conv2d(
in_channels=cur_channels,
out_channels=block_dim,
kernel_size=3,
padding=1,
stride=1,
)
)
layers.append(torch.nn.ReLU())
layers.append(
torch.nn.Conv2d(
in_channels=block_dim,
out_channels=block_dim,
kernel_size=3,
padding=0,
stride=1,
)
)
layers.append(
torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
)
cur_channels = block_dim
self.layers = nn.Sequential(*layers)
self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
Returns:
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
"""
x = x.unsqueeze(1)
x = self.layers(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
return x

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,928 @@
# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from label_smoothing import LabelSmoothingLoss
from subsampling import Conv2dSubsampling, VggSubsampling
from torch.nn.utils.rnn import pad_sequence
# Note: TorchScript requires Dict/List/etc. to be fully typed.
Supervisions = Dict[str, torch.Tensor]
class Transformer(nn.Module):
def __init__(
self,
num_features: int,
num_classes: int,
subsampling_factor: int = 4,
d_model: int = 256,
nhead: int = 4,
dim_feedforward: int = 2048,
num_encoder_layers: int = 12,
num_decoder_layers: int = 6,
dropout: float = 0.1,
normalize_before: bool = True,
vgg_frontend: bool = False,
use_feat_batchnorm: Union[float, bool] = 0.1,
) -> None:
"""
Args:
num_features:
The input dimension of the model.
num_classes:
The output dimension of the model.
subsampling_factor:
Number of output frames is num_in_frames // subsampling_factor.
Currently, subsampling_factor MUST be 4.
d_model:
Attention dimension.
nhead:
Number of heads in multi-head attention.
Must satisfy d_model // nhead == 0.
dim_feedforward:
The output dimension of the feedforward layers in encoder/decoder.
num_encoder_layers:
Number of encoder layers.
num_decoder_layers:
Number of decoder layers.
dropout:
Dropout in encoder/decoder.
normalize_before:
If True, use pre-layer norm; False to use post-layer norm.
vgg_frontend:
True to use vgg style frontend for subsampling.
use_feat_batchnorm:
True to use batchnorm for the input layer.
Float value to scale the input layer.
False to do nothing.
"""
super().__init__()
self.use_feat_batchnorm = use_feat_batchnorm
assert isinstance(use_feat_batchnorm, (float, bool))
if isinstance(use_feat_batchnorm, bool) and use_feat_batchnorm:
self.feat_batchnorm = nn.BatchNorm1d(num_features)
self.num_features = num_features
self.num_classes = num_classes
self.subsampling_factor = subsampling_factor
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
# self.encoder_embed converts the input of shape (N, T, num_classes)
# to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_classes -> d_model
if vgg_frontend:
self.encoder_embed = VggSubsampling(num_features, d_model)
else:
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_pos = PositionalEncoding(d_model, dropout)
encoder_layer = TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
normalize_before=normalize_before,
)
if normalize_before:
encoder_norm = nn.LayerNorm(d_model)
else:
encoder_norm = None
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=num_encoder_layers,
norm=encoder_norm,
)
# TODO(fangjun): remove dropout
self.encoder_output_layer = nn.Sequential(
nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
)
if num_decoder_layers > 0:
self.decoder_num_class = (
self.num_classes
) # bpe model already has sos/eos symbol
self.decoder_embed = nn.Embedding(
num_embeddings=self.decoder_num_class, embedding_dim=d_model
)
self.decoder_pos = PositionalEncoding(d_model, dropout)
decoder_layer = TransformerDecoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
normalize_before=normalize_before,
)
if normalize_before:
decoder_norm = nn.LayerNorm(d_model)
else:
decoder_norm = None
self.decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=num_decoder_layers,
norm=decoder_norm,
)
self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
self.decoder_criterion = LabelSmoothingLoss()
else:
self.decoder_criterion = None
def forward(
self, x: torch.Tensor, supervision: Optional[Supervisions] = None
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
x:
The input tensor. Its shape is (N, T, C).
supervision:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
(CAUTION: It contains length information, i.e., start and number of
frames, before subsampling)
Returns:
Return a tuple containing 3 tensors:
- CTC output for ctc decoding. Its shape is (N, T, C)
- Encoder output with shape (T, N, C). It can be used as key and
value for the decoder.
- Encoder output padding mask. It can be used as
memory_key_padding_mask for the decoder. Its shape is (N, T).
It is None if `supervision` is None.
"""
if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
if isinstance(self.use_feat_batchnorm, float):
x *= self.use_feat_batchnorm
encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision)
x = self.ctc_output(encoder_memory)
return x, encoder_memory, memory_key_padding_mask
def run_encoder(
self, x: torch.Tensor, supervisions: Optional[Supervisions] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Run the transformer encoder.
Args:
x:
The model input. Its shape is (N, T, C).
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
CAUTION: It contains length information, i.e., start and number of
frames, before subsampling
It is read directly from the batch, without any sorting. It is used
to compute the encoder padding mask, which is used as memory key
padding mask for the decoder.
Returns:
Return a tuple with two tensors:
- The encoder output, with shape (T, N, C)
- encoder padding mask, with shape (N, T).
The mask is None if `supervisions` is None.
It is used as memory key padding mask in the decoder.
"""
x = self.encoder_embed(x)
x = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
mask = encoder_padding_mask(x.size(0), supervisions)
mask = mask.to(x.device) if mask is not None else None
x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C)
return x, mask
def ctc_output(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
The output tensor from the transformer encoder.
Its shape is (T, N, C)
Returns:
Return a tensor that can be used for CTC decoding.
Its shape is (N, T, C)
"""
x = self.encoder_output_layer(x)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
return x
@torch.jit.export
def decoder_forward(
self,
memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
token_ids: List[List[int]],
sos_id: int,
eos_id: int,
) -> torch.Tensor:
"""
Args:
memory:
It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask:
The padding mask from the encoder.
token_ids:
A list-of-list IDs. Each sublist contains IDs for an utterance.
The IDs can be either phone IDs or word piece IDs.
sos_id:
sos token id
eos_id:
eos token id
Returns:
A scalar, the **sum** of label smoothing loss over utterances
in the batch without any normalization.
"""
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
device = memory.device
ys_in_pad = ys_in_pad.to(device)
ys_out_pad = ys_out_pad.to(device)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask
# We set the first column to False since the first column in ys_in_pad
# contains sos_id, which is the same as eos_id in our current setting.
tgt_key_padding_mask[:, 0] = False
tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C)
tgt = self.decoder_pos(tgt)
tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
pred_pad = self.decoder(
tgt=tgt,
memory=memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
) # (T, N, C)
pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C)
decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad)
return decoder_loss
@torch.jit.export
def decoder_nll(
self,
memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
token_ids: List[torch.Tensor],
sos_id: int,
eos_id: int,
) -> torch.Tensor:
"""
Args:
memory:
It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask:
The padding mask from the encoder.
token_ids:
A list-of-list IDs (e.g., word piece IDs).
Each sublist represents an utterance.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
Returns:
A 2-D tensor of shape (len(token_ids), max_token_length)
representing the cross entropy loss (i.e., negative log-likelihood).
"""
# The common part between this function and decoder_forward could be
# extracted as a separate function.
if isinstance(token_ids[0], torch.Tensor):
# This branch is executed by torchscript in C++.
# See https://github.com/k2-fsa/k2/pull/870
# https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286
token_ids = [tolist(t) for t in token_ids]
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask
# We set the first column to False since the first column in ys_in_pad
# contains sos_id, which is the same as eos_id in our current setting.
tgt_key_padding_mask[:, 0] = False
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
tgt = self.decoder_pos(tgt)
tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
pred_pad = self.decoder(
tgt=tgt,
memory=memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
) # (T, B, F)
pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F)
pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F)
# nll: negative log-likelihood
nll = torch.nn.functional.cross_entropy(
pred_pad.view(-1, self.decoder_num_class),
ys_out_pad.view(-1),
ignore_index=-1,
reduction="none",
)
nll = nll.view(pred_pad.shape[0], -1)
return nll
class TransformerEncoderLayer(nn.Module):
"""
Modified from torch.nn.TransformerEncoderLayer.
Add support of normalize_before,
i.e., use layer_norm before the first block.
Args:
d_model:
the number of expected features in the input (required).
nhead:
the number of heads in the multiheadattention models (required).
dim_feedforward:
the dimension of the feedforward network model (default=2048).
dropout:
the dropout value (default=0.1).
activation:
the activation function of intermediate layer, relu or
gelu (default=relu).
normalize_before:
whether to use layer_norm before the first block.
Examples::
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> out = encoder_layer(src)
"""
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: str = "relu",
normalize_before: bool = True,
) -> None:
super(TransformerEncoderLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def __setstate__(self, state):
if "activation" not in state:
state["activation"] = nn.functional.relu
super(TransformerEncoderLayer, self).__setstate__(state)
def forward(
self,
src: torch.Tensor,
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional)
Shape:
src: (S, N, E).
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length,
N is the batch size, E is the feature number
"""
residual = src
if self.normalize_before:
src = self.norm1(src)
src2 = self.self_attn(
src,
src,
src,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
src = residual + self.dropout1(src2)
if not self.normalize_before:
src = self.norm1(src)
residual = src
if self.normalize_before:
src = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = residual + self.dropout2(src2)
if not self.normalize_before:
src = self.norm2(src)
return src
class TransformerDecoderLayer(nn.Module):
"""
Modified from torch.nn.TransformerDecoderLayer.
Add support of normalize_before,
i.e., use layer_norm before the first block.
Args:
d_model:
the number of expected features in the input (required).
nhead:
the number of heads in the multiheadattention models (required).
dim_feedforward:
the dimension of the feedforward network model (default=2048).
dropout:
the dropout value (default=0.1).
activation:
the activation function of intermediate layer, relu or
gelu (default=relu).
Examples::
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
>>> memory = torch.rand(10, 32, 512)
>>> tgt = torch.rand(20, 32, 512)
>>> out = decoder_layer(tgt, memory)
"""
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: str = "relu",
normalize_before: bool = True,
) -> None:
super(TransformerDecoderLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
self.src_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def __setstate__(self, state):
if "activation" not in state:
state["activation"] = nn.functional.relu
super(TransformerDecoderLayer, self).__setstate__(state)
def forward(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
tgt_mask: Optional[torch.Tensor] = None,
memory_mask: Optional[torch.Tensor] = None,
tgt_key_padding_mask: Optional[torch.Tensor] = None,
memory_key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Pass the inputs (and mask) through the decoder layer.
Args:
tgt:
the sequence to the decoder layer (required).
memory:
the sequence from the last layer of the encoder (required).
tgt_mask:
the mask for the tgt sequence (optional).
memory_mask:
the mask for the memory sequence (optional).
tgt_key_padding_mask:
the mask for the tgt keys per batch (optional).
memory_key_padding_mask:
the mask for the memory keys per batch (optional).
Shape:
tgt: (T, N, E).
memory: (S, N, E).
tgt_mask: (T, T).
memory_mask: (T, S).
tgt_key_padding_mask: (N, T).
memory_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length,
N is the batch size, E is the feature number
"""
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
tgt2 = self.self_attn(
tgt,
tgt,
tgt,
attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask,
)[0]
tgt = residual + self.dropout1(tgt2)
if not self.normalize_before:
tgt = self.norm1(tgt)
residual = tgt
if self.normalize_before:
tgt = self.norm2(tgt)
tgt2 = self.src_attn(
tgt,
memory,
memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt = residual + self.dropout2(tgt2)
if not self.normalize_before:
tgt = self.norm2(tgt)
residual = tgt
if self.normalize_before:
tgt = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = residual + self.dropout3(tgt2)
if not self.normalize_before:
tgt = self.norm3(tgt)
return tgt
def _get_activation_fn(activation: str):
if activation == "relu":
return nn.functional.relu
elif activation == "gelu":
return nn.functional.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
class PositionalEncoding(nn.Module):
"""This class implements the positional encoding
proposed in the following paper:
- Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf
PE(pos, 2i) = sin(pos / (10000^(2i/d_modle))
PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle))
Note::
1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model)))
= exp(-1* 2i / d_model * log(100000))
= exp(2i * -(log(10000) / d_model))
"""
def __init__(self, d_model: int, dropout: float = 0.1) -> None:
"""
Args:
d_model:
Embedding dimension.
dropout:
Dropout probability to be applied to the output of this module.
"""
super().__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = nn.Dropout(p=dropout)
# not doing: self.pe = None because of errors thrown by torchscript
self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32)
def extend_pe(self, x: torch.Tensor) -> None:
"""Extend the time t in the positional encoding if required.
The shape of `self.pe` is (1, T1, d_model). The shape of the input x
is (N, T, d_model). If T > T1, then we change the shape of self.pe
to (N, T, d_model). Otherwise, nothing is done.
Args:
x:
It is a tensor of shape (N, T, C).
Returns:
Return None.
"""
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
# Now pe is of shape (1, T, d_model), where T is x.size(1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Add positional encoding.
Args:
x:
Its shape is (N, T, C)
Returns:
Return a tensor of shape (N, T, C)
"""
self.extend_pe(x)
x = x * self.xscale + self.pe[:, : x.size(1), :]
return self.dropout(x)
class Noam(object):
"""
Implements Noam optimizer.
Proposed in
"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
Modified from
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa
Args:
params:
iterable of parameters to optimize or dicts defining parameter groups
model_size:
attention dimension of the transformer model
factor:
learning rate factor
warm_step:
warmup steps
"""
def __init__(
self,
params,
model_size: int = 256,
factor: float = 10.0,
warm_step: int = 25000,
weight_decay=0,
) -> None:
"""Construct an Noam object."""
self.optimizer = torch.optim.Adam(
params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
)
self._step = 0
self.warmup = warm_step
self.factor = factor
self.model_size = model_size
self._rate = 0
@property
def param_groups(self):
"""Return param_groups."""
return self.optimizer.param_groups
def step(self):
"""Update parameters and rate."""
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p["lr"] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step=None):
"""Implement `lrate` above."""
if step is None:
step = self._step
return (
self.factor
* self.model_size ** (-0.5)
* min(step ** (-0.5), step * self.warmup ** (-1.5))
)
def zero_grad(self):
"""Reset gradient."""
self.optimizer.zero_grad()
def state_dict(self):
"""Return state_dict."""
return {
"_step": self._step,
"warmup": self.warmup,
"factor": self.factor,
"model_size": self.model_size,
"_rate": self._rate,
"optimizer": self.optimizer.state_dict(),
}
def load_state_dict(self, state_dict):
"""Load state_dict."""
for key, value in state_dict.items():
if key == "optimizer":
self.optimizer.load_state_dict(state_dict["optimizer"])
else:
setattr(self, key, value)
def encoder_padding_mask(
max_len: int, supervisions: Optional[Supervisions] = None
) -> Optional[torch.Tensor]:
"""Make mask tensor containing indexes of padded part.
TODO::
This function **assumes** that the model uses
a subsampling factor of 4. We should remove that
assumption later.
Args:
max_len:
Maximum length of input features.
CAUTION: It is the length after subsampling.
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
(CAUTION: It contains length information, i.e., start and number of
frames, before subsampling)
Returns:
Tensor: Mask tensor of dimension (batch_size, input_length),
True denote the masked indices.
"""
if supervisions is None:
return None
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"],
supervisions["num_frames"],
),
1,
).to(torch.int32)
lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
for idx in range(supervision_segments.size(0)):
# Note: TorchScript doesn't allow to unpack tensors as tuples
sequence_idx = supervision_segments[idx, 0].item()
start_frame = supervision_segments[idx, 1].item()
num_frames = supervision_segments[idx, 2].item()
lengths[sequence_idx] = start_frame + num_frames
lengths = [((i - 1) // 2 - 1) // 2 for i in lengths]
bs = int(len(lengths))
seq_range = torch.arange(0, max_len, dtype=torch.int64)
seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len)
# Note: TorchScript doesn't implement Tensor.new()
seq_length_expand = torch.tensor(
lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype
).unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
"""Generate a length mask for input.
The masked position are filled with True,
Unmasked positions are filled with False.
Args:
ys_pad:
padded tensor of dimension (batch_size, input_length).
ignore_id:
the ignored number (the padding number) in ys_pad
Returns:
Tensor:
a bool tensor of the same shape as the input tensor.
"""
ys_mask = ys_pad == ignore_id
return ys_mask
def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
"""Generate a square mask for the sequence. The masked positions are
filled with float('-inf'). Unmasked positions are filled with float(0.0).
The mask can be used for masked self-attention.
For instance, if sz is 3, it returns::
tensor([[0., -inf, -inf],
[0., 0., -inf],
[0., 0., 0]])
Args:
sz: mask size
Returns:
A square mask of dimension (sz, sz)
"""
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = (
mask.float()
.masked_fill(mask == 0, float("-inf"))
.masked_fill(mask == 1, float(0.0))
)
return mask
def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
"""Prepend sos_id to each utterance.
Args:
token_ids:
A list-of-list of token IDs. Each sublist contains
token IDs (e.g., word piece IDs) of an utterance.
sos_id:
The ID of the SOS token.
Return:
Return a new list-of-list, where each sublist starts
with SOS ID.
"""
return [[sos_id] + utt for utt in token_ids]
def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
"""Append eos_id to each utterance.
Args:
token_ids:
A list-of-list of token IDs. Each sublist contains
token IDs (e.g., word piece IDs) of an utterance.
eos_id:
The ID of the EOS token.
Return:
Return a new list-of-list, where each sublist ends
with EOS ID.
"""
return [utt + [eos_id] for utt in token_ids]
def tolist(t: torch.Tensor) -> List[int]:
"""Used by jit"""
return torch.jit.annotate(List[int], t.tolist())

File diff suppressed because it is too large Load Diff