diff --git a/egs/librispeech/WSASR/README.md b/egs/librispeech/WSASR/README.md
new file mode 100644
index 000000000..3b8822fd2
--- /dev/null
+++ b/egs/librispeech/WSASR/README.md
@@ -0,0 +1,224 @@
+# Introduction
+
+This is a weakly supervised ASR recipe for the LibriSpeech (clean 100 hours) dataset. We train a
+conformer model using [Bypass Temporal Classification](https://arxiv.org/pdf/2306.01031.pdf) (BTC)/[Omni-temporal Classification](https://arxiv.org/pdf/2309.15796.pdf) (OTC) with transcripts with synthetic errors. In this README, we will describe
+the task and the BTC/OTC training process.
+
+Note that OTC is an extension of BTC and supports all BTC functions. Therefore, in the following, we only describe OTC.
+## Task
+We propose BTC/OTC to directly train an ASR system leveraging weak supervision, i.e., speech with non-verbatim transcripts. This is achieved by using a special token $\star$ to model uncertainties (i.e., substitution errors, insertion errors, and deletion errors)
+within the WFST framework during training.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Examples of errors (substitution, insertion, and deletion) in the transcript. The grey box is the verbatim transcript and the red box is the inaccurate transcript. Inaccurate words are marked in bold.
+
+
+We modify $G(\mathbf{y})$ by adding self-loop arcs into each state and bypass arcs into each arc.
+
+
+
+
+
+We incorporate the penalty strategy and apply different configurations for the self-loop arc and bypass arc. The penalties are set as
+
+$$\lambda_{1_{i}} = \beta_{1} * \tau_{1}^{i},\quad \lambda_{2_{i}} = \beta_{2} * \tau_{2}^{i}$$
+
+for the $i$-th training epoch. $\beta$ is the initial penalty that encourages the model to rely more on the given transcript at the start of training.
+It decays exponentially by a factor of $\tau \in (0, 1)$, gradually encouraging the model to align speech with $\star$ when getting confused.
+
+After composing the modified WFST $G_{\text{otc}}(\mathbf{y})$ with $L$ and $T$, the OTC training graph is shown in this figure:
+
+
+ OTC training graph. The self-loop arcs and bypass arcs are highlighted in green and blue, respectively.
+
+
+The $\star$ is represented as the average probability of all non-blank tokens.
+
+
+
+
+The weight of $\star$ is the log average probability of "a" and "b": $\log \frac{e^{-1.2} + e^{-2.3}}{2} = -1.6$ and $\log \frac{e^{-1.9} + e^{-0.5}}{2} = -1.0$ for 2 frames.
+
+## Description of the recipe
+### Preparation
+```
+# feature_type can be ssl or fbank
+feature_type=ssl
+feature_dir="data/${feature_type}"
+manifest_dir="${feature_dir}"
+lang_dir="data/lang"
+lm_dir="data/lm"
+exp_dir="conformer_ctc2/exp"
+otc_token=""
+
+./prepare.sh \
+ --feature-type "${feature_type}" \
+ --feature-dir "${feature_dir}" \
+ --lang-dir "${lang_dir}" \
+ --lm-dir "${lm_dir}" \
+ --otc-token "${otc_token}"
+```
+This script adds the 'otc_token' ('\') and its corresponding sentence-piece ('▁\') to 'words.txt' and 'tokens.txt,' respectively. Additionally, it computes SSL features using the 'wav2vec2-base' model. (You can use GPU to accelerate feature extraction).
+
+### Making synthetic errors to the transcript (train-clean-100) [optional]
+```
+sub_er=0.17
+ins_er=0.17
+del_er=0.17
+synthetic_train_manifest="librispeech_cuts_train-clean-100_${sub_er}_${ins_er}_${del_er}.jsonl.gz"
+
+./local/make_error_cutset.py \
+ --input-cutset "${manifest_dir}/librispeech_cuts_train-clean-100.jsonl.gz" \
+ --words-file "${lang_dir}/words.txt" \
+ --sub-error-rate "${sub_er}" \
+ --ins-error-rate "${ins_er}" \
+ --del-error-rate "${del_er}" \
+ --output-cutset "${manifest_dir}/${synthetic_train_manifest}"
+```
+This script generates synthetic substitution, insertion, and deletion errors in the transcript with ratios 'sub_er', 'ins_er', and 'del_er', respectively. The original transcript is saved as 'verbatim transcript' in the cutset, along with information on how the transcript is corrupted:
+
+ - '[hello]' indicates the original word 'hello' is substituted by another word
+ - '[]' indicates an extra word is inserted into the transcript
+ - '-hello-' indicates the word 'hello' is deleted from the transcript
+
+So if the original transcript is "have a nice day" and the synthetic one is "a very good day", the 'verbatim transcript' would be:
+```
+original: have a nice day
+synthetic: a very good day
+verbatim: -have- a [] [nice] day
+```
+
+### Training
+The training uses synthetic data based on the train-clean-100 subset.
+```
+otc_lang_dir=data/lang_bpe_200
+
+allow_bypass_arc=true
+allow_self_loop_arc=true
+initial_bypass_weight=-19
+initial_self_loop_weight=3.75
+bypass_weight_decay=0.975
+self_loop_weight_decay=0.999
+
+show_alignment=true
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+./conformer_ctc2/train.py \
+ --world-size 4 \
+ --manifest-dir "${manifest_dir}" \
+ --train-manifest "${synthetic_train_manifest}" \
+ --exp-dir "${exp_dir}" \
+ --lang-dir "${otc_lang_dir}" \
+ --otc-token "${otc_token}" \
+ --allow-bypass-arc "${allow_bypass_arc}" \
+ --allow-self-loop-arc "${allow_self_loop_arc}" \
+ --initial-bypass-weight "${initial_bypass_weight}" \
+ --initial-self-loop-weight "${initial_self_loop_weight}" \
+ --bypass-weight-decay "${bypass_weight_decay}" \
+ --self-loop-weight-decay "${self_loop_weight_decay}" \
+ --show-alignment "${show_alignment}"
+```
+The bypass arc deals with substitution and insertion errors, while the self-loop arc deals with deletion errors. Using "--show-alignment" would print the best alignment during training, which is very helpful for tuning hyperparameters and debugging.
+
+### Decoding
+```
+export CUDA_VISIBLE_DEVICES="0"
+./conformer_ctc2/decode.py \
+ --manifest-dir "${manifest_dir}" \
+ --exp-dir "${exp_dir}" \
+ --lang-dir "${otc_lang_dir}" \
+ --lm-dir "${lm_dir}" \
+ --otc-token "${otc_token}"
+```
+
+### Results (ctc-greedy-search)
+
+
+
Training Criterion
+
ssl
+
fbank
+
+
+
test-clean
+
test-other
+
test-clean
+
test-other
+
+
+
CTC
+
100.0
+
100.0
+
99.89
+
99.98
+
+
+
OTC
+
11.89
+
25.46
+
20.14
+
44.24
+
+
+
+### Results (1best, blank_bias=-4)
+
+
+
Training Criterion
+
ssl
+
fbank
+
+
+
test-clean
+
test-other
+
test-clean
+
test-other
+
+
+
CTC
+
98.40
+
98.68
+
99.79
+
99.86
+
+
+
OTC
+
6.59
+
15.98
+
11.78
+
32.38
+
+
+
+## Pre-trained Model
+Pre-trained model:
+
+## Citations
+```
+@inproceedings{gao2023bypass,
+ title={Bypass Temporal Classification: Weakly Supervised Automatic Speech Recognition with Imperfect Transcripts},
+ author={Gao, Dongji and Wiesner, Matthew and Xu, Hainan and Garcia, Leibny Paola and Povey, Daniel and Khudanpur, Sanjeev},
+ booktitle={INTERSPEECH},
+ year={2023}
+}
+
+@inproceedings{gao2023learning,
+ title={Learning from Flawed Data: Weakly Supervised Automatic Speech Recognition},
+ author={Gao, Dongji and Xu, Hainan and Raj, Desh and Garcia, Leibny Paola and Povey, Daniel and Khudanpur, Sanjeev},
+ booktitle={IEEE ASRU},
+ year={2023}
+}
+```
diff --git a/egs/librispeech/WSASR/conformer_ctc2/__init__.py b/egs/librispeech/WSASR/conformer_ctc2/__init__.py
new file mode 120000
index 000000000..43a85af20
--- /dev/null
+++ b/egs/librispeech/WSASR/conformer_ctc2/__init__.py
@@ -0,0 +1 @@
+../../ASR/pruned_transducer_stateless2/__init__.py
\ No newline at end of file
diff --git a/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py b/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py
new file mode 100644
index 000000000..1b6991bcd
--- /dev/null
+++ b/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py
@@ -0,0 +1,369 @@
+# Copyright 2021 Piotr Żelasko
+# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
+# 2023 John Hopkins University (author: Dongji Gao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import inspect
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, load_manifest, load_manifest_lazy
+from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ K2SpeechRecognitionDataset,
+ PrecomputedFeatures,
+ SingleCutSampler,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import AudioSamples # noqa F401 For AudioSamples
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class LibriSpeechAsrDataModule:
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - augmentation,
+ - on-the-fly feature extraction
+
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+ group.add_argument(
+ "--full-libri",
+ type=str2bool,
+ default=False,
+ help="""Used only when --mini-libri is False.When enabled,
+ use 960h LibriSpeech. Otherwise, use 100h subset.""",
+ )
+ group.add_argument(
+ "--mini-libri",
+ type=str2bool,
+ default=False,
+ help="True for mini librispeech",
+ )
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/ssl"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--concatenate-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, utterances (cuts) will be concatenated "
+ "to minimize the amount of padding.",
+ )
+ group.add_argument(
+ "--duration-factor",
+ type=float,
+ default=1.0,
+ help="Determines the maximum duration of a concatenated cut "
+ "relative to the duration of the longest cut in a batch.",
+ )
+ group.add_argument(
+ "--gap",
+ type=float,
+ default=1.0,
+ help="The amount of padding (in seconds) inserted between "
+ "concatenated cuts. This padding is filled with noise when "
+ "noise augmentation is used.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--input-strategy",
+ type=str,
+ default="PrecomputedFeatures",
+ help="AudioSamples or PrecomputedFeatures",
+ )
+
+ group.add_argument(
+ "--train-manifest",
+ type=str,
+ default="librispeech_cuts_train-clean-100.jsonl.gz",
+ help="Train manifest file.",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ transforms = []
+ if self.args.concatenate_cuts:
+ logging.info(
+ f"Using cut concatenation with duration factor "
+ f"{self.args.duration_factor} and gap {self.args.gap}."
+ )
+ # Cut concatenation should be the first transform in the list,
+ # so that if we e.g. mix noise in, it will fill the gaps between
+ # different utterances.
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ logging.info("About to create train dataset")
+ train = K2SpeechRecognitionDataset(
+ input_strategy=eval(self.args.input_strategy)(),
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=self.args.drop_last,
+ )
+ else:
+ logging.info("Using SingleCutSampler.")
+ train_sampler = SingleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ transforms = []
+ if self.args.concatenate_cuts:
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ logging.info("About to create dev dataset")
+
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+
+ logging.info("About to create dev dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.debug("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.debug("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_clean_5_cuts(self) -> CutSet:
+ logging.info("mini_librispeech: About to get train-clean-5 cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz"
+ )
+
+ @lru_cache()
+ def train_clean_100_cuts(self) -> CutSet:
+ logging.info("About to get train-clean-100 cuts")
+ return load_manifest_lazy(self.args.manifest_dir / self.args.train_manifest)
+
+ @lru_cache()
+ def train_all_shuf_cuts(self) -> CutSet:
+ logging.info(
+ "About to get the shuffled train-clean-100, \
+ train-clean-360 and train-other-500 cuts"
+ )
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
+ )
+
+ @lru_cache()
+ def dev_clean_2_cuts(self) -> CutSet:
+ logging.info("mini_librispeech: About to get dev-clean-2 cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz"
+ )
+
+ @lru_cache()
+ def dev_clean_cuts(self) -> CutSet:
+ logging.info("About to get dev-clean cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
+ )
+
+ @lru_cache()
+ def dev_other_cuts(self) -> CutSet:
+ logging.info("About to get dev-other cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_clean_cuts(self) -> CutSet:
+ logging.info("About to get test-clean cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_other_cuts(self) -> CutSet:
+ logging.info("About to get test-other cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
+ )
diff --git a/egs/librispeech/WSASR/conformer_ctc2/attention.py b/egs/librispeech/WSASR/conformer_ctc2/attention.py
new file mode 120000
index 000000000..e808a6f20
--- /dev/null
+++ b/egs/librispeech/WSASR/conformer_ctc2/attention.py
@@ -0,0 +1 @@
+../../ASR/conformer_ctc2/attention.py
\ No newline at end of file
diff --git a/egs/librispeech/WSASR/conformer_ctc2/conformer.py b/egs/librispeech/WSASR/conformer_ctc2/conformer.py
new file mode 100644
index 000000000..db4821d37
--- /dev/null
+++ b/egs/librispeech/WSASR/conformer_ctc2/conformer.py
@@ -0,0 +1,949 @@
+#!/usr/bin/env python3
+# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
+# 2022 Xiaomi Corp. (author: Quandong Wang)
+# 2023 Johns Hopkins University (author: Dongji Gao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import math
+import warnings
+from typing import Optional, Tuple
+
+import torch
+from scaling import (
+ ActivationBalancer,
+ BasicNorm,
+ DoubleSwish,
+ ScaledConv1d,
+ ScaledLinear,
+)
+from subsampling import Conv2dSubsampling, Conv2dSubsampling2
+from torch import Tensor, nn
+from transformer import Supervisions, Transformer, encoder_padding_mask
+
+
+class Conformer(Transformer):
+ """
+ Args:
+ num_features (int): Number of input features
+ num_classes (int): Number of output classes
+ subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
+ d_model (int): attention dimension, also the output dimension
+ nhead (int): number of head
+ dim_feedforward (int): feedforward dimention
+ num_encoder_layers (int): number of encoder layers
+ num_decoder_layers (int): number of decoder layers
+ dropout (float): dropout rate
+ layer_dropout (float): layer-dropout rate.
+ cnn_module_kernel (int): Kernel size of convolution module
+ vgg_frontend (bool): whether to use vgg frontend.
+ """
+
+ def __init__(
+ self,
+ num_features: int,
+ num_classes: int,
+ subsampling_factor: int = 2,
+ d_model: int = 256,
+ nhead: int = 4,
+ dim_feedforward: int = 2048,
+ num_encoder_layers: int = 12,
+ num_decoder_layers: int = 6,
+ dropout: float = 0.2,
+ layer_dropout: float = 0.075,
+ cnn_module_kernel: int = 31,
+ ) -> None:
+ super(Conformer, self).__init__(
+ num_features=num_features,
+ num_classes=num_classes,
+ subsampling_factor=subsampling_factor,
+ d_model=d_model,
+ nhead=nhead,
+ dim_feedforward=dim_feedforward,
+ num_encoder_layers=num_encoder_layers,
+ num_decoder_layers=num_decoder_layers,
+ dropout=dropout,
+ layer_dropout=layer_dropout,
+ )
+
+ self.num_features = num_features
+ self.subsampling_factor = subsampling_factor
+ if subsampling_factor != 4 and subsampling_factor != 2:
+ raise NotImplementedError("Support only 'subsampling_factor=4 or 2'.")
+
+ # self.encoder_embed converts the input of shape (N, T, num_features)
+ # to the shape (N, T//subsampling_factor, d_model).
+ # That is, it does two things simultaneously:
+ # (1) subsampling: T -> T//subsampling_factor
+ # (2) embedding: num_features -> d_model
+ if self.subsampling_factor == 4:
+ self.encoder_embed = Conv2dSubsampling(num_features, d_model)
+ elif self.subsampling_factor == 2:
+ self.encoder_embed = Conv2dSubsampling2(num_features, d_model)
+
+ self.encoder_pos = RelPositionalEncoding(d_model, dropout)
+
+ encoder_layer = ConformerEncoderLayer(
+ d_model,
+ nhead,
+ dim_feedforward,
+ dropout,
+ layer_dropout,
+ cnn_module_kernel,
+ )
+ self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
+
+ def run_encoder(
+ self,
+ x: torch.Tensor,
+ supervisions: Optional[Supervisions] = None,
+ warmup: float = 1.0,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """
+ Args:
+ x:
+ The input tensor. Its shape is (batch_size, seq_len, feature_dim).
+ supervisions:
+ Supervision in lhotse format.
+ See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
+ CAUTION: It contains length information, i.e., start and number of
+ frames, before subsampling
+ It is read directly from the batch, without any sorting. It is used
+ to compute encoder padding mask, which is used as memory key padding
+ mask for the decoder.
+ warmup:
+ A floating point value that gradually increases from 0 throughout
+ training; when it is >= 1.0 we are "fully warmed up". It is used
+ to turn modules on sequentially.
+ Returns:
+ Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
+ Tensor: Mask tensor of dimension (batch_size, input_length)
+ """
+ x = self.encoder_embed(x)
+ x, pos_emb = self.encoder_pos(x)
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+ mask = encoder_padding_mask(x.size(0), self.subsampling_factor, supervisions)
+ if mask is not None:
+ mask = mask.to(x.device)
+
+ # Caution: We assume the subsampling factor is 4!
+
+ x = self.encoder(
+ x, pos_emb, src_key_padding_mask=mask, warmup=warmup
+ ) # (T, N, C)
+
+ # x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ # return x, lengths
+ return x, mask
+
+
+class ConformerEncoderLayer(nn.Module):
+ """
+ ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
+ See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
+
+ Args:
+ d_model: the number of expected features in the input (required).
+ nhead: the number of heads in the multiheadattention models (required).
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
+ dropout: the dropout value (default=0.1).
+ cnn_module_kernel (int): Kernel size of convolution module.
+
+ Examples::
+ >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
+ >>> src = torch.rand(10, 32, 512)
+ >>> pos_emb = torch.rand(32, 19, 512)
+ >>> out = encoder_layer(src, pos_emb)
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ nhead: int,
+ dim_feedforward: int = 2048,
+ dropout: float = 0.1,
+ layer_dropout: float = 0.075,
+ cnn_module_kernel: int = 31,
+ ) -> None:
+ super(ConformerEncoderLayer, self).__init__()
+
+ self.layer_dropout = layer_dropout
+
+ self.d_model = d_model
+
+ self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
+
+ self.feed_forward = nn.Sequential(
+ ScaledLinear(d_model, dim_feedforward),
+ ActivationBalancer(channel_dim=-1),
+ DoubleSwish(),
+ nn.Dropout(dropout),
+ ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
+ )
+
+ self.feed_forward_macaron = nn.Sequential(
+ ScaledLinear(d_model, dim_feedforward),
+ ActivationBalancer(channel_dim=-1),
+ DoubleSwish(),
+ nn.Dropout(dropout),
+ ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
+ )
+
+ self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
+
+ self.norm_final = BasicNorm(d_model)
+
+ # try to ensure the output is close to zero-mean (or at least, zero-median).
+ self.balancer = ActivationBalancer(
+ channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
+ )
+
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(
+ self,
+ src: Tensor,
+ pos_emb: Tensor,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ warmup: float = 1.0,
+ ) -> Tensor:
+ """
+ Pass the input through the encoder layer.
+
+ Args:
+ src: the sequence to the encoder layer (required).
+ pos_emb: Positional embedding tensor (required).
+ src_mask: the mask for the src sequence (optional).
+ src_key_padding_mask: the mask for the src keys per batch (optional).
+ warmup: controls selective bypass of of layers; if < 1.0, we will
+ bypass layers more frequently.
+
+ Shape:
+ src: (S, N, E).
+ pos_emb: (N, 2*S-1, E)
+ src_mask: (S, S).
+ src_key_padding_mask: (N, S).
+ S is the source sequence length, N is the batch size, E is the feature number
+ """
+ src_orig = src
+
+ warmup_scale = min(0.1 + warmup, 1.0)
+ # alpha = 1.0 means fully use this encoder layer, 0.0 would mean
+ # completely bypass it.
+ if self.training:
+ alpha = (
+ warmup_scale
+ if torch.rand(()).item() <= (1.0 - self.layer_dropout)
+ else 0.1
+ )
+ else:
+ alpha = 1.0
+
+ # macaron style feed forward module
+ src = src + self.dropout(self.feed_forward_macaron(src))
+
+ # multi-headed self-attention module
+ src_att = self.self_attn(
+ src,
+ src,
+ src,
+ pos_emb=pos_emb,
+ attn_mask=src_mask,
+ key_padding_mask=src_key_padding_mask,
+ )[0]
+ src = src + self.dropout(src_att)
+
+ # convolution module
+ src = src + self.dropout(
+ self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
+ )
+
+ # feed forward module
+ src = src + self.dropout(self.feed_forward(src))
+
+ src = self.norm_final(self.balancer(src))
+
+ if alpha != 1.0:
+ src = alpha * src + (1 - alpha) * src_orig
+
+ return src
+
+
+class ConformerEncoder(nn.Module):
+ r"""ConformerEncoder is a stack of N encoder layers
+
+ Args:
+ encoder_layer: an instance of the ConformerEncoderLayer() class (required).
+ num_layers: the number of sub-encoder-layers in the encoder (required).
+
+ Examples::
+ >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
+ >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
+ >>> src = torch.rand(10, 32, 512)
+ >>> pos_emb = torch.rand(32, 19, 512)
+ >>> out = conformer_encoder(src, pos_emb)
+ """
+
+ def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [copy.deepcopy(encoder_layer) for i in range(num_layers)]
+ )
+ self.num_layers = num_layers
+
+ def forward(
+ self,
+ src: Tensor,
+ pos_emb: Tensor,
+ mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ warmup: float = 1.0,
+ ) -> Tensor:
+ r"""Pass the input through the encoder layers in turn.
+
+ Args:
+ src: the sequence to the encoder (required).
+ pos_emb: Positional embedding tensor (required).
+ mask: the mask for the src sequence (optional).
+ src_key_padding_mask: the mask for the src keys per batch (optional).
+
+ Shape:
+ src: (S, N, E).
+ pos_emb: (N, 2*S-1, E)
+ mask: (S, S).
+ src_key_padding_mask: (N, S).
+ S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
+
+ """
+ output = src
+
+ for i, mod in enumerate(self.layers):
+ output = mod(
+ output,
+ pos_emb,
+ src_mask=mask,
+ src_key_padding_mask=src_key_padding_mask,
+ warmup=warmup,
+ )
+
+ return output
+
+
+class RelPositionalEncoding(torch.nn.Module):
+ """Relative positional encoding module.
+
+ See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
+ Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
+
+ Args:
+ d_model: Embedding dimension.
+ dropout_rate: Dropout rate.
+ max_len: Maximum input length.
+
+ """
+
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
+ """Construct an PositionalEncoding object."""
+ super(RelPositionalEncoding, self).__init__()
+ self.d_model = d_model
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+
+ def extend_pe(self, x: Tensor) -> None:
+ """Reset the positional encodings."""
+ if self.pe is not None:
+ # self.pe contains both positive and negative parts
+ # the length of self.pe is 2 * input_len - 1
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
+ # Note: TorchScript doesn't implement operator== for torch.Device
+ if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ # Suppose `i` means to the position of query vecotr and `j` means the
+ # position of key vector. We use position relative positions when keys
+ # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]:
+ """Add positional encoding.
+
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
+
+ """
+ self.extend_pe(x)
+ pos_emb = self.pe[
+ :,
+ self.pe.size(1) // 2
+ - x.size(1)
+ + 1 : self.pe.size(1) // 2 # noqa E203
+ + x.size(1),
+ ]
+ return self.dropout(x), self.dropout(pos_emb)
+
+
+class RelPositionMultiheadAttention(nn.Module):
+ r"""Multi-Head Attention layer with relative position encoding
+
+ See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
+
+ Args:
+ embed_dim: total dimension of the model.
+ num_heads: parallel attention heads.
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
+
+ Examples::
+
+ >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ ) -> None:
+ super(RelPositionMultiheadAttention, self).__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ assert (
+ self.head_dim * num_heads == self.embed_dim
+ ), "embed_dim must be divisible by num_heads"
+
+ self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
+ self.out_proj = ScaledLinear(
+ embed_dim, embed_dim, bias=True, initial_scale=0.25
+ )
+
+ # linear transformation for positional encoding.
+ self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
+ self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
+ self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
+ self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
+ self._reset_parameters()
+
+ def _pos_bias_u(self):
+ return self.pos_bias_u * self.pos_bias_u_scale.exp()
+
+ def _pos_bias_v(self):
+ return self.pos_bias_v * self.pos_bias_v_scale.exp()
+
+ def _reset_parameters(self) -> None:
+ nn.init.normal_(self.pos_bias_u, std=0.01)
+ nn.init.normal_(self.pos_bias_v, std=0.01)
+
+ def forward(
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ pos_emb: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ r"""
+ Args:
+ query, key, value: map a query and a set of key-value pairs to an output.
+ pos_emb: Positional embedding tensor
+ key_padding_mask: if provided, specified padding elements in the key will
+ be ignored by the attention. When given a binary mask and a value is True,
+ the corresponding value on the attention layer will be ignored. When given
+ a byte mask and a value is non-zero, the corresponding value on the attention
+ layer will be ignored
+ need_weights: output attn_output_weights.
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
+
+ Shape:
+ - Inputs:
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
+ If a ByteTensor is provided, the non-zero positions will be ignored while the position
+ with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
+ S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
+ is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
+ is provided, it will be added to the attention weight.
+
+ - Outputs:
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+ E is the embedding dimension.
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
+ L is the target sequence length, S is the source sequence length.
+ """
+ return self.multi_head_attention_forward(
+ query,
+ key,
+ value,
+ pos_emb,
+ self.embed_dim,
+ self.num_heads,
+ self.in_proj.get_weight(),
+ self.in_proj.get_bias(),
+ self.dropout,
+ self.out_proj.get_weight(),
+ self.out_proj.get_bias(),
+ training=self.training,
+ key_padding_mask=key_padding_mask,
+ need_weights=need_weights,
+ attn_mask=attn_mask,
+ )
+
+ def rel_shift(self, x: Tensor) -> Tensor:
+ """Compute relative positional encoding.
+
+ Args:
+ x: Input tensor (batch, head, time1, 2*time1-1).
+ time1 means the length of query vector.
+
+ Returns:
+ Tensor: tensor of shape (batch, head, time1, time2)
+ (note: time2 has the same value as time1, but it is for
+ the key, while time1 is for the query).
+ """
+ (batch_size, num_heads, time1, n) = x.shape
+ assert n == 2 * time1 - 1
+ # Note: TorchScript requires explicit arg for stride()
+ batch_stride = x.stride(0)
+ head_stride = x.stride(1)
+ time1_stride = x.stride(2)
+ n_stride = x.stride(3)
+ return x.as_strided(
+ (batch_size, num_heads, time1, time1),
+ (batch_stride, head_stride, time1_stride - n_stride, n_stride),
+ storage_offset=n_stride * (time1 - 1),
+ )
+
+ def multi_head_attention_forward(
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ pos_emb: Tensor,
+ embed_dim_to_check: int,
+ num_heads: int,
+ in_proj_weight: Tensor,
+ in_proj_bias: Tensor,
+ dropout_p: float,
+ out_proj_weight: Tensor,
+ out_proj_bias: Tensor,
+ training: bool = True,
+ key_padding_mask: Optional[Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ r"""
+ Args:
+ query, key, value: map a query and a set of key-value pairs to an output.
+ pos_emb: Positional embedding tensor
+ embed_dim_to_check: total dimension of the model.
+ num_heads: parallel attention heads.
+ in_proj_weight, in_proj_bias: input projection weight and bias.
+ dropout_p: probability of an element to be zeroed.
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
+ training: apply dropout if is ``True``.
+ key_padding_mask: if provided, specified padding elements in the key will
+ be ignored by the attention. This is an binary mask. When the value is True,
+ the corresponding value on the attention layer will be filled with -inf.
+ need_weights: output attn_output_weights.
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
+
+ Shape:
+ Inputs:
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
+ length, N is the batch size, E is the embedding dimension.
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
+ If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
+ will be unchanged. If a BoolTensor is provided, the positions with the
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
+ is provided, it will be added to the attention weight.
+
+ Outputs:
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+ E is the embedding dimension.
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
+ L is the target sequence length, S is the source sequence length.
+ """
+
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == embed_dim_to_check
+ assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
+
+ head_dim = embed_dim // num_heads
+ assert (
+ head_dim * num_heads == embed_dim
+ ), "embed_dim must be divisible by num_heads"
+
+ scaling = float(head_dim) ** -0.5
+
+ if torch.equal(query, key) and torch.equal(key, value):
+ # self-attention
+ q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+ 3, dim=-1
+ )
+
+ elif torch.equal(key, value):
+ # encoder-decoder attention
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = 0
+ _end = embed_dim
+ _w = in_proj_weight[_start:_end, :]
+ if _b is not None:
+ _b = _b[_start:_end]
+ q = nn.functional.linear(query, _w, _b)
+
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = embed_dim
+ _end = None
+ _w = in_proj_weight[_start:, :]
+ if _b is not None:
+ _b = _b[_start:]
+ k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
+
+ else:
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = 0
+ _end = embed_dim
+ _w = in_proj_weight[_start:_end, :]
+ if _b is not None:
+ _b = _b[_start:_end]
+ q = nn.functional.linear(query, _w, _b)
+
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = embed_dim
+ _end = embed_dim * 2
+ _w = in_proj_weight[_start:_end, :]
+ if _b is not None:
+ _b = _b[_start:_end]
+ k = nn.functional.linear(key, _w, _b)
+
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = embed_dim * 2
+ _end = None
+ _w = in_proj_weight[_start:, :]
+ if _b is not None:
+ _b = _b[_start:]
+ v = nn.functional.linear(value, _w, _b)
+
+ if attn_mask is not None:
+ assert (
+ attn_mask.dtype == torch.float32
+ or attn_mask.dtype == torch.float64
+ or attn_mask.dtype == torch.float16
+ or attn_mask.dtype == torch.uint8
+ or attn_mask.dtype == torch.bool
+ ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
+ attn_mask.dtype
+ )
+ if attn_mask.dtype == torch.uint8:
+ warnings.warn(
+ "Byte tensor for attn_mask is deprecated. Use bool tensor instead."
+ )
+ attn_mask = attn_mask.to(torch.bool)
+
+ if attn_mask.dim() == 2:
+ attn_mask = attn_mask.unsqueeze(0)
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
+ raise RuntimeError("The size of the 2D attn_mask is not correct.")
+ elif attn_mask.dim() == 3:
+ if list(attn_mask.size()) != [
+ bsz * num_heads,
+ query.size(0),
+ key.size(0),
+ ]:
+ raise RuntimeError("The size of the 3D attn_mask is not correct.")
+ else:
+ raise RuntimeError(
+ "attn_mask's dimension {} is not supported".format(attn_mask.dim())
+ )
+ # attn_mask's dim is 3 now.
+
+ # convert ByteTensor key_padding_mask to bool
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+ warnings.warn(
+ "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
+ )
+ key_padding_mask = key_padding_mask.to(torch.bool)
+
+ q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim)
+ k = k.contiguous().view(-1, bsz, num_heads, head_dim)
+ v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
+
+ src_len = k.size(0)
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz, "{} == {}".format(
+ key_padding_mask.size(0), bsz
+ )
+ assert key_padding_mask.size(1) == src_len, "{} == {}".format(
+ key_padding_mask.size(1), src_len
+ )
+
+ q = q.transpose(0, 1) # (batch, time1, head, d_k)
+
+ pos_emb_bsz = pos_emb.size(0)
+ assert pos_emb_bsz in (1, bsz) # actually it is 1
+ p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
+ p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
+
+ q_with_bias_u = (q + self._pos_bias_u()).transpose(
+ 1, 2
+ ) # (batch, head, time1, d_k)
+
+ q_with_bias_v = (q + self._pos_bias_v()).transpose(
+ 1, 2
+ ) # (batch, head, time1, d_k)
+
+ # compute attention score
+ # first compute matrix a and matrix c
+ # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
+ k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
+ matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
+
+ # compute matrix b and matrix d
+ matrix_bd = torch.matmul(
+ q_with_bias_v, p.transpose(-2, -1)
+ ) # (batch, head, time1, 2*time1-1)
+ matrix_bd = self.rel_shift(matrix_bd)
+
+ attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2)
+
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
+
+ assert list(attn_output_weights.size()) == [
+ bsz * num_heads,
+ tgt_len,
+ src_len,
+ ]
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_output_weights.masked_fill_(attn_mask, float("-inf"))
+ else:
+ attn_output_weights += attn_mask
+
+ if key_padding_mask is not None:
+ attn_output_weights = attn_output_weights.view(
+ bsz, num_heads, tgt_len, src_len
+ )
+ attn_output_weights = attn_output_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
+ float("-inf"),
+ )
+ attn_output_weights = attn_output_weights.view(
+ bsz * num_heads, tgt_len, src_len
+ )
+
+ attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
+ attn_output_weights = nn.functional.dropout(
+ attn_output_weights, p=dropout_p, training=training
+ )
+
+ attn_output = torch.bmm(attn_output_weights, v)
+ assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
+ attn_output = (
+ attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ )
+ attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
+
+ if need_weights:
+ # average attention weights over heads
+ attn_output_weights = attn_output_weights.view(
+ bsz, num_heads, tgt_len, src_len
+ )
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
+ else:
+ return attn_output, None
+
+
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Conformer model.
+ Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
+
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernerl size of conv layers.
+ bias (bool): Whether to use bias in conv layers (default=True).
+
+ """
+
+ def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
+ """Construct an ConvolutionModule object."""
+ super(ConvolutionModule, self).__init__()
+ # kernerl_size should be a odd number for 'SAME' padding
+ assert (kernel_size - 1) % 2 == 0
+
+ self.pointwise_conv1 = ScaledConv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+
+ # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu).
+ # For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
+ # but sometimes, for some reason, for layer 0 the rms ends up being very large,
+ # between 50 and 100 for different channels. This will cause very peaky and
+ # sparse derivatives for the sigmoid gating function, which will tend to make
+ # the loss function not learn effectively. (for most layers the average absolute values
+ # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
+ # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
+ # layers, which likely breaks down as 0.5 for the "linear" half and
+ # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we
+ # constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
+ # it will be in a better position to start learning something, i.e. to latch onto
+ # the correct range.
+ self.deriv_balancer1 = ActivationBalancer(
+ channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
+ )
+
+ self.depthwise_conv = ScaledConv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ groups=channels,
+ bias=bias,
+ )
+
+ self.deriv_balancer2 = ActivationBalancer(
+ channel_dim=1, min_positive=0.05, max_positive=1.0
+ )
+
+ self.activation = DoubleSwish()
+
+ self.pointwise_conv2 = ScaledConv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ initial_scale=0.25,
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """Compute convolution module.
+
+ Args:
+ x: Input tensor (#time, batch, channels).
+ src_key_padding_mask: the mask for the src keys per batch (optional).
+
+ Returns:
+ Tensor: Output tensor (#time, batch, channels).
+
+ """
+ # exchange the temporal dimension and the feature dimension
+ x = x.permute(1, 2, 0) # (#batch, channels, time).
+
+ # GLU mechanism
+ x = self.pointwise_conv1(x) # (batch, 2*channels, time)
+
+ x = self.deriv_balancer1(x)
+ x = nn.functional.glu(x, dim=1) # (batch, channels, time)
+
+ # 1D Depthwise Conv
+ if src_key_padding_mask is not None:
+ x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
+ x = self.depthwise_conv(x)
+
+ x = self.deriv_balancer2(x)
+ x = self.activation(x)
+
+ x = self.pointwise_conv2(x) # (batch, channel, time)
+
+ return x.permute(2, 0, 1)
+
+
+if __name__ == "__main__":
+ feature_dim = 50
+ c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
+ batch_size = 5
+ seq_len = 20
+ # Just make sure the forward pass runs.
+ f = c(
+ torch.randn(batch_size, seq_len, feature_dim),
+ torch.full((batch_size,), seq_len, dtype=torch.int64),
+ warmup=0.5,
+ )
diff --git a/egs/librispeech/WSASR/conformer_ctc2/decode.py b/egs/librispeech/WSASR/conformer_ctc2/decode.py
new file mode 100755
index 000000000..3fa045533
--- /dev/null
+++ b/egs/librispeech/WSASR/conformer_ctc2/decode.py
@@ -0,0 +1,718 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
+# Fangjun Kuang,
+# Quandong Wang)
+# 2023 Johns Hopkins University (Author: Dongji Gao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from conformer import Conformer
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.decode import get_lattice, one_best_decoding
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.otc_graph_compiler import OtcTrainingGraphCompiler
+from icefall.utils import (
+ AttributeDict,
+ get_texts,
+ load_averaged_model,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--otc-token",
+ type=str,
+ default="",
+ help="OTC token",
+ )
+
+ parser.add_argument(
+ "--blank-bias",
+ type=float,
+ default=0,
+ help="bias (log-prob) added to blank token during decoding",
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=20,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=1,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--method",
+ type=str,
+ default="ctc-greedy-search",
+ help="""Decoding method.
+ Supported values are:
+ - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
+ model, i.e., lang_dir/bpe.model, to convert word pieces to words.
+ It needs neither a lexicon nor an n-gram LM.
+ - (1) ctc-greedy-search. It only use CTC output and a sentence piece
+ model for decoding. It produces the same results with ctc-decoding.
+ - (2) 1best. Extract the best path from the decoding lattice as the
+ decoding result.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--num-decoder-layers",
+ type=int,
+ default=0,
+ help="""Number of decoder layer of transformer decoder.
+ Setting this to 0 will not create the decoder at all (pure CTC model)
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="conformer_ctc2/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_bpe_200",
+ help="The lang dir",
+ )
+
+ parser.add_argument(
+ "--lm-dir",
+ type=str,
+ default="data/lm",
+ help="""The n-gram LM dir.
+ It should contain either G_4_gram.pt or G_4_gram.fst.txt
+ """,
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ params = AttributeDict(
+ {
+ # parameters for conformer
+ "subsampling_factor": 2,
+ "feature_dim": 768,
+ "nhead": 8,
+ "dim_feedforward": 2048,
+ "encoder_dim": 512,
+ "num_encoder_layers": 12,
+ # parameters for decoding
+ "search_beam": 20,
+ "output_beam": 8,
+ "min_active_states": 30,
+ "max_active_states": 10000,
+ "use_double_scores": True,
+ "env_info": get_env_info(),
+ }
+ )
+ return params
+
+
+def ctc_greedy_search(
+ nnet_output: torch.Tensor,
+ memory: torch.Tensor,
+ memory_key_padding_mask: torch.Tensor,
+) -> List[List[int]]:
+ """Apply CTC greedy search
+
+ Args:
+ speech (torch.Tensor): (batch, max_len, feat_dim)
+ speech_length (torch.Tensor): (batch, )
+ Returns:
+ List[List[int]]: best path result
+ """
+ batch_size = memory.shape[1]
+ # Let's assume B = batch_size
+ encoder_out = memory
+ encoder_mask = memory_key_padding_mask
+ maxlen = encoder_out.size(0)
+
+ ctc_probs = nnet_output # (B, maxlen, vocab_size)
+ topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1)
+ topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen)
+ topk_index = topk_index.masked_fill_(encoder_mask, 0) # (B, maxlen)
+ hyps = [hyp.tolist() for hyp in topk_index]
+ scores = topk_prob.max(1)
+ hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
+ return hyps, scores
+
+
+def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
+ # from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
+ new_hyp: List[int] = []
+ cur = 0
+ while cur < len(hyp):
+ if hyp[cur] != 0:
+ new_hyp.append(hyp[cur])
+ prev = cur
+ while cur < len(hyp) and hyp[cur] == hyp[prev]:
+ cur += 1
+ return new_hyp
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ HLG: Optional[k2.Fsa],
+ H: Optional[k2.Fsa],
+ bpe_model: Optional[spm.SentencePieceProcessor],
+ batch: dict,
+ word_table: k2.SymbolTable,
+ sos_id: int,
+ eos_id: int,
+ G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if no rescoring is used, the key is the string `no_rescore`.
+ If LM rescoring is used, the key is the string `lm_scale_xxx`,
+ where `xxx` is the value of `lm_scale`. An example key is
+ `lm_scale_0.7`
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+
+ - params.method is "1best", it uses 1best decoding without LM rescoring.
+
+ model:
+ The neural model.
+ HLG:
+ The decoding graph. Used only when params.method is NOT ctc-decoding.
+ H:
+ The ctc topo. Used only when params.method is ctc-decoding.
+ bpe_model:
+ The BPE model. Used only when params.method is ctc-decoding.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ word_table:
+ The word symbol table.
+ sos_id:
+ The token ID of the SOS.
+ eos_id:
+ The token ID of the EOS.
+ G:
+ An LM. It is not None when params.method is "nbest-rescoring"
+ or "whole-lattice-rescoring". In general, the G in HLG
+ is a 3-gram LM, while this G is a 4-gram LM.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict. Note: If it decodes to nothing, then return None.
+ """
+ if HLG is not None:
+ device = HLG.device
+ else:
+ device = H.device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+
+ nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
+ # nnet_output is (N, T, C)
+ nnet_output[:, :, 0] += params.blank_bias
+
+ supervision_segments = torch.stack(
+ (
+ supervisions["sequence_idx"],
+ torch.div(
+ supervisions["start_frame"],
+ params.subsampling_factor,
+ rounding_mode="trunc",
+ ),
+ torch.div(
+ supervisions["num_frames"],
+ params.subsampling_factor,
+ rounding_mode="trunc",
+ ),
+ ),
+ 1,
+ ).to(torch.int32)
+
+ if H is None:
+ assert HLG is not None
+ decoding_graph = HLG
+ else:
+ assert HLG is None
+ assert bpe_model is not None
+ decoding_graph = H
+
+ lattice = get_lattice(
+ nnet_output=nnet_output,
+ decoding_graph=decoding_graph,
+ supervision_segments=supervision_segments,
+ search_beam=params.search_beam,
+ output_beam=params.output_beam,
+ min_active_states=params.min_active_states,
+ max_active_states=params.max_active_states,
+ subsampling_factor=params.subsampling_factor + 2,
+ )
+
+ if params.method == "ctc-decoding":
+ best_path = one_best_decoding(
+ lattice=lattice, use_double_scores=params.use_double_scores
+ )
+ # Note: `best_path.aux_labels` contains token IDs, not word IDs
+ # since we are using H, not HLG here.
+ #
+ # token_ids is a lit-of-list of IDs
+ token_ids = get_texts(best_path)
+
+ # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+ hyps = bpe_model.decode(token_ids)
+
+ # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+ hyps = [s.split() for s in hyps]
+ key = "ctc-decoding"
+ return {key: hyps}
+
+ if params.method == "ctc-greedy-search":
+ hyps, _ = ctc_greedy_search(
+ nnet_output,
+ memory,
+ memory_key_padding_mask,
+ )
+
+ # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+ hyps = bpe_model.decode(hyps)
+
+ # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+ hyps = [s.split() for s in hyps]
+ key = "ctc-greedy-search"
+ return {key: hyps}
+
+ if params.method in ["1best"]:
+ best_path = one_best_decoding(
+ lattice=lattice, use_double_scores=params.use_double_scores
+ )
+ key = "no_rescore"
+
+ hyps = get_texts(best_path)
+ hyps = [[word_table[i] for i in ids] for ids in hyps]
+
+ return {key: hyps}
+ else:
+ assert False, f"Unsupported decoding method: {params.method}"
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ HLG: Optional[k2.Fsa],
+ H: Optional[k2.Fsa],
+ bpe_model: Optional[spm.SentencePieceProcessor],
+ word_table: k2.SymbolTable,
+ sos_id: int,
+ eos_id: int,
+ G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ HLG:
+ The decoding graph. Used only when params.method is NOT ctc-decoding.
+ H:
+ The ctc topo. Used only when params.method is ctc-decoding.
+ bpe_model:
+ The BPE model. Used only when params.method is ctc-decoding.
+ word_table:
+ It is the word symbol table.
+ sos_id:
+ The token ID for SOS.
+ eos_id:
+ The token ID for EOS.
+ G:
+ An LM. It is not None when params.method is "nbest-rescoring"
+ or "whole-lattice-rescoring". In general, the G in HLG
+ is a 3-gram LM, while this G is a 4-gram LM.
+ Returns:
+ Return a dict, whose key may be "no-rescore" if no LM rescoring
+ is used, or it may be "lm_scale_0.7" if LM rescoring is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ HLG=HLG,
+ H=H,
+ bpe_model=bpe_model,
+ batch=batch,
+ word_table=word_table,
+ G=G,
+ sos_id=sos_id,
+ eos_id=eos_id,
+ )
+
+ if hyps_dict is not None:
+ for lm_scale, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[lm_scale].extend(this_batch)
+ else:
+ assert len(results) > 0, "It should not decode to empty in the first batch!"
+ this_batch = []
+ hyp_words = []
+ for ref_text in texts:
+ ref_words = ref_text.split()
+ this_batch.append((ref_words, hyp_words))
+
+ for lm_scale in results.keys():
+ results[lm_scale].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % 100 == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ if params.method in ("attention-decoder", "rnn-lm"):
+ # Set it to False since there are too many logs.
+ enable_log = False
+ else:
+ enable_log = True
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ if enable_log:
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=enable_log
+ )
+ test_set_wers[key] = wer
+
+ if enable_log:
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+ args.lang_dir = Path(args.lang_dir)
+ args.lm_dir = Path(args.lm_dir)
+ assert "▁" not in args.otc_token
+ args.otc_token = f"▁{args.otc_token}"
+
+ params = get_params()
+ params.update(vars(args))
+
+ setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
+ logging.info("Decoding started")
+ logging.info(params)
+
+ lexicon = Lexicon(params.lang_dir)
+ # remove otc_token from decoding units
+ max_token_id = max(lexicon.tokens) - 1
+ num_classes = max_token_id + 1 # +1 for the blank
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ graph_compiler = OtcTrainingGraphCompiler(
+ params.lang_dir,
+ params.otc_token,
+ device=device,
+ sos_token="",
+ eos_token="",
+ )
+ sos_id = graph_compiler.sos_id
+ eos_id = graph_compiler.eos_id
+
+ params.num_classes = num_classes
+ params.sos_id = sos_id
+ params.eos_id = eos_id
+
+ if params.method == "ctc-decoding" or params.method == "ctc-greedy-search":
+ HLG = None
+ H = k2.ctc_topo(
+ max_token=max_token_id,
+ modified=False,
+ device=device,
+ )
+ bpe_model = spm.SentencePieceProcessor()
+ bpe_model.load(str(params.lang_dir / "bpe.model"))
+ else:
+ H = None
+ bpe_model = None
+ HLG = k2.Fsa.from_dict(
+ torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
+ )
+ assert HLG.requires_grad is False
+
+ if not hasattr(HLG, "lm_scores"):
+ HLG.lm_scores = HLG.scores.clone()
+
+ G = None
+
+ model = Conformer(
+ num_features=params.feature_dim,
+ nhead=params.nhead,
+ d_model=params.encoder_dim,
+ num_classes=num_classes,
+ subsampling_factor=params.subsampling_factor,
+ num_encoder_layers=params.num_encoder_layers,
+ num_decoder_layers=params.num_decoder_layers,
+ )
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ test_clean_cuts = librispeech.test_clean_cuts()
+ test_other_cuts = librispeech.test_other_cuts()
+
+ test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
+ test_other_dl = librispeech.test_dataloaders(test_other_cuts)
+
+ test_sets = ["test-clean", "test-other"]
+ test_dl = [test_clean_dl, test_other_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ HLG=HLG,
+ H=H,
+ bpe_model=bpe_model,
+ word_table=lexicon.word_table,
+ G=G,
+ sos_id=sos_id,
+ eos_id=eos_id,
+ )
+
+ save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+
+ logging.info("Done!")
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/WSASR/conformer_ctc2/export.py b/egs/librispeech/WSASR/conformer_ctc2/export.py
new file mode 120000
index 000000000..5f484e391
--- /dev/null
+++ b/egs/librispeech/WSASR/conformer_ctc2/export.py
@@ -0,0 +1 @@
+../../ASR/conformer_ctc2/export.py
\ No newline at end of file
diff --git a/egs/librispeech/WSASR/conformer_ctc2/label_smoothing.py b/egs/librispeech/WSASR/conformer_ctc2/label_smoothing.py
new file mode 120000
index 000000000..c050ea637
--- /dev/null
+++ b/egs/librispeech/WSASR/conformer_ctc2/label_smoothing.py
@@ -0,0 +1 @@
+../../ASR/conformer_ctc/label_smoothing.py
\ No newline at end of file
diff --git a/egs/librispeech/WSASR/conformer_ctc2/optim.py b/egs/librispeech/WSASR/conformer_ctc2/optim.py
new file mode 120000
index 000000000..db836b5e0
--- /dev/null
+++ b/egs/librispeech/WSASR/conformer_ctc2/optim.py
@@ -0,0 +1 @@
+../../ASR/pruned_transducer_stateless2/optim.py
\ No newline at end of file
diff --git a/egs/librispeech/WSASR/conformer_ctc2/scaling.py b/egs/librispeech/WSASR/conformer_ctc2/scaling.py
new file mode 120000
index 000000000..bd0abfeee
--- /dev/null
+++ b/egs/librispeech/WSASR/conformer_ctc2/scaling.py
@@ -0,0 +1 @@
+../../ASR/pruned_transducer_stateless2/scaling.py
\ No newline at end of file
diff --git a/egs/librispeech/WSASR/conformer_ctc2/subsampling.py b/egs/librispeech/WSASR/conformer_ctc2/subsampling.py
new file mode 100644
index 000000000..2ba802866
--- /dev/null
+++ b/egs/librispeech/WSASR/conformer_ctc2/subsampling.py
@@ -0,0 +1,184 @@
+#!/usr/bin/env python3
+# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
+# 2022 Xiaomi Corporation (author: Quandong Wang)
+# 2023 Johns Hopkins University (author: Dongji Gao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from scaling import (
+ ActivationBalancer,
+ BasicNorm,
+ DoubleSwish,
+ ScaledConv2d,
+ ScaledLinear,
+)
+
+
+class Conv2dSubsampling(torch.nn.Module):
+ """Convolutional 2D subsampling (to 1/4 length).
+
+ Convert an input of shape (N, T, idim) to an output
+ with shape (N, T', odim), where
+ T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
+
+ It is based on
+ https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ layer1_channels: int = 8,
+ layer2_channels: int = 32,
+ layer3_channels: int = 128,
+ ) -> None:
+ """
+ Args:
+ in_channels:
+ Number of channels in. The input shape is (N, T, in_channels).
+ Caution: It requires: T >=7, in_channels >=7
+ out_channels
+ Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels)
+ layer1_channels:
+ Number of channels in layer1
+ layer1_channels:
+ Number of channels in layer2
+ """
+ assert in_channels >= 7
+ super().__init__()
+
+ self.conv = torch.nn.Sequential(
+ ScaledConv2d(
+ in_channels=1,
+ out_channels=layer1_channels,
+ kernel_size=3,
+ padding=1,
+ ),
+ ActivationBalancer(channel_dim=1),
+ DoubleSwish(),
+ ScaledConv2d(
+ in_channels=layer1_channels,
+ out_channels=layer2_channels,
+ kernel_size=3,
+ stride=2,
+ ),
+ ActivationBalancer(channel_dim=1),
+ DoubleSwish(),
+ ScaledConv2d(
+ in_channels=layer2_channels,
+ out_channels=layer3_channels,
+ kernel_size=3,
+ stride=2,
+ ),
+ ActivationBalancer(channel_dim=1),
+ DoubleSwish(),
+ )
+ self.out = ScaledLinear(
+ layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
+ )
+ # set learn_eps=False because out_norm is preceded by `out`, and `out`
+ # itself has learned scale, so the extra degree of freedom is not
+ # needed.
+ self.out_norm = BasicNorm(out_channels, learn_eps=False)
+ # constrain median of output to be close to zero.
+ self.out_balancer = ActivationBalancer(
+ channel_dim=-1, min_positive=0.45, max_positive=0.55
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Subsample x.
+
+ Args:
+ x:
+ Its shape is (N, T, idim).
+
+ Returns:
+ Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
+ """
+ # On entry, x is (N, T, idim)
+ x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
+ x = self.conv(x)
+ # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
+ b, c, t, f = x.size()
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ # Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
+ x = self.out_norm(x)
+ x = self.out_balancer(x)
+ return x
+
+
+class Conv2dSubsampling2(torch.nn.Module):
+ """Convolutional 2D subsampling (to 1/2 length).
+
+ Convert an input of shape (N, T, idim) to an output
+ with shape (N, T', odim) where
+ T' = (T - 1) // 2 - 2, which approximates T' == T // 2
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ layer1_channels: int = 8,
+ layer2_channels: int = 32,
+ layer3_channels: int = 128,
+ ) -> None:
+ assert in_channels >= 7
+ super().__init__()
+
+ self.conv = torch.nn.Sequential(
+ ScaledConv2d(
+ in_channels=1,
+ out_channels=layer1_channels,
+ kernel_size=3,
+ padding=1,
+ ),
+ ActivationBalancer(channel_dim=1),
+ DoubleSwish(),
+ ScaledConv2d(
+ in_channels=layer1_channels,
+ out_channels=layer2_channels,
+ kernel_size=3,
+ stride=2,
+ ),
+ ActivationBalancer(channel_dim=1),
+ DoubleSwish(),
+ ScaledConv2d(
+ in_channels=layer2_channels,
+ out_channels=layer3_channels,
+ kernel_size=3,
+ stride=1,
+ ),
+ ActivationBalancer(channel_dim=1),
+ DoubleSwish(),
+ )
+ self.out = ScaledLinear(
+ layer3_channels * ((in_channels - 1) // 2 - 2), out_channels
+ )
+ self.out_norm = BasicNorm(out_channels, learn_eps=False)
+ self.out_balancer = ActivationBalancer(
+ channel_dim=-1, min_positive=0.45, max_positive=0.55
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = x.unsqueeze(1)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ x = self.out_norm(x)
+ x = self.out_balancer(x)
+ return x
diff --git a/egs/librispeech/WSASR/conformer_ctc2/train.py b/egs/librispeech/WSASR/conformer_ctc2/train.py
new file mode 100755
index 000000000..fe6c5af91
--- /dev/null
+++ b/egs/librispeech/WSASR/conformer_ctc2/train.py
@@ -0,0 +1,1115 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Quandong Wang)
+# 2023 Johns Hopkins University (author: Dongji Gao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./conformer_ctc2/train.py \
+ --world-size 4 \
+ --manifest-dir data/ssl \
+ --train-manifest librispeech_cuts_train-clean-100_0.17_0.17_0.17.jsonl.gz \
+ --exp-dir conformer_ctc2/exp \
+ --lang-dir data/lang_bpe_200 \
+ --otc-token "" \
+ --allow-bypass-arc true \
+ --allow-self-loop-arc true \
+ --initial-bypass-weight -19 \
+ --initial-self-loop-weight 3.75 \
+ --bypass-weight-decay 0.975 \
+ --self-loop-weight-decay 0.999 \
+ --show-alignment true
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from conformer import Conformer
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from optim import Eden, Eve
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.decode import one_best_decoding
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.otc_graph_compiler import OtcTrainingGraphCompiler
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ encode_supervisions_otc,
+ get_texts,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=20,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="conformer_ctc2/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_bpe_200",
+ help="""The lang dir
+ It contains language related input files such as
+ "lexicon.txt"
+ """,
+ )
+
+ parser.add_argument(
+ "--initial-lr",
+ type=float,
+ default=0.003,
+ help="""The initial learning rate. This value should not need to be
+ changed.""",
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=5000,
+ help="""Number of steps that affects how rapidly the learning rate decreases.
+ We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=6,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--att-rate",
+ type=float,
+ default=0.0,
+ help="""The attention rate.
+ The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss
+ """,
+ )
+
+ parser.add_argument(
+ "--num-decoder-layers",
+ type=int,
+ default=0,
+ help="""Number of decoder layer of transformer decoder.
+ Setting this to 0 will not create the decoder at all (pure CTC model)
+ """,
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=8000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=10,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=100,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ parser.add_argument(
+ "--otc-token",
+ type=str,
+ default="_",
+ help="OTC token",
+ )
+
+ parser.add_argument(
+ "--allow-bypass-arc",
+ type=str2bool,
+ default=True,
+ help="""Whether to add bypass arc to training graph for substitution
+ and insertion errors (wrong or extra words in the transcript).""",
+ )
+
+ parser.add_argument(
+ "--allow-self-loop-arc",
+ type=str2bool,
+ default=True,
+ help="""Whether to self-loop bypass arc to training graph for deletion errors
+ (missing words in the transcript).""",
+ )
+
+ parser.add_argument(
+ "--initial-bypass-weight",
+ type=float,
+ default=0.0,
+ help="Initial weight associated with bypass arc",
+ )
+
+ parser.add_argument(
+ "--initial-self-loop-weight",
+ type=float,
+ default=0.0,
+ help="Initial weight associated with self-loop arc",
+ )
+
+ parser.add_argument(
+ "--bypass-weight-decay",
+ type=float,
+ default=1.0,
+ help="""Weight decay factor of bypass arc weight:
+ bypass_arc_weight = intial_bypass_weight * bypass_weight_decay ^ ith-epoch""",
+ )
+
+ parser.add_argument(
+ "--self-loop-weight-decay",
+ type=float,
+ default=1.0,
+ help="""Weight decay factor of self-loop arc weight:
+ self_loop_arc_weight = intial_self_loop_weight * self_loop_weight_decay ^ ith-epoch""",
+ )
+
+ parser.add_argument(
+ "--show-alignment",
+ type=str2bool,
+ default=True,
+ help="Whether to print OTC alignment during training",
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - beam_size: It is used in k2.ctc_loss
+
+ - reduction: It is used in k2.ctc_loss
+
+ - use_double_scores: It is used in k2.ctc_loss
+
+ - warm_step: The warm_step for Noam optimizer.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 1,
+ "reset_interval": 200,
+ "valid_interval": 800, # For the 100h subset, use 800
+ "alignment_interval": 25,
+ # parameters for conformer
+ "feature_dim": 768,
+ "subsampling_factor": 2,
+ "encoder_dim": 512,
+ "nhead": 8,
+ "dim_feedforward": 2048,
+ "num_encoder_layers": 12,
+ # parameters for ctc loss
+ "beam_size": 10,
+ "reduction": "sum",
+ "use_double_scores": True,
+ # parameters for Noam
+ "model_warm_step": 3000, # arg given to model, not for lrate
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ batch: dict,
+ graph_compiler: OtcTrainingGraphCompiler,
+ is_training: bool,
+ warmup: float = 2.0,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute OTC loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Conformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ graph_compiler:
+ It is used to build a decoding graph from a ctc topo and training
+ transcript. The training transcript is contained in the given `batch`,
+ while the ctc topo is built when this compiler is instantiated.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ with torch.set_grad_enabled(is_training):
+ nnet_output, encoder_memory, memory_mask = model(
+ feature, supervisions, warmup=warmup
+ )
+ # Set the probability of OTC token as the average of non-blank tokens
+ # under the assumption that blank is the first and
+ # OTC token is the last token in tokens.txt
+ _, _, V = nnet_output.shape
+
+ otc_token_log_prob = torch.logsumexp(
+ nnet_output[:, :, 1:], dim=-1, keepdim=True
+ ) - torch.log(torch.tensor([V - 1])).to(device)
+
+ nnet_output = torch.cat([nnet_output, otc_token_log_prob], dim=-1)
+
+ # NOTE: We need `encode_supervisions` to sort sequences with
+ # different duration in decreasing order, required by
+ # `k2.intersect_dense` called in `k2.ctc_loss`
+ supervision_segments, texts, utt_ids, verbatim_texts = encode_supervisions_otc(
+ supervisions, subsampling_factor=params.subsampling_factor
+ )
+
+ bypass_weight = graph_compiler.initial_bypass_weight * (
+ graph_compiler.bypass_weight_decay ** (params.cur_epoch - 1)
+ )
+ self_loop_weight = graph_compiler.initial_self_loop_weight * (
+ graph_compiler.self_loop_weight_decay ** (params.cur_epoch - 1)
+ )
+
+ decoding_graph = graph_compiler.compile(
+ texts=texts,
+ allow_bypass_arc=params.allow_bypass_arc,
+ allow_self_loop_arc=params.allow_self_loop_arc,
+ bypass_weight=bypass_weight,
+ self_loop_weight=self_loop_weight,
+ )
+
+ dense_fsa_vec = k2.DenseFsaVec(
+ nnet_output,
+ supervision_segments,
+ allow_truncate=3,
+ )
+
+ otc_loss = k2.ctc_loss(
+ decoding_graph=decoding_graph,
+ dense_fsa_vec=dense_fsa_vec,
+ output_beam=params.beam_size,
+ reduction=params.reduction,
+ use_double_scores=params.use_double_scores,
+ )
+
+ assert params.att_rate == 0.0
+ loss = otc_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+ info["otc_loss"] = otc_loss.detach().cpu().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+
+ # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
+ info["utterances"] = feature.size(0)
+ # averaged input duration in frames over utterances
+ info["utt_duration"] = feature_lens.sum().item()
+ # averaged padding proportion over utterances
+ info["utt_pad_proportion"] = (
+ ((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
+ )
+
+ if params.show_alignment:
+ if params.batch_idx_train % params.alignment_interval == 0:
+ for index, utt_id in enumerate(utt_ids):
+ verbatim_text = verbatim_texts[index]
+ utt_id = utt_ids[index]
+
+ lattice = k2.intersect_dense(
+ decoding_graph,
+ dense_fsa_vec,
+ params.beam_size,
+ )
+ best_path = one_best_decoding(
+ lattice=lattice,
+ use_double_scores=params.use_double_scores,
+ )
+ hyp_ids = get_texts(best_path)[index]
+ hyp_text_list = [graph_compiler.token_table[i] for i in hyp_ids]
+ hyp_text = "".join(hyp_text_list).replace("▁", " ")
+
+ logging.info(f"[utterance id]: {utt_id}")
+ logging.info(f"[verbatim text]: {verbatim_text}")
+ logging.info(f"[best alignment]: {hyp_text}")
+ logging.info(bypass_weight)
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ graph_compiler: OtcTrainingGraphCompiler,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ graph_compiler=graph_compiler,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ graph_compiler: OtcTrainingGraphCompiler,
+ scheduler: LRSchedulerType,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ graph_compiler:
+ It is used to convert transcripts to FSAs.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(train_dl):
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ graph_compiler=graph_compiler,
+ is_training=True,
+ warmup=(params.batch_idx_train / params.model_warm_step),
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ # scaler.scale(loss).backward()
+
+ try:
+ # loss.backward()
+ scaler.scale(loss).backward()
+ except RuntimeError as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(f"failing batch size:{batch_size} ")
+ raise
+
+ scheduler.step_batch(params.batch_idx_train)
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+
+ if params.print_diagnostics and batch_idx == 30:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = scheduler.get_last_lr()[0]
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}"
+ )
+ if loss_info["otc_loss"] == float("inf"):
+ logging.error("Your loss contains inf, something goes wrong")
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+
+ if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+ params.valid_interval = 1600
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+ logging.info(params)
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+
+ graph_compiler = OtcTrainingGraphCompiler(
+ params.lang_dir,
+ otc_token=params.otc_token,
+ device=device,
+ initial_bypass_weight=params.initial_bypass_weight,
+ initial_self_loop_weight=params.initial_self_loop_weight,
+ bypass_weight_decay=params.bypass_weight_decay,
+ self_loop_weight_decay=params.self_loop_weight_decay,
+ )
+
+ # remove OTC token as it is the average of all non-blank tokens
+ max_token_id = graph_compiler.get_max_token_id() - 1
+ # add blank
+ num_classes = max_token_id + 1
+
+ logging.info("About to create model")
+ model = Conformer(
+ num_features=params.feature_dim,
+ nhead=params.nhead,
+ d_model=params.encoder_dim,
+ num_classes=num_classes,
+ subsampling_factor=params.subsampling_factor,
+ num_encoder_layers=params.num_encoder_layers,
+ num_decoder_layers=params.num_decoder_layers,
+ )
+
+ print(model)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank])
+
+ optimizer = Eve(model.parameters(), lr=params.initial_lr)
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ diagnostic = diagnostics.attach_diagnostics(model)
+
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ train_cuts = librispeech.train_clean_100_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ return 1.0 <= c.duration <= 20.0
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = librispeech.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = librispeech.dev_clean_cuts()
+ valid_cuts += librispeech.dev_other_cuts()
+ valid_dl = librispeech.valid_dataloaders(valid_cuts)
+
+ if params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ graph_compiler=graph_compiler,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ graph_compiler=graph_compiler,
+ scheduler=scheduler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ graph_compiler: OtcTrainingGraphCompiler,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ # warmup = 0.0 is so that the derivs for the pruned loss stay zero
+ # (i.e. are not remembered by the decaying-average in adam), because
+ # we want to avoid these params being subject to shrinkage in adam.
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ graph_compiler=graph_compiler,
+ is_training=True,
+ warmup=0.0,
+ )
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+ except RuntimeError as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ raise
+
+
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+ assert "▁" not in args.otc_token
+ args.otc_token = f"▁{args.otc_token}"
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/WSASR/conformer_ctc2/transformer.py b/egs/librispeech/WSASR/conformer_ctc2/transformer.py
new file mode 100644
index 000000000..41e6cd357
--- /dev/null
+++ b/egs/librispeech/WSASR/conformer_ctc2/transformer.py
@@ -0,0 +1,1055 @@
+# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu)
+# Copyright 2022 Xiaomi Corp. (author: Quandong Wang)
+# 2023 Johns Hopkins University (author: Dongji Gao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import math
+from typing import Dict, List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from attention import MultiheadAttention
+from label_smoothing import LabelSmoothingLoss
+from scaling import (
+ ActivationBalancer,
+ BasicNorm,
+ DoubleSwish,
+ ScaledEmbedding,
+ ScaledLinear,
+)
+from subsampling import Conv2dSubsampling
+from torch.nn.utils.rnn import pad_sequence
+
+# Note: TorchScript requires Dict/List/etc. to be fully typed.
+Supervisions = Dict[str, torch.Tensor]
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ num_features: int,
+ num_classes: int,
+ subsampling_factor: int = 4,
+ d_model: int = 256,
+ nhead: int = 4,
+ dim_feedforward: int = 2048,
+ num_encoder_layers: int = 12,
+ num_decoder_layers: int = 6,
+ dropout: float = 0.1,
+ layer_dropout: float = 0.075,
+ ) -> None:
+ """
+ Args:
+ num_features:
+ The input dimension of the model.
+ num_classes:
+ The output dimension of the model.
+ subsampling_factor:
+ Number of output frames is num_in_frames // subsampling_factor.
+ Currently, subsampling_factor MUST be 4.
+ d_model:
+ Attention dimension.
+ nhead:
+ Number of heads in multi-head attention.
+ Must satisfy d_model // nhead == 0.
+ dim_feedforward:
+ The output dimension of the feedforward layers in encoder/decoder.
+ num_encoder_layers:
+ Number of encoder layers.
+ num_decoder_layers:
+ Number of decoder layers.
+ dropout:
+ Dropout in encoder/decoder.
+ layer_dropout (float): layer-dropout rate.
+ """
+ super().__init__()
+
+ self.num_features = num_features
+ self.num_classes = num_classes
+ self.subsampling_factor = subsampling_factor
+ if subsampling_factor != 4 and subsampling_factor != 2:
+ raise NotImplementedError("Support only 'subsampling_factor=4 or 2'.")
+
+ # self.encoder_embed converts the input of shape (N, T, num_classes)
+ # to the shape (N, T//subsampling_factor, d_model).
+ # That is, it does two things simultaneously:
+ # (1) subsampling: T -> T//subsampling_factor
+ # (2) embedding: num_classes -> d_model
+ self.encoder_embed = Conv2dSubsampling(num_features, d_model)
+
+ self.encoder_pos = PositionalEncoding(d_model, dropout)
+
+ encoder_layer = TransformerEncoderLayer(
+ d_model=d_model,
+ nhead=nhead,
+ dim_feedforward=dim_feedforward,
+ dropout=dropout,
+ layer_dropout=layer_dropout,
+ )
+
+ self.encoder = TransformerEncoder(
+ encoder_layer=encoder_layer,
+ num_layers=num_encoder_layers,
+ )
+
+ # TODO(fangjun): remove dropout
+ self.encoder_output_layer = nn.Sequential(
+ nn.Dropout(p=dropout), ScaledLinear(d_model, num_classes, bias=True)
+ )
+
+ if num_decoder_layers > 0:
+ self.decoder_num_class = (
+ self.num_classes
+ ) # bpe model already has sos/eos symbol
+
+ self.decoder_embed = ScaledEmbedding(
+ num_embeddings=self.decoder_num_class, embedding_dim=d_model
+ )
+ self.decoder_pos = PositionalEncoding(d_model, dropout)
+
+ decoder_layer = TransformerDecoderLayer(
+ d_model=d_model,
+ nhead=nhead,
+ dim_feedforward=dim_feedforward,
+ dropout=dropout,
+ )
+
+ self.decoder = TransformerDecoder(
+ decoder_layer=decoder_layer,
+ num_layers=num_decoder_layers,
+ )
+
+ self.decoder_output_layer = ScaledLinear(
+ d_model, self.decoder_num_class, bias=True
+ )
+
+ self.decoder_criterion = LabelSmoothingLoss()
+ else:
+ self.decoder_criterion = None
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ supervision: Optional[Supervisions] = None,
+ warmup: float = 1.0,
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ """
+ Args:
+ x:
+ The input tensor. Its shape is (N, T, C).
+ supervision:
+ Supervision in lhotse format.
+ See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
+ (CAUTION: It contains length information, i.e., start and number of
+ frames, before subsampling)
+ warmup:
+ A floating point value that gradually increases from 0 throughout
+ training; when it is >= 1.0 we are "fully warmed up". It is used
+ to turn modules on sequentially.
+
+ Returns:
+ Return a tuple containing 3 tensors:
+ - CTC output for ctc decoding. Its shape is (N, T, C)
+ - Encoder output with shape (T, N, C). It can be used as key and
+ value for the decoder.
+ - Encoder output padding mask. It can be used as
+ memory_key_padding_mask for the decoder. Its shape is (N, T).
+ It is None if `supervision` is None.
+ """
+
+ encoder_memory, memory_key_padding_mask = self.run_encoder(
+ x, supervision, warmup
+ )
+
+ x = self.ctc_output(encoder_memory)
+ return x, encoder_memory, memory_key_padding_mask
+
+ def run_encoder(
+ self,
+ x: torch.Tensor,
+ supervisions: Optional[Supervisions] = None,
+ warmup: float = 1.0,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Run the transformer encoder.
+
+ Args:
+ x:
+ The model input. Its shape is (N, T, C).
+ supervisions:
+ Supervision in lhotse format.
+ See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
+ CAUTION: It contains length information, i.e., start and number of
+ frames, before subsampling
+ It is read directly from the batch, without any sorting. It is used
+ to compute the encoder padding mask, which is used as memory key
+ padding mask for the decoder.
+ Returns:
+ Return a tuple with two tensors:
+ - The encoder output, with shape (T, N, C)
+ - encoder padding mask, with shape (N, T).
+ The mask is None if `supervisions` is None.
+ It is used as memory key padding mask in the decoder.
+ """
+ x = self.encoder_embed(x)
+ x = self.encoder_pos(x)
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+ mask = encoder_padding_mask(x.size(0), supervisions)
+ mask = mask.to(x.device) if mask is not None else None
+ x = self.encoder(x, src_key_padding_mask=mask, warmup=warmup) # (T, N, C)
+
+ return x, mask
+
+ def ctc_output(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x:
+ The output tensor from the transformer encoder.
+ Its shape is (T, N, C)
+
+ Returns:
+ Return a tensor that can be used for CTC decoding.
+ Its shape is (N, T, C)
+ """
+ x = self.encoder_output_layer(x)
+ x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+ x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
+ return x
+
+ @torch.jit.export
+ def decoder_forward(
+ self,
+ memory: torch.Tensor,
+ memory_key_padding_mask: torch.Tensor,
+ token_ids: List[List[int]],
+ sos_id: int,
+ eos_id: int,
+ ) -> torch.Tensor:
+ """
+ Args:
+ memory:
+ It's the output of the encoder with shape (T, N, C)
+ memory_key_padding_mask:
+ The padding mask from the encoder.
+ token_ids:
+ A list-of-list IDs. Each sublist contains IDs for an utterance.
+ The IDs can be either phone IDs or word piece IDs.
+ sos_id:
+ sos token id
+ eos_id:
+ eos token id
+
+ Returns:
+ A scalar, the **sum** of label smoothing loss over utterances
+ in the batch without any normalization.
+ """
+ ys_in = add_sos(token_ids, sos_id=sos_id)
+ ys_in = [torch.tensor(y) for y in ys_in]
+ ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
+
+ ys_out = add_eos(token_ids, eos_id=eos_id)
+ ys_out = [torch.tensor(y) for y in ys_out]
+ ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
+
+ device = memory.device
+ ys_in_pad = ys_in_pad.to(device)
+ ys_out_pad = ys_out_pad.to(device)
+
+ tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
+
+ tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
+ # TODO: Use length information to create the decoder padding mask
+ # We set the first column to False since the first column in ys_in_pad
+ # contains sos_id, which is the same as eos_id in our current setting.
+ tgt_key_padding_mask[:, 0] = False
+
+ tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C)
+ tgt = self.decoder_pos(tgt)
+ tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+ pred_pad = self.decoder(
+ tgt=tgt,
+ memory=memory,
+ tgt_mask=tgt_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask,
+ ) # (T, N, C)
+ pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
+ pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C)
+
+ decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad)
+
+ return decoder_loss
+
+ @torch.jit.export
+ def decoder_nll(
+ self,
+ memory: torch.Tensor,
+ memory_key_padding_mask: torch.Tensor,
+ token_ids: List[torch.Tensor],
+ sos_id: int,
+ eos_id: int,
+ ) -> torch.Tensor:
+ """
+ Args:
+ memory:
+ It's the output of the encoder with shape (T, N, C)
+ memory_key_padding_mask:
+ The padding mask from the encoder.
+ token_ids:
+ A list-of-list IDs (e.g., word piece IDs).
+ Each sublist represents an utterance.
+ sos_id:
+ The token ID for SOS.
+ eos_id:
+ The token ID for EOS.
+ Returns:
+ A 2-D tensor of shape (len(token_ids), max_token_length)
+ representing the cross entropy loss (i.e., negative log-likelihood).
+ """
+ # The common part between this function and decoder_forward could be
+ # extracted as a separate function.
+ if isinstance(token_ids[0], torch.Tensor):
+ # This branch is executed by torchscript in C++.
+ # See https://github.com/k2-fsa/k2/pull/870
+ # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286
+ token_ids = [tolist(t) for t in token_ids]
+
+ ys_in = add_sos(token_ids, sos_id=sos_id)
+ ys_in = [torch.tensor(y) for y in ys_in]
+ ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
+
+ ys_out = add_eos(token_ids, eos_id=eos_id)
+ ys_out = [torch.tensor(y) for y in ys_out]
+ ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
+
+ device = memory.device
+ ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
+ ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
+
+ tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
+
+ tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
+ # TODO: Use length information to create the decoder padding mask
+ # We set the first column to False since the first column in ys_in_pad
+ # contains sos_id, which is the same as eos_id in our current setting.
+ tgt_key_padding_mask[:, 0] = False
+
+ tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
+ tgt = self.decoder_pos(tgt)
+ tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
+ pred_pad = self.decoder(
+ tgt=tgt,
+ memory=memory,
+ tgt_mask=tgt_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask,
+ ) # (T, B, F)
+ pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F)
+ pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F)
+ # nll: negative log-likelihood
+ nll = torch.nn.functional.cross_entropy(
+ pred_pad.view(-1, self.decoder_num_class),
+ ys_out_pad.view(-1),
+ ignore_index=-1,
+ reduction="none",
+ )
+
+ nll = nll.view(pred_pad.shape[0], -1)
+
+ return nll
+
+
+class TransformerEncoderLayer(nn.Module):
+ """
+ Modified from torch.nn.TransformerEncoderLayer.
+
+ Args:
+ d_model:
+ the number of expected features in the input (required).
+ nhead:
+ the number of heads in the multiheadattention models (required).
+ dim_feedforward:
+ the dimension of the feedforward network model (default=2048).
+ dropout:
+ the dropout value (default=0.1).
+ activation:
+ the activation function of intermediate layer, relu or
+ gelu (default=relu).
+
+ Examples::
+ >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
+ >>> src = torch.rand(10, 32, 512)
+ >>> out = encoder_layer(src)
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ nhead: int,
+ dim_feedforward: int = 2048,
+ dropout: float = 0.1,
+ layer_dropout: float = 0.075,
+ ) -> None:
+ super(TransformerEncoderLayer, self).__init__()
+
+ self.layer_dropout = layer_dropout
+
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=0.0)
+ # Implementation of Feedforward model
+
+ self.feed_forward = nn.Sequential(
+ ScaledLinear(d_model, dim_feedforward),
+ ActivationBalancer(channel_dim=-1),
+ DoubleSwish(),
+ nn.Dropout(dropout),
+ ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
+ )
+
+ self.norm_final = BasicNorm(d_model)
+
+ # try to ensure the output is close to zero-mean (or at least, zero-median).
+ self.balancer = ActivationBalancer(
+ channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
+ )
+
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(
+ self,
+ src: torch.Tensor,
+ src_mask: Optional[torch.Tensor] = None,
+ src_key_padding_mask: Optional[torch.Tensor] = None,
+ warmup: float = 1.0,
+ ) -> torch.Tensor:
+ """
+ Pass the input through the encoder layer.
+
+ Args:
+ src: the sequence to the encoder layer (required).
+ src_mask: the mask for the src sequence (optional).
+ src_key_padding_mask: the mask for the src keys per batch (optional)
+ warmup: controls selective bypass of of layers; if < 1.0, we will
+ bypass layers more frequently.
+
+ Shape:
+ src: (S, N, E).
+ src_mask: (S, S).
+ src_key_padding_mask: (N, S).
+ S is the source sequence length, T is the target sequence length,
+ N is the batch size, E is the feature number
+ """
+ src_orig = src
+
+ warmup_scale = min(0.1 + warmup, 1.0)
+ # alpha = 1.0 means fully use this encoder layer, 0.0 would mean
+ # completely bypass it.
+ if self.training:
+ alpha = (
+ warmup_scale
+ if torch.rand(()).item() <= (1.0 - self.layer_dropout)
+ else 0.1
+ )
+ else:
+ alpha = 1.0
+
+ # src_att = self.self_attn(src, src, src, src_mask)
+ src_att = self.self_attn(
+ src,
+ src,
+ src,
+ attn_mask=src_mask,
+ key_padding_mask=src_key_padding_mask,
+ )[0]
+ src = src + self.dropout(src_att)
+
+ src = src + self.dropout(self.feed_forward(src))
+
+ src = self.norm_final(self.balancer(src))
+
+ if alpha != 1.0:
+ src = alpha * src + (1 - alpha) * src_orig
+
+ return src
+
+
+class TransformerDecoderLayer(nn.Module):
+ """
+ Modified from torch.nn.TransformerDecoderLayer.
+ Add support of normalize_before,
+ i.e., use layer_norm before the first block.
+
+ Args:
+ d_model:
+ the number of expected features in the input (required).
+ nhead:
+ the number of heads in the multiheadattention models (required).
+ dim_feedforward:
+ the dimension of the feedforward network model (default=2048).
+ dropout:
+ the dropout value (default=0.1).
+ activation:
+ the activation function of intermediate layer, relu or
+ gelu (default=relu).
+
+ Examples::
+ >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
+ >>> memory = torch.rand(10, 32, 512)
+ >>> tgt = torch.rand(20, 32, 512)
+ >>> out = decoder_layer(tgt, memory)
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ nhead: int,
+ dim_feedforward: int = 2048,
+ dropout: float = 0.1,
+ layer_dropout: float = 0.075,
+ normalize_before: bool = True,
+ ) -> None:
+ super(TransformerDecoderLayer, self).__init__()
+ self.layer_dropout = layer_dropout
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=0.0)
+ self.src_attn = MultiheadAttention(d_model, nhead, dropout=0.0)
+ # Implementation of Feedforward model
+ self.feed_forward = nn.Sequential(
+ ScaledLinear(d_model, dim_feedforward),
+ ActivationBalancer(channel_dim=-1),
+ DoubleSwish(),
+ nn.Dropout(dropout),
+ ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
+ )
+
+ self.norm_final = BasicNorm(d_model)
+
+ # try to ensure the output is close to zero-mean (or at least, zero-median).
+ self.balancer = ActivationBalancer(
+ channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
+ )
+
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(
+ self,
+ tgt: torch.Tensor,
+ memory: torch.Tensor,
+ tgt_mask: Optional[torch.Tensor] = None,
+ memory_mask: Optional[torch.Tensor] = None,
+ tgt_key_padding_mask: Optional[torch.Tensor] = None,
+ memory_key_padding_mask: Optional[torch.Tensor] = None,
+ warmup: float = 1.0,
+ ) -> torch.Tensor:
+ """Pass the inputs (and mask) through the decoder layer.
+
+ Args:
+ tgt:
+ the sequence to the decoder layer (required).
+ memory:
+ the sequence from the last layer of the encoder (required).
+ tgt_mask:
+ the mask for the tgt sequence (optional).
+ memory_mask:
+ the mask for the memory sequence (optional).
+ tgt_key_padding_mask:
+ the mask for the tgt keys per batch (optional).
+ memory_key_padding_mask:
+ the mask for the memory keys per batch (optional).
+ warmup: controls selective bypass of of layers; if < 1.0, we will
+ bypass layers more frequently.
+
+
+
+ Shape:
+ tgt: (T, N, E).
+ memory: (S, N, E).
+ tgt_mask: (T, T).
+ memory_mask: (T, S).
+ tgt_key_padding_mask: (N, T).
+ memory_key_padding_mask: (N, S).
+ S is the source sequence length, T is the target sequence length,
+ N is the batch size, E is the feature number
+ """
+ tgt_orig = tgt
+
+ warmup_scale = min(0.1 + warmup, 1.0)
+ # alpha = 1.0 means fully use this encoder layer, 0.0 would mean
+ # completely bypass it.
+ if self.training:
+ alpha = (
+ warmup_scale
+ if torch.rand(()).item() <= (1.0 - self.layer_dropout)
+ else 0.1
+ )
+ else:
+ alpha = 1.0
+
+ # tgt_att = self.self_attn(tgt, tgt, tgt, tgt_mask)
+ tgt_att = self.self_attn(
+ tgt,
+ tgt,
+ tgt,
+ attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask,
+ )[0]
+ tgt = tgt + self.dropout(tgt_att)
+
+ # src_att = self.src_attn(tgt, memory, memory, memory_mask)
+ src_att = self.src_attn(
+ tgt,
+ memory,
+ memory,
+ attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask,
+ )[0]
+ tgt = tgt + self.dropout(src_att)
+
+ tgt = tgt + self.dropout(self.feed_forward(tgt))
+
+ tgt = self.norm_final(self.balancer(tgt))
+
+ if alpha != 1.0:
+ tgt = alpha * tgt + (1 - alpha) * tgt_orig
+
+ return tgt
+
+
+class TransformerEncoder(nn.Module):
+ r"""TransformerEncoder is a stack of N encoder layers
+
+ Args:
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
+ num_layers: the number of sub-encoder-layers in the encoder (required).
+
+ Examples::
+ >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
+ >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
+ >>> src = torch.rand(10, 32, 512)
+ >>> out = transformer_encoder(src)
+ """
+
+ def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [copy.deepcopy(encoder_layer) for i in range(num_layers)]
+ )
+ self.num_layers = num_layers
+
+ def forward(
+ self,
+ src: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ src_key_padding_mask: Optional[torch.Tensor] = None,
+ warmup: float = 1.0,
+ ) -> torch.Tensor:
+ r"""Pass the input through the encoder layers in turn.
+
+ Args:
+ src: the sequence to the encoder (required).
+ mask: the mask for the src sequence (optional).
+ src_key_padding_mask: the mask for the src keys per batch (optional).
+
+ Shape:
+ src: (S, N, E).
+ mask: (S, S).
+ src_key_padding_mask: (N, S).
+ S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
+
+ """
+ output = src
+
+ for mod in self.layers:
+ output = mod(
+ output,
+ src_mask=mask,
+ src_key_padding_mask=src_key_padding_mask,
+ warmup=warmup,
+ )
+
+ return output
+
+
+class TransformerDecoder(nn.Module):
+ r"""TransformerDecoder is a stack of N decoder layers
+
+ Args:
+ decoder_layer: an instance of the TransformerDecoderLayer() class (required).
+ num_layers: the number of sub-decoder-layers in the decoder (required).
+
+ Examples::
+ >>> decoder_layer = TransformerDecoderLayer(d_model=512, nhead=8)
+ >>> transformer_decoder = TransformerDecoder(decoder_layer, num_layers=6)
+ >>> memory = torch.rand(10, 32, 512)
+ >>> tgt = torch.rand(10, 32, 512)
+ >>> out = transformer_decoder(tgt, memory)
+ """
+
+ def __init__(self, decoder_layer: nn.Module, num_layers: int) -> None:
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [copy.deepcopy(decoder_layer) for i in range(num_layers)]
+ )
+ self.num_layers = num_layers
+
+ def forward(
+ self,
+ tgt: torch.Tensor,
+ memory: torch.Tensor,
+ tgt_mask: Optional[torch.Tensor] = None,
+ memory_mask: Optional[torch.Tensor] = None,
+ tgt_key_padding_mask: Optional[torch.Tensor] = None,
+ memory_key_padding_mask: Optional[torch.Tensor] = None,
+ warmup: float = 1.0,
+ ) -> torch.Tensor:
+ r"""Pass the input through the decoder layers in turn.
+
+ Args:
+ tgt: the sequence to the decoder (required).
+ memory: the sequence from the last layer of the encoder (required).
+ tgt_mask: the mask for the tgt sequence (optional).
+ memory_mask: the mask for the memory sequence (optional).
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
+
+ Shape:
+ tgt: (S, N, E).
+ tgt_mask: (S, S).
+ tgt_key_padding_mask: (N, S).
+
+ """
+ output = tgt
+
+ for mod in self.layers:
+ output = mod(
+ output,
+ memory,
+ tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask,
+ warmup=warmup,
+ )
+
+ return output
+
+
+class PositionalEncoding(nn.Module):
+ """This class implements the positional encoding
+ proposed in the following paper:
+
+ - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf
+
+ PE(pos, 2i) = sin(pos / (10000^(2i/d_modle))
+ PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle))
+
+ Note::
+
+ 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model)))
+ = exp(-1* 2i / d_model * log(100000))
+ = exp(2i * -(log(10000) / d_model))
+ """
+
+ def __init__(self, d_model: int, dropout: float = 0.1) -> None:
+ """
+ Args:
+ d_model:
+ Embedding dimension.
+ dropout:
+ Dropout probability to be applied to the output of this module.
+ """
+ super().__init__()
+ self.d_model = d_model
+ self.xscale = math.sqrt(self.d_model)
+ self.dropout = nn.Dropout(p=dropout)
+ # not doing: self.pe = None because of errors thrown by torchscript
+ self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32)
+
+ def extend_pe(self, x: torch.Tensor) -> None:
+ """Extend the time t in the positional encoding if required.
+
+ The shape of `self.pe` is (1, T1, d_model). The shape of the input x
+ is (N, T, d_model). If T > T1, then we change the shape of self.pe
+ to (N, T, d_model). Otherwise, nothing is done.
+
+ Args:
+ x:
+ It is a tensor of shape (N, T, C).
+ Returns:
+ Return None.
+ """
+ if self.pe is not None:
+ if self.pe.size(1) >= x.size(1):
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
+ * -(math.log(10000.0) / self.d_model)
+ )
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ # Now pe is of shape (1, T, d_model), where T is x.size(1)
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Add positional encoding.
+
+ Args:
+ x:
+ Its shape is (N, T, C)
+
+ Returns:
+ Return a tensor of shape (N, T, C)
+ """
+ self.extend_pe(x)
+ x = x * self.xscale + self.pe[:, : x.size(1), :]
+ return self.dropout(x)
+
+
+class Noam(object):
+ """
+ Implements Noam optimizer.
+
+ Proposed in
+ "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
+
+ Modified from
+ https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa
+
+ Args:
+ params:
+ iterable of parameters to optimize or dicts defining parameter groups
+ model_size:
+ attention dimension of the transformer model
+ factor:
+ learning rate factor
+ warm_step:
+ warmup steps
+ """
+
+ def __init__(
+ self,
+ params,
+ model_size: int = 256,
+ factor: float = 10.0,
+ warm_step: int = 25000,
+ weight_decay=0,
+ ) -> None:
+ """Construct an Noam object."""
+ self.optimizer = torch.optim.Adam(
+ params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
+ )
+ self._step = 0
+ self.warmup = warm_step
+ self.factor = factor
+ self.model_size = model_size
+ self._rate = 0
+
+ @property
+ def param_groups(self):
+ """Return param_groups."""
+ return self.optimizer.param_groups
+
+ def step(self):
+ """Update parameters and rate."""
+ self._step += 1
+ rate = self.rate()
+ for p in self.optimizer.param_groups:
+ p["lr"] = rate
+ self._rate = rate
+ self.optimizer.step()
+
+ def rate(self, step=None):
+ """Implement `lrate` above."""
+ if step is None:
+ step = self._step
+ return (
+ self.factor
+ * self.model_size ** (-0.5)
+ * min(step ** (-0.5), step * self.warmup ** (-1.5))
+ )
+
+ def zero_grad(self):
+ """Reset gradient."""
+ self.optimizer.zero_grad()
+
+ def state_dict(self):
+ """Return state_dict."""
+ return {
+ "_step": self._step,
+ "warmup": self.warmup,
+ "factor": self.factor,
+ "model_size": self.model_size,
+ "_rate": self._rate,
+ "optimizer": self.optimizer.state_dict(),
+ }
+
+ def load_state_dict(self, state_dict):
+ """Load state_dict."""
+ for key, value in state_dict.items():
+ if key == "optimizer":
+ self.optimizer.load_state_dict(state_dict["optimizer"])
+ else:
+ setattr(self, key, value)
+
+
+def encoder_padding_mask(
+ max_len: int,
+ subsampling_factor: Optional[int] = 4,
+ supervisions: Optional[Supervisions] = None,
+) -> Optional[torch.Tensor]:
+ """Make mask tensor containing indexes of padded part.
+
+ TODO::
+ This function **assumes** that the model uses
+ a subsampling factor of 4 or 2. We should remove that
+ assumption later.
+
+ Args:
+ max_len:
+ Maximum length of input features.
+ CAUTION: It is the length after subsampling.
+ supervisions:
+ Supervision in lhotse format.
+ See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
+ (CAUTION: It contains length information, i.e., start and number of
+ frames, before subsampling)
+
+ Returns:
+ Tensor: Mask tensor of dimension (batch_size, input_length),
+ True denote the masked indices.
+ """
+ if supervisions is None:
+ return None
+
+ supervision_segments = torch.stack(
+ (
+ supervisions["sequence_idx"],
+ supervisions["start_frame"],
+ supervisions["num_frames"],
+ ),
+ 1,
+ ).to(torch.int32)
+
+ lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
+ for idx in range(supervision_segments.size(0)):
+ # Note: TorchScript doesn't allow to unpack tensors as tuples
+ sequence_idx = supervision_segments[idx, 0].item()
+ start_frame = supervision_segments[idx, 1].item()
+ num_frames = supervision_segments[idx, 2].item()
+ lengths[sequence_idx] = start_frame + num_frames
+
+ if subsampling_factor == 4:
+ lengths = [((i - 1) // 2 - 1) // 2 for i in lengths]
+ elif subsampling_factor == 2:
+ lengths = [(i - 1) // 2 - 2 for i in lengths]
+ bs = int(len(lengths))
+ seq_range = torch.arange(0, max_len, dtype=torch.int64)
+ seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len)
+ # Note: TorchScript doesn't implement Tensor.new()
+ seq_length_expand = torch.tensor(
+ lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype
+ ).unsqueeze(-1)
+ mask = seq_range_expand >= seq_length_expand
+
+ return mask
+
+
+def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
+ """Generate a length mask for input.
+
+ The masked position are filled with True,
+ Unmasked positions are filled with False.
+
+ Args:
+ ys_pad:
+ padded tensor of dimension (batch_size, input_length).
+ ignore_id:
+ the ignored number (the padding number) in ys_pad
+
+ Returns:
+ Tensor:
+ a bool tensor of the same shape as the input tensor.
+ """
+ ys_mask = ys_pad == ignore_id
+ return ys_mask
+
+
+def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
+ """Generate a square mask for the sequence. The masked positions are
+ filled with float('-inf'). Unmasked positions are filled with float(0.0).
+ The mask can be used for masked self-attention.
+
+ For instance, if sz is 3, it returns::
+
+ tensor([[0., -inf, -inf],
+ [0., 0., -inf],
+ [0., 0., 0]])
+
+ Args:
+ sz: mask size
+
+ Returns:
+ A square mask of dimension (sz, sz)
+ """
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
+ mask = (
+ mask.float()
+ .masked_fill(mask == 0, float("-inf"))
+ .masked_fill(mask == 1, float(0.0))
+ )
+ return mask
+
+
+def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
+ """Prepend sos_id to each utterance.
+
+ Args:
+ token_ids:
+ A list-of-list of token IDs. Each sublist contains
+ token IDs (e.g., word piece IDs) of an utterance.
+ sos_id:
+ The ID of the SOS token.
+
+ Return:
+ Return a new list-of-list, where each sublist starts
+ with SOS ID.
+ """
+ return [[sos_id] + utt for utt in token_ids]
+
+
+def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
+ """Append eos_id to each utterance.
+
+ Args:
+ token_ids:
+ A list-of-list of token IDs. Each sublist contains
+ token IDs (e.g., word piece IDs) of an utterance.
+ eos_id:
+ The ID of the EOS token.
+
+ Return:
+ Return a new list-of-list, where each sublist ends
+ with EOS ID.
+ """
+ return [utt + [eos_id] for utt in token_ids]
+
+
+def tolist(t: torch.Tensor) -> List[int]:
+ """Used by jit"""
+ return torch.jit.annotate(List[int], t.tolist())
diff --git a/egs/librispeech/WSASR/figures/del.png b/egs/librispeech/WSASR/figures/del.png
new file mode 100644
index 000000000..38973980b
Binary files /dev/null and b/egs/librispeech/WSASR/figures/del.png differ
diff --git a/egs/librispeech/WSASR/figures/ins.png b/egs/librispeech/WSASR/figures/ins.png
new file mode 100644
index 000000000..2d0e807a9
Binary files /dev/null and b/egs/librispeech/WSASR/figures/ins.png differ
diff --git a/egs/librispeech/WSASR/figures/otc_emission.drawio.png b/egs/librispeech/WSASR/figures/otc_emission.drawio.png
new file mode 100644
index 000000000..6cea5531d
Binary files /dev/null and b/egs/librispeech/WSASR/figures/otc_emission.drawio.png differ
diff --git a/egs/librispeech/WSASR/figures/otc_g.png b/egs/librispeech/WSASR/figures/otc_g.png
new file mode 100644
index 000000000..ebad49180
Binary files /dev/null and b/egs/librispeech/WSASR/figures/otc_g.png differ
diff --git a/egs/librispeech/WSASR/figures/otc_training_graph.drawio.png b/egs/librispeech/WSASR/figures/otc_training_graph.drawio.png
new file mode 100644
index 000000000..8978158d8
Binary files /dev/null and b/egs/librispeech/WSASR/figures/otc_training_graph.drawio.png differ
diff --git a/egs/librispeech/WSASR/figures/sub.png b/egs/librispeech/WSASR/figures/sub.png
new file mode 100644
index 000000000..5674e9feb
Binary files /dev/null and b/egs/librispeech/WSASR/figures/sub.png differ
diff --git a/egs/librispeech/WSASR/local/compile_hlg.py b/egs/librispeech/WSASR/local/compile_hlg.py
new file mode 100755
index 000000000..63791f4cc
--- /dev/null
+++ b/egs/librispeech/WSASR/local/compile_hlg.py
@@ -0,0 +1,173 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This script takes as input lang_dir and generates HLG from
+
+ - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
+ - L, the lexicon, built from lang_dir/L_disambig.pt
+
+ Caution: We use a lexicon that contains disambiguation symbols
+
+ - G, the LM, built from data/lm/G_n_gram.fst.txt
+
+The generated HLG is saved in $lang_dir/HLG.pt
+"""
+import argparse
+import logging
+from pathlib import Path
+
+import k2
+import torch
+
+from icefall.lexicon import Lexicon
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lm",
+ type=str,
+ default="G_3_gram",
+ help="""Stem name for LM used in HLG compiling.
+ """,
+ )
+ parser.add_argument(
+ "--lm-dir",
+ type=str,
+ help="""LM directory.
+ """,
+ )
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Input and output directory.
+ """,
+ )
+
+ return parser.parse_args()
+
+
+def compile_HLG(lm_dir: str, lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
+ """
+ Args:
+ lang_dir:
+ The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
+ lm:
+ The language stem base name.
+
+ Return:
+ An FSA representing HLG.
+ """
+ lexicon = Lexicon(lang_dir)
+ max_token_id = max(lexicon.tokens)
+ logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
+ H = k2.ctc_topo(max_token_id)
+ L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
+
+ if Path(f"{lm_dir}/{lm}.pt").is_file():
+ logging.info(f"Loading pre-compiled {lm}")
+ d = torch.load(f"{lm_dir}/{lm}.pt")
+ G = k2.Fsa.from_dict(d)
+ else:
+ logging.info(f"Loading {lm}.fst.txt")
+ with open(f"{lm_dir}/{lm}.fst.txt") as f:
+ G = k2.Fsa.from_openfst(f.read(), acceptor=False)
+ torch.save(G.as_dict(), f"{lm_dir}/{lm}.pt")
+
+ first_token_disambig_id = lexicon.token_table["#0"]
+ first_word_disambig_id = lexicon.word_table["#0"]
+
+ L = k2.arc_sort(L)
+ G = k2.arc_sort(G)
+
+ logging.info("Intersecting L and G")
+ LG = k2.compose(L, G)
+ logging.info(f"LG shape: {LG.shape}")
+
+ logging.info("Connecting LG")
+ LG = k2.connect(LG)
+ logging.info(f"LG shape after k2.connect: {LG.shape}")
+
+ logging.info(type(LG.aux_labels))
+ logging.info("Determinizing LG")
+
+ LG = k2.determinize(LG)
+ logging.info(type(LG.aux_labels))
+
+ logging.info("Connecting LG after k2.determinize")
+ LG = k2.connect(LG)
+
+ logging.info("Removing disambiguation symbols on LG")
+
+ LG.labels[LG.labels >= first_token_disambig_id] = 0
+ # See https://github.com/k2-fsa/k2/issues/874
+ # for why we need to set LG.properties to None
+ LG.__dict__["_properties"] = None
+
+ assert isinstance(LG.aux_labels, k2.RaggedTensor)
+ LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
+
+ LG = k2.remove_epsilon(LG)
+ logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
+
+ LG = k2.connect(LG)
+ LG.aux_labels = LG.aux_labels.remove_values_eq(0)
+
+ logging.info("Arc sorting LG")
+ LG = k2.arc_sort(LG)
+
+ logging.info("Composing H and LG")
+ # CAUTION: The name of the inner_labels is fixed
+ # to `tokens`. If you want to change it, please
+ # also change other places in icefall that are using
+ # it.
+ HLG = k2.compose(H, LG, inner_labels="tokens")
+
+ logging.info("Connecting LG")
+ HLG = k2.connect(HLG)
+
+ logging.info("Arc sorting LG")
+ HLG = k2.arc_sort(HLG)
+ logging.info(f"HLG.shape: {HLG.shape}")
+
+ return HLG
+
+
+def main():
+ args = get_args()
+ lm_dir = Path(args.lm_dir)
+ lang_dir = Path(args.lang_dir)
+
+ if (lang_dir / "HLG.pt").is_file():
+ logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
+ return
+
+ logging.info(f"Processing {lang_dir}")
+
+ HLG = compile_HLG(lm_dir, lang_dir, args.lm)
+ logging.info(f"Saving HLG.pt to {lang_dir}")
+ torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ main()
diff --git a/egs/librispeech/WSASR/local/compute_fbank_librispeech.py b/egs/librispeech/WSASR/local/compute_fbank_librispeech.py
new file mode 100755
index 000000000..a387d54c9
--- /dev/null
+++ b/egs/librispeech/WSASR/local/compute_fbank_librispeech.py
@@ -0,0 +1,162 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the LibriSpeech dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import argparse
+import logging
+import os
+from pathlib import Path
+from typing import Optional
+
+import sentencepiece as spm
+import torch
+from filter_cuts import filter_cuts
+from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
+from lhotse.recipes.utils import read_manifests_if_cached
+
+from icefall.utils import get_executor, str2bool
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ help="""Path to the bpe.model. If not None, we will remove short and
+ long utterances before extracting features""",
+ )
+
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ help="""Dataset parts to compute fbank. If None, we will use all""",
+ )
+
+ parser.add_argument(
+ "--perturb-speed",
+ type=str2bool,
+ default=True,
+ help="""Perturb speed with factor 0.9 and 1.1 on train subset.""",
+ )
+
+ return parser.parse_args()
+
+
+def compute_fbank_librispeech(
+ bpe_model: Optional[str] = None,
+ dataset: Optional[str] = None,
+ perturb_speed: Optional[bool] = True,
+):
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/fbank")
+ num_jobs = min(15, os.cpu_count())
+ num_mel_bins = 80
+
+ if bpe_model:
+ logging.info(f"Loading {bpe_model}")
+ sp = spm.SentencePieceProcessor()
+ sp.load(bpe_model)
+
+ if dataset is None:
+ dataset_parts = (
+ "dev-clean",
+ "dev-other",
+ "test-clean",
+ "test-other",
+ "train-clean-100",
+ )
+ else:
+ dataset_parts = dataset.split(" ", -1)
+
+ prefix = "librispeech"
+ suffix = "jsonl.gz"
+ manifests = read_manifests_if_cached(
+ dataset_parts=dataset_parts,
+ output_dir=src_dir,
+ prefix=prefix,
+ suffix=suffix,
+ )
+ assert manifests is not None
+
+ assert len(manifests) == len(dataset_parts), (
+ len(manifests),
+ len(dataset_parts),
+ list(manifests.keys()),
+ dataset_parts,
+ )
+
+ extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+
+ with get_executor() as ex: # Initialize the executor only once.
+ for partition, m in manifests.items():
+ cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
+ if (output_dir / cuts_filename).is_file():
+ logging.info(f"{partition} already exists - skipping.")
+ continue
+ logging.info(f"Processing {partition}")
+ cut_set = CutSet.from_manifests(
+ recordings=m["recordings"],
+ supervisions=m["supervisions"],
+ )
+
+ if "train" in partition:
+ if bpe_model:
+ cut_set = filter_cuts(cut_set, sp)
+ if perturb_speed:
+ logging.info(f"Doing speed perturb")
+ cut_set = (
+ cut_set
+ + cut_set.perturb_speed(0.9)
+ + cut_set.perturb_speed(1.1)
+ )
+ cut_set = cut_set.compute_and_store_features(
+ extractor=extractor,
+ storage_path=f"{output_dir}/{prefix}_feats_{partition}",
+ # when an executor is specified, make more partitions
+ num_jobs=num_jobs if ex is None else 80,
+ executor=ex,
+ storage_type=LilcomChunkyWriter,
+ )
+ cut_set.to_file(output_dir / cuts_filename)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ args = get_args()
+ logging.info(vars(args))
+ compute_fbank_librispeech(
+ bpe_model=args.bpe_model,
+ dataset=args.dataset,
+ perturb_speed=args.perturb_speed,
+ )
diff --git a/egs/librispeech/WSASR/local/compute_ssl_librispeech.py b/egs/librispeech/WSASR/local/compute_ssl_librispeech.py
new file mode 100755
index 000000000..f405c468c
--- /dev/null
+++ b/egs/librispeech/WSASR/local/compute_ssl_librispeech.py
@@ -0,0 +1,100 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+# 2023 Johns Hopkins University (author: Dongji Gao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the LibriSpeech dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import logging
+import os
+from pathlib import Path
+
+import torch
+from lhotse import S3PRLSSL, CutSet, NumpyFilesWriter, S3PRLSSLConfig
+from lhotse.recipes.utils import read_manifests_if_cached
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_ssl_librispeech():
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/ssl")
+ num_jobs = 1
+
+ dataset_parts = (
+ "dev-clean",
+ "dev-other",
+ "test-clean",
+ "test-other",
+ "train-clean-100",
+ )
+ prefix = "librispeech"
+ suffix = "jsonl.gz"
+ manifests = read_manifests_if_cached(
+ dataset_parts=dataset_parts,
+ output_dir=src_dir,
+ prefix=prefix,
+ suffix=suffix,
+ )
+ assert manifests is not None
+
+ assert len(manifests) == len(dataset_parts), (
+ len(manifests),
+ len(dataset_parts),
+ list(manifests.keys()),
+ dataset_parts,
+ )
+
+ extractor = S3PRLSSL(S3PRLSSLConfig(ssl_model="wav2vec2", device="cuda"))
+
+ with get_executor() as ex: # Initialize the executor only once.
+ for partition, m in manifests.items():
+ cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
+ if (output_dir / cuts_filename).is_file():
+ logging.info(f"{partition} already exists - skipping.")
+ continue
+ logging.info(f"Processing {partition}")
+ cut_set = CutSet.from_manifests(
+ recordings=m["recordings"],
+ supervisions=m["supervisions"],
+ )
+ cut_set = cut_set.compute_and_store_features(
+ extractor=extractor,
+ storage_path=f"{output_dir}/{prefix}_feats_{partition}",
+ storage_type=NumpyFilesWriter,
+ )
+ cut_set.to_file(output_dir / cuts_filename)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ compute_ssl_librispeech()
diff --git a/egs/librispeech/WSASR/local/filter_cuts.py b/egs/librispeech/WSASR/local/filter_cuts.py
new file mode 100644
index 000000000..fbcc9e24a
--- /dev/null
+++ b/egs/librispeech/WSASR/local/filter_cuts.py
@@ -0,0 +1,160 @@
+#!/usr/bin/env python3
+# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script removes short and long utterances from a cutset.
+
+Caution:
+ You may need to tune the thresholds for your own dataset.
+
+Usage example:
+
+ python3 ./local/filter_cuts.py \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --in-cuts data/fbank/librispeech_cuts_test-clean.jsonl.gz \
+ --out-cuts data/fbank-filtered/librispeech_cuts_test-clean.jsonl.gz
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import sentencepiece as spm
+from lhotse import CutSet, load_manifest_lazy
+from lhotse.cut import Cut
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--bpe-model",
+ type=Path,
+ help="Path to the bpe.model",
+ )
+
+ parser.add_argument(
+ "--in-cuts",
+ type=Path,
+ help="Path to the input cutset",
+ )
+
+ parser.add_argument(
+ "--out-cuts",
+ type=Path,
+ help="Path to the output cutset",
+ )
+
+ return parser.parse_args()
+
+
+def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor):
+ total = 0 # number of total utterances before removal
+ removed = 0 # number of removed utterances
+
+ def remove_short_and_long_utterances(c: Cut):
+ """Return False to exclude the input cut"""
+ nonlocal removed, total
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ./display_manifest_statistics.py
+ #
+ # You should use ./display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ total += 1
+ if c.duration < 1.0 or c.duration > 20.0:
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ )
+ removed += 1
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./pruned_transducer_stateless2/conformer.py, the
+ # conv module uses the following expression
+ # for subsampling
+ if c.num_frames is None:
+ num_frames = c.duration * 100 # approximate
+ else:
+ num_frames = c.num_frames
+
+ T = ((num_frames - 1) // 2 - 1) // 2
+ # Note: for ./lstm_transducer_stateless/lstm.py, the formula is
+ # T = ((num_frames - 3) // 2 - 1) // 2
+
+ # Note: for ./pruned_transducer_stateless7/zipformer.py, the formula is
+ # T = ((num_frames - 7) // 2 + 1) // 2
+
+ tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ removed += 1
+ return False
+
+ return True
+
+ # We use to_eager() here so that we can print out the value of total
+ # and removed below.
+ ans = cut_set.filter(remove_short_and_long_utterances).to_eager()
+ ratio = removed / total * 100
+ logging.info(
+ f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed."
+ )
+ return ans
+
+
+def main():
+ args = get_args()
+ logging.info(vars(args))
+
+ if args.out_cuts.is_file():
+ logging.info(f"{args.out_cuts} already exists - skipping")
+ return
+
+ assert args.in_cuts.is_file(), f"{args.in_cuts} does not exist"
+ assert args.bpe_model.is_file(), f"{args.bpe_model} does not exist"
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(str(args.bpe_model))
+
+ cut_set = load_manifest_lazy(args.in_cuts)
+ assert isinstance(cut_set, CutSet)
+
+ cut_set = filter_cuts(cut_set, sp)
+ logging.info(f"Saving to {args.out_cuts}")
+ args.out_cuts.parent.mkdir(parents=True, exist_ok=True)
+ cut_set.to_file(args.out_cuts)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ main()
diff --git a/egs/librispeech/WSASR/local/get_words_from_lexicon.py b/egs/librispeech/WSASR/local/get_words_from_lexicon.py
new file mode 100755
index 000000000..0cc740b36
--- /dev/null
+++ b/egs/librispeech/WSASR/local/get_words_from_lexicon.py
@@ -0,0 +1,48 @@
+#!/usr/bin/env python3
+
+import argparse
+from pathlib import Path
+
+from icefall.lexicon import read_lexicon
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Input and output directory.
+ It should contain a file lexicon.txt.
+ Generated files by this script are saved into this directory.
+ """,
+ )
+
+ parser.add_argument(
+ "--otc-token",
+ type=str,
+ help="OTC token to be added to words.txt",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+ lang_dir = Path(args.lang_dir)
+ otc_token = args.otc_token
+
+ lexicon = read_lexicon(lang_dir / "lexicon.txt")
+ ans = set()
+ for word, _ in lexicon:
+ ans.add(word)
+ sorted_ans = sorted(list(ans))
+ words = [""] + sorted_ans + [otc_token] + ["#0", "", ""]
+
+ words_file = lang_dir / "words.txt"
+ with open(words_file, "w") as wf:
+ for i, word in enumerate(words):
+ wf.write(f"{word} {i}\n")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/WSASR/local/make_error_cutset.py b/egs/librispeech/WSASR/local/make_error_cutset.py
new file mode 100755
index 000000000..8463a380e
--- /dev/null
+++ b/egs/librispeech/WSASR/local/make_error_cutset.py
@@ -0,0 +1,175 @@
+#!/usr/bin/env python3
+
+# Copyright 2023 Johns Hopkins University (author: Dongji Gao)
+
+import argparse
+import random
+from pathlib import Path
+from typing import List
+
+from lhotse import CutSet, load_manifest
+from lhotse.cut.base import Cut
+
+from icefall.utils import str2bool
+
+
+def get_args():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+ parser.add_argument(
+ "--input-cutset",
+ type=str,
+ help="Supervision manifest that contains verbatim transcript",
+ )
+
+ parser.add_argument(
+ "--words-file",
+ type=str,
+ help="words.txt file",
+ )
+
+ parser.add_argument(
+ "--otc-token",
+ type=str,
+ help="OTC token in words.txt",
+ )
+
+ parser.add_argument(
+ "--sub-error-rate",
+ type=float,
+ default=0.0,
+ help="Substitution error rate",
+ )
+
+ parser.add_argument(
+ "--ins-error-rate",
+ type=float,
+ default=0.0,
+ help="Insertion error rate",
+ )
+
+ parser.add_argument(
+ "--del-error-rate",
+ type=float,
+ default=0.0,
+ help="Deletion error rate",
+ )
+
+ parser.add_argument(
+ "--output-cutset",
+ type=str,
+ default="",
+ help="Supervision manifest that contains modified non-verbatim transcript",
+ )
+
+ parser.add_argument("--verbose", type=str2bool, help="show details of errors")
+ return parser.parse_args()
+
+
+def check_args(args):
+ total_error_rate = args.sub_error_rate + args.ins_error_rate + args.del_error_rate
+ assert args.sub_error_rate >= 0 and args.sub_error_rate <= 1.0
+ assert args.ins_error_rate >= 0 and args.sub_error_rate <= 1.0
+ assert args.del_error_rate >= 0 and args.sub_error_rate <= 1.0
+ assert total_error_rate <= 1.0
+
+
+def get_word_list(token_path: str) -> List:
+ word_list = []
+ with open(Path(token_path), "r") as tp:
+ for line in tp.readlines():
+ token = line.split()[0]
+ assert token not in word_list
+ word_list.append(token)
+ return word_list
+
+
+def modify_cut_text(
+ cut: Cut,
+ words_list: List,
+ non_words: List,
+ sub_ratio: float = 0.0,
+ ins_ratio: float = 0.0,
+ del_ratio: float = 0.0,
+):
+ text = cut.supervisions[0].text
+ text_list = text.split()
+
+ # We save the modified information of the original verbatim text for debugging
+ marked_verbatim_text_list = []
+ modified_text_list = []
+
+ del_index_set = set()
+ sub_index_set = set()
+ ins_index_set = set()
+
+ # We follow the order: deletion -> substitution -> insertion
+ for token in text_list:
+ marked_token = token
+ modified_token = token
+
+ prob = random.random()
+
+ if prob <= del_ratio:
+ marked_token = f"-{token}-"
+ modified_token = ""
+ elif prob <= del_ratio + sub_ratio + ins_ratio:
+ if prob <= del_ratio + sub_ratio:
+ marked_token = f"[{token}]"
+ else:
+ marked_verbatim_text_list.append(marked_token)
+ modified_text_list.append(modified_token)
+ marked_token = "[]"
+
+ # get new_token
+ while (
+ modified_token == token
+ or modified_token in non_words
+ or modified_token.startswith("#")
+ ):
+ modified_token = random.choice(words_list)
+
+ marked_verbatim_text_list.append(marked_token)
+ modified_text_list.append(modified_token)
+
+ marked_text = " ".join(marked_verbatim_text_list)
+ modified_text = " ".join(modified_text_list)
+
+ if not hasattr(cut.supervisions[0], "verbatim_text"):
+ cut.supervisions[0].verbatim_text = marked_text
+ cut.supervisions[0].text = modified_text
+
+ return cut
+
+
+def main():
+ args = get_args()
+ check_args(args)
+
+ otc_token = args.otc_token
+ non_words = set(("sil", "", ""))
+ non_words.add(otc_token)
+
+ words_list = get_word_list(args.words_file)
+ cutset = load_manifest(Path(args.input_cutset))
+
+ cuts = []
+
+ for cut in cutset:
+ modified_cut = modify_cut_text(
+ cut=cut,
+ words_list=words_list,
+ non_words=non_words,
+ sub_ratio=args.sub_error_rate,
+ ins_ratio=args.ins_error_rate,
+ del_ratio=args.del_error_rate,
+ )
+ cuts.append(modified_cut)
+
+ output_cutset = CutSet.from_cuts(cuts)
+ output_cutset.to_file(args.output_cutset)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/WSASR/local/prepare_lang.py b/egs/librispeech/WSASR/local/prepare_lang.py
new file mode 100755
index 000000000..d913756a1
--- /dev/null
+++ b/egs/librispeech/WSASR/local/prepare_lang.py
@@ -0,0 +1,413 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
+consisting of words and tokens (i.e., phones) and does the following:
+
+1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
+
+2. Generate tokens.txt, the token table mapping a token to a unique integer.
+
+3. Generate words.txt, the word table mapping a word to a unique integer.
+
+4. Generate L.pt, in k2 format. It can be loaded by
+
+ d = torch.load("L.pt")
+ lexicon = k2.Fsa.from_dict(d)
+
+5. Generate L_disambig.pt, in k2 format.
+"""
+import argparse
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Any, Dict, List, Tuple
+
+import k2
+import torch
+
+from icefall.lexicon import read_lexicon, write_lexicon
+from icefall.utils import str2bool
+
+Lexicon = List[Tuple[str, List[str]]]
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Input and output directory.
+ It should contain a file lexicon.txt.
+ Generated files by this script are saved into this directory.
+ """,
+ )
+
+ parser.add_argument(
+ "--debug",
+ type=str2bool,
+ default=False,
+ help="""True for debugging, which will generate
+ a visualization of the lexicon FST.
+
+ Caution: If your lexicon contains hundreds of thousands
+ of lines, please set it to False!
+ """,
+ )
+
+ return parser.parse_args()
+
+
+def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
+ """Write a symbol to ID mapping to a file.
+
+ Note:
+ No need to implement `read_mapping` as it can be done
+ through :func:`k2.SymbolTable.from_file`.
+
+ Args:
+ filename:
+ Filename to save the mapping.
+ sym2id:
+ A dict mapping symbols to IDs.
+ Returns:
+ Return None.
+ """
+ with open(filename, "w", encoding="utf-8") as f:
+ for sym, i in sym2id.items():
+ f.write(f"{sym} {i}\n")
+
+
+def get_tokens(lexicon: Lexicon) -> List[str]:
+ """Get tokens from a lexicon.
+
+ Args:
+ lexicon:
+ It is the return value of :func:`read_lexicon`.
+ Returns:
+ Return a list of unique tokens.
+ """
+ ans = set()
+ for _, tokens in lexicon:
+ ans.update(tokens)
+ sorted_ans = sorted(list(ans))
+ return sorted_ans
+
+
+def get_words(lexicon: Lexicon) -> List[str]:
+ """Get words from a lexicon.
+
+ Args:
+ lexicon:
+ It is the return value of :func:`read_lexicon`.
+ Returns:
+ Return a list of unique words.
+ """
+ ans = set()
+ for word, _ in lexicon:
+ ans.add(word)
+ sorted_ans = sorted(list(ans))
+ return sorted_ans
+
+
+def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
+ """It adds pseudo-token disambiguation symbols #1, #2 and so on
+ at the ends of tokens to ensure that all pronunciations are different,
+ and that none is a prefix of another.
+
+ See also add_lex_disambig.pl from kaldi.
+
+ Args:
+ lexicon:
+ It is returned by :func:`read_lexicon`.
+ Returns:
+ Return a tuple with two elements:
+
+ - The output lexicon with disambiguation symbols
+ - The ID of the max disambiguation symbol that appears
+ in the lexicon
+ """
+
+ # (1) Work out the count of each token-sequence in the
+ # lexicon.
+ count = defaultdict(int)
+ for _, tokens in lexicon:
+ count[" ".join(tokens)] += 1
+
+ # (2) For each left sub-sequence of each token-sequence, note down
+ # that it exists (for identifying prefixes of longer strings).
+ issubseq = defaultdict(int)
+ for _, tokens in lexicon:
+ tokens = tokens.copy()
+ tokens.pop()
+ while tokens:
+ issubseq[" ".join(tokens)] = 1
+ tokens.pop()
+
+ # (3) For each entry in the lexicon:
+ # if the token sequence is unique and is not a
+ # prefix of another word, no disambig symbol.
+ # Else output #1, or #2, #3, ... if the same token-seq
+ # has already been assigned a disambig symbol.
+ ans = []
+
+ # We start with #1 since #0 has its own purpose
+ first_allowed_disambig = 1
+ max_disambig = first_allowed_disambig - 1
+ last_used_disambig_symbol_of = defaultdict(int)
+
+ for word, tokens in lexicon:
+ tokenseq = " ".join(tokens)
+ assert tokenseq != ""
+ if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
+ ans.append((word, tokens))
+ continue
+
+ cur_disambig = last_used_disambig_symbol_of[tokenseq]
+ if cur_disambig == 0:
+ cur_disambig = first_allowed_disambig
+ else:
+ cur_disambig += 1
+
+ if cur_disambig > max_disambig:
+ max_disambig = cur_disambig
+ last_used_disambig_symbol_of[tokenseq] = cur_disambig
+ tokenseq += f" #{cur_disambig}"
+ ans.append((word, tokenseq.split()))
+ return ans, max_disambig
+
+
+def generate_id_map(symbols: List[str]) -> Dict[str, int]:
+ """Generate ID maps, i.e., map a symbol to a unique ID.
+
+ Args:
+ symbols:
+ A list of unique symbols.
+ Returns:
+ A dict containing the mapping between symbols and IDs.
+ """
+ return {sym: i for i, sym in enumerate(symbols)}
+
+
+def add_self_loops(
+ arcs: List[List[Any]], disambig_token: int, disambig_word: int
+) -> List[List[Any]]:
+ """Adds self-loops to states of an FST to propagate disambiguation symbols
+ through it. They are added on each state with non-epsilon output symbols
+ on at least one arc out of the state.
+
+ See also fstaddselfloops.pl from Kaldi. One difference is that
+ Kaldi uses OpenFst style FSTs and it has multiple final states.
+ This function uses k2 style FSTs and it does not need to add self-loops
+ to the final state.
+
+ The input label of a self-loop is `disambig_token`, while the output
+ label is `disambig_word`.
+
+ Args:
+ arcs:
+ A list-of-list. The sublist contains
+ `[src_state, dest_state, label, aux_label, score]`
+ disambig_token:
+ It is the token ID of the symbol `#0`.
+ disambig_word:
+ It is the word ID of the symbol `#0`.
+
+ Return:
+ Return new `arcs` containing self-loops.
+ """
+ states_needs_self_loops = set()
+ for arc in arcs:
+ src, dst, ilabel, olabel, score = arc
+ if olabel != 0:
+ states_needs_self_loops.add(src)
+
+ ans = []
+ for s in states_needs_self_loops:
+ ans.append([s, s, disambig_token, disambig_word, 0])
+
+ return arcs + ans
+
+
+def lexicon_to_fst(
+ lexicon: Lexicon,
+ token2id: Dict[str, int],
+ word2id: Dict[str, int],
+ sil_token: str = "SIL",
+ sil_prob: float = 0.5,
+ need_self_loops: bool = False,
+) -> k2.Fsa:
+ """Convert a lexicon to an FST (in k2 format) with optional silence at
+ the beginning and end of each word.
+
+ Args:
+ lexicon:
+ The input lexicon. See also :func:`read_lexicon`
+ token2id:
+ A dict mapping tokens to IDs.
+ word2id:
+ A dict mapping words to IDs.
+ sil_token:
+ The silence token.
+ sil_prob:
+ The probability for adding a silence at the beginning and end
+ of the word.
+ need_self_loops:
+ If True, add self-loop to states with non-epsilon output symbols
+ on at least one arc out of the state. The input label for this
+ self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
+ Returns:
+ Return an instance of `k2.Fsa` representing the given lexicon.
+ """
+ assert sil_prob > 0.0 and sil_prob < 1.0
+ # CAUTION: we use score, i.e, negative cost.
+ sil_score = math.log(sil_prob)
+ no_sil_score = math.log(1.0 - sil_prob)
+
+ start_state = 0
+ loop_state = 1 # words enter and leave from here
+ sil_state = 2 # words terminate here when followed by silence; this state
+ # has a silence transition to loop_state.
+ next_state = 3 # the next un-allocated state, will be incremented as we go.
+ arcs = []
+
+ assert token2id[""] == 0
+ assert word2id[""] == 0
+
+ eps = 0
+
+ sil_token = token2id[sil_token]
+
+ arcs.append([start_state, loop_state, eps, eps, no_sil_score])
+ arcs.append([start_state, sil_state, eps, eps, sil_score])
+ arcs.append([sil_state, loop_state, sil_token, eps, 0])
+
+ for word, tokens in lexicon:
+ assert len(tokens) > 0, f"{word} has no pronunciations"
+ cur_state = loop_state
+
+ word = word2id[word]
+ tokens = [token2id[i] for i in tokens]
+
+ for i in range(len(tokens) - 1):
+ w = word if i == 0 else eps
+ arcs.append([cur_state, next_state, tokens[i], w, 0])
+
+ cur_state = next_state
+ next_state += 1
+
+ # now for the last token of this word
+ # It has two out-going arcs, one to the loop state,
+ # the other one to the sil_state.
+ i = len(tokens) - 1
+ w = word if i == 0 else eps
+ arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
+ arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
+
+ if need_self_loops:
+ disambig_token = token2id["#0"]
+ disambig_word = word2id["#0"]
+ arcs = add_self_loops(
+ arcs,
+ disambig_token=disambig_token,
+ disambig_word=disambig_word,
+ )
+
+ final_state = next_state
+ arcs.append([loop_state, final_state, -1, -1, 0])
+ arcs.append([final_state])
+
+ arcs = sorted(arcs, key=lambda arc: arc[0])
+ arcs = [[str(i) for i in arc] for arc in arcs]
+ arcs = [" ".join(arc) for arc in arcs]
+ arcs = "\n".join(arcs)
+
+ fsa = k2.Fsa.from_str(arcs, acceptor=False)
+ return fsa
+
+
+def main():
+ args = get_args()
+ lang_dir = Path(args.lang_dir)
+ lexicon_filename = lang_dir / "lexicon.txt"
+ sil_token = "SIL"
+ sil_prob = 0.5
+
+ lexicon = read_lexicon(lexicon_filename)
+ tokens = get_tokens(lexicon)
+ words = get_words(lexicon)
+
+ lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
+
+ for i in range(max_disambig + 1):
+ disambig = f"#{i}"
+ assert disambig not in tokens
+ tokens.append(f"#{i}")
+
+ assert "" not in tokens
+ tokens = [""] + tokens
+
+ assert "" not in words
+ assert "#0" not in words
+ assert "" not in words
+ assert "" not in words
+
+ words = [""] + words + ["#0", "", ""]
+
+ token2id = generate_id_map(tokens)
+ word2id = generate_id_map(words)
+
+ write_mapping(lang_dir / "tokens.txt", token2id)
+ write_mapping(lang_dir / "words.txt", word2id)
+ write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
+
+ L = lexicon_to_fst(
+ lexicon,
+ token2id=token2id,
+ word2id=word2id,
+ sil_token=sil_token,
+ sil_prob=sil_prob,
+ )
+
+ L_disambig = lexicon_to_fst(
+ lexicon_disambig,
+ token2id=token2id,
+ word2id=word2id,
+ sil_token=sil_token,
+ sil_prob=sil_prob,
+ need_self_loops=True,
+ )
+ torch.save(L.as_dict(), lang_dir / "L.pt")
+ torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
+
+ if args.debug:
+ labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
+ aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
+
+ L.labels_sym = labels_sym
+ L.aux_labels_sym = aux_labels_sym
+ L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
+
+ L_disambig.labels_sym = labels_sym
+ L_disambig.aux_labels_sym = aux_labels_sym
+ L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/WSASR/local/prepare_otc_lang_bpe.py b/egs/librispeech/WSASR/local/prepare_otc_lang_bpe.py
new file mode 100755
index 000000000..415bdff6f
--- /dev/null
+++ b/egs/librispeech/WSASR/local/prepare_otc_lang_bpe.py
@@ -0,0 +1,295 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+# 2023 Johns Hopkins University (author: Dongji Gao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
+
+"""
+
+This script takes as input `lang_dir`, which should contain::
+
+ - lang_dir/bpe.model,
+ - lang_dir/words.txt
+
+and generates the following files in the directory `lang_dir`:
+
+ - lexicon.txt
+ - lexicon_disambig.txt
+ - L.pt
+ - L_disambig.pt
+ - tokens.txt
+"""
+
+import argparse
+from pathlib import Path
+from typing import Dict, List, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+from prepare_lang import (
+ Lexicon,
+ add_disambig_symbols,
+ add_self_loops,
+ write_lexicon,
+ write_mapping,
+)
+
+from icefall.utils import str2bool
+
+
+def lexicon_to_fst_no_sil(
+ lexicon: Lexicon,
+ token2id: Dict[str, int],
+ word2id: Dict[str, int],
+ need_self_loops: bool = False,
+) -> k2.Fsa:
+ """Convert a lexicon to an FST (in k2 format).
+
+ Args:
+ lexicon:
+ The input lexicon. See also :func:`read_lexicon`
+ token2id:
+ A dict mapping tokens to IDs.
+ word2id:
+ A dict mapping words to IDs.
+ need_self_loops:
+ If True, add self-loop to states with non-epsilon output symbols
+ on at least one arc out of the state. The input label for this
+ self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
+ Returns:
+ Return an instance of `k2.Fsa` representing the given lexicon.
+ """
+ loop_state = 0 # words enter and leave from here
+ next_state = 1 # the next un-allocated state, will be incremented as we go
+
+ arcs = []
+
+ # The blank symbol is defined in local/train_bpe_model.py
+ assert token2id[""] == 0
+ assert word2id[""] == 0
+
+ eps = 0
+
+ for word, pieces in lexicon:
+ assert len(pieces) > 0, f"{word} has no pronunciations"
+ cur_state = loop_state
+
+ word = word2id[word]
+ pieces = [token2id[i] for i in pieces]
+
+ for i in range(len(pieces) - 1):
+ w = word if i == 0 else eps
+ arcs.append([cur_state, next_state, pieces[i], w, 0])
+
+ cur_state = next_state
+ next_state += 1
+
+ # now for the last piece of this word
+ i = len(pieces) - 1
+ w = word if i == 0 else eps
+ arcs.append([cur_state, loop_state, pieces[i], w, 0])
+
+ if need_self_loops:
+ disambig_token = token2id["#0"]
+ disambig_word = word2id["#0"]
+ arcs = add_self_loops(
+ arcs,
+ disambig_token=disambig_token,
+ disambig_word=disambig_word,
+ )
+
+ final_state = next_state
+ arcs.append([loop_state, final_state, -1, -1, 0])
+ arcs.append([final_state])
+
+ arcs = sorted(arcs, key=lambda arc: arc[0])
+ arcs = [[str(i) for i in arc] for arc in arcs]
+ arcs = [" ".join(arc) for arc in arcs]
+ arcs = "\n".join(arcs)
+
+ fsa = k2.Fsa.from_str(arcs, acceptor=False)
+ return fsa
+
+
+def generate_otc_lexicon(
+ model_file: str,
+ words: List[str],
+ oov: str,
+ otc_token: str,
+) -> Tuple[Lexicon, Dict[str, int]]:
+ """Generate a lexicon from a BPE model.
+
+ Args:
+ model_file:
+ Path to a sentencepiece model.
+ words:
+ A list of strings representing words.
+ oov:
+ The out of vocabulary word in lexicon.
+ otc_token:
+ The OTC token in lexicon.
+ Returns:
+ Return a tuple with two elements:
+ - A dict whose keys are words and values are the corresponding
+ word pieces.
+ - A dict representing the token symbol, mapping from tokens to IDs.
+ """
+ sp = spm.SentencePieceProcessor()
+ sp.load(str(model_file))
+
+ # Convert word to word piece IDs instead of word piece strings
+ # to avoid OOV tokens.
+ words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int)
+
+ # Now convert word piece IDs back to word piece strings.
+ words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids]
+
+ lexicon = []
+ for word, pieces in zip(words, words_pieces):
+ lexicon.append((word, pieces))
+
+ lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())]))
+ token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())}
+
+ # Add OTC token to the last.
+ lexicon.append((otc_token, [f"▁{otc_token}"]))
+ otc_token_index = len(token2id)
+ token2id[f"▁{otc_token}"] = otc_token_index
+
+ return lexicon, token2id
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Input and output directory.
+ It should contain the bpe.model and words.txt
+ """,
+ )
+
+ parser.add_argument(
+ "--oov",
+ type=str,
+ default="",
+ help="The out of vocabulary word in lexicon.",
+ )
+
+ parser.add_argument(
+ "--otc-token",
+ type=str,
+ default="",
+ help="The OTC token in lexicon.",
+ )
+
+ parser.add_argument(
+ "--debug",
+ type=str2bool,
+ default=False,
+ help="""True for debugging, which will generate
+ a visualization of the lexicon FST.
+
+ Caution: If your lexicon contains hundreds of thousands
+ of lines, please set it to False!
+
+ See "test/test_bpe_lexicon.py" for usage.
+ """,
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+ lang_dir = Path(args.lang_dir)
+ model_file = lang_dir / "bpe.model"
+ otc_token = args.otc_token
+
+ word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
+
+ words = word_sym_table.symbols
+
+ excluded = [
+ "",
+ "!SIL",
+ "",
+ args.oov,
+ otc_token,
+ "#0",
+ "",
+ "",
+ ]
+
+ for w in excluded:
+ if w in words:
+ words.remove(w)
+
+ lexicon, token_sym_table = generate_otc_lexicon(
+ model_file, words, args.oov, otc_token
+ )
+
+ lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
+
+ next_token_id = max(token_sym_table.values()) + 1
+ for i in range(max_disambig + 1):
+ disambig = f"#{i}"
+ assert disambig not in token_sym_table
+ token_sym_table[disambig] = next_token_id
+ next_token_id += 1
+
+ word_sym_table.add("#0")
+ word_sym_table.add("")
+ word_sym_table.add("")
+
+ write_mapping(lang_dir / "tokens.txt", token_sym_table)
+
+ write_lexicon(lang_dir / "lexicon.txt", lexicon)
+ write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
+
+ L = lexicon_to_fst_no_sil(
+ lexicon,
+ token2id=token_sym_table,
+ word2id=word_sym_table,
+ )
+
+ L_disambig = lexicon_to_fst_no_sil(
+ lexicon_disambig,
+ token2id=token_sym_table,
+ word2id=word_sym_table,
+ need_self_loops=True,
+ )
+ torch.save(L.as_dict(), lang_dir / "L.pt")
+ torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
+
+ if args.debug:
+ labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
+ aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
+
+ L.labels_sym = labels_sym
+ L.aux_labels_sym = aux_labels_sym
+ L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
+
+ L_disambig.labels_sym = labels_sym
+ L_disambig.aux_labels_sym = aux_labels_sym
+ L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/WSASR/local/train_bpe_model.py b/egs/librispeech/WSASR/local/train_bpe_model.py
new file mode 100755
index 000000000..43142aee4
--- /dev/null
+++ b/egs/librispeech/WSASR/local/train_bpe_model.py
@@ -0,0 +1,100 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# You can install sentencepiece via:
+#
+# pip install sentencepiece
+#
+# Due to an issue reported in
+# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
+#
+# Please install a version >=0.1.96
+
+import argparse
+import shutil
+from pathlib import Path
+
+import sentencepiece as spm
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Input and output directory.
+ The generated bpe.model is saved to this directory.
+ """,
+ )
+
+ parser.add_argument(
+ "--transcript",
+ type=str,
+ help="Training transcript.",
+ )
+
+ parser.add_argument(
+ "--vocab-size",
+ type=int,
+ help="Vocabulary size for BPE training",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+ vocab_size = args.vocab_size
+ lang_dir = Path(args.lang_dir)
+
+ model_type = "unigram"
+
+ model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
+ train_text = args.transcript
+ character_coverage = 1.0
+ input_sentence_size = 100000000
+
+ user_defined_symbols = ["", ""]
+ unk_id = len(user_defined_symbols)
+ # Note: unk_id is fixed to 2.
+ # If you change it, you should also change other
+ # places that are using it.
+
+ model_file = Path(model_prefix + ".model")
+ if not model_file.is_file():
+ spm.SentencePieceTrainer.train(
+ input=train_text,
+ vocab_size=vocab_size,
+ model_type=model_type,
+ model_prefix=model_prefix,
+ input_sentence_size=input_sentence_size,
+ character_coverage=character_coverage,
+ user_defined_symbols=user_defined_symbols,
+ unk_id=unk_id,
+ bos_id=-1,
+ eos_id=-1,
+ )
+ else:
+ print(f"{model_file} exists - skipping")
+ return
+
+ shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/WSASR/local/validate_bpe_lexicon.py b/egs/librispeech/WSASR/local/validate_bpe_lexicon.py
new file mode 100755
index 000000000..16a489c11
--- /dev/null
+++ b/egs/librispeech/WSASR/local/validate_bpe_lexicon.py
@@ -0,0 +1,85 @@
+#!/usr/bin/env python3
+# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script checks that there are no OOV tokens in the BPE-based lexicon.
+
+Usage example:
+
+ python3 ./local/validate_bpe_lexicon.py \
+ --lexicon /path/to/lexicon.txt \
+ --bpe-model /path/to/bpe.model
+"""
+
+import argparse
+from pathlib import Path
+from typing import List, Tuple
+
+import sentencepiece as spm
+
+from icefall.lexicon import read_lexicon
+
+# Map word to word pieces
+Lexicon = List[Tuple[str, List[str]]]
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--lexicon",
+ required=True,
+ type=Path,
+ help="Path to lexicon.txt",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ required=True,
+ type=Path,
+ help="Path to bpe.model",
+ )
+
+ parser.add_argument(
+ "--otc-token",
+ required=True,
+ type=str,
+ help="OTC token",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+ assert args.lexicon.is_file(), args.lexicon
+ assert args.bpe_model.is_file(), args.bpe_model
+
+ lexicon = read_lexicon(args.lexicon)
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(str(args.bpe_model))
+
+ word_pieces = set(sp.id_to_piece(list(range(sp.vocab_size()))))
+ word_pieces.add(f"▁{args.otc_token}")
+ for word, pieces in lexicon:
+ for p in pieces:
+ if p not in word_pieces:
+ raise ValueError(f"The word {word} contains an OOV token {p}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/WSASR/local/validate_manifest.py b/egs/librispeech/WSASR/local/validate_manifest.py
new file mode 100755
index 000000000..f620b91ea
--- /dev/null
+++ b/egs/librispeech/WSASR/local/validate_manifest.py
@@ -0,0 +1,92 @@
+#!/usr/bin/env python3
+# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script checks the following assumptions of the generated manifest:
+
+- Single supervision per cut
+- Supervision time bounds are within cut time bounds
+
+We will add more checks later if needed.
+
+Usage example:
+
+ python3 ./local/validate_manifest.py \
+ ./data/fbank/librispeech_cuts_train-clean-100.jsonl.gz
+
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+from lhotse import CutSet, load_manifest_lazy
+from lhotse.cut import Cut
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "manifest",
+ type=Path,
+ help="Path to the manifest file",
+ )
+
+ return parser.parse_args()
+
+
+def validate_one_supervision_per_cut(c: Cut):
+ if len(c.supervisions) != 1:
+ raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions")
+
+
+def validate_supervision_and_cut_time_bounds(c: Cut):
+ s = c.supervisions[0]
+ if s.start < c.start:
+ raise ValueError(
+ f"{c.id}: Supervision start time {s.start} is less "
+ f"than cut start time {c.start}"
+ )
+
+ if s.end > c.end:
+ raise ValueError(
+ f"{c.id}: Supervision end time {s.end} is larger "
+ f"than cut end time {c.end}"
+ )
+
+
+def main():
+ args = get_args()
+
+ manifest = args.manifest
+ logging.info(f"Validating {manifest}")
+
+ assert manifest.is_file(), f"{manifest} does not exist"
+ cut_set = load_manifest_lazy(manifest)
+ assert isinstance(cut_set, CutSet)
+
+ for c in cut_set:
+ validate_one_supervision_per_cut(c)
+ validate_supervision_and_cut_time_bounds(c)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ main()
diff --git a/egs/librispeech/WSASR/prepare.sh b/egs/librispeech/WSASR/prepare.sh
new file mode 100755
index 000000000..f6a922fde
--- /dev/null
+++ b/egs/librispeech/WSASR/prepare.sh
@@ -0,0 +1,233 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+nj=15
+stage=-1
+stop_stage=100
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+# - $dl_dir/LibriSpeech
+# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it.
+# You can download them from https://www.openslr.org/12
+#
+# - $dl_dir/lm
+# This directory contains the following files downloaded from
+# http://www.openslr.org/resources/11
+#
+# - 3-gram.pruned.1e-7.arpa.gz
+# - 3-gram.pruned.1e-7.arpa
+# - 4-gram.arpa.gz
+# - 4-gram.arpa
+# - librispeech-vocab.txt
+# - librispeech-lexicon.txt
+# - librispeech-lm-norm.txt.gz
+#
+otc_token=""
+feature_type="ssl"
+
+dl_dir=$PWD/download
+manifests_dir="data/manifests"
+feature_dir="data/${feature_type}"
+lang_dir="data/lang"
+lm_dir="data/lm"
+
+perturb_speed=false
+
+# ssl or fbank
+
+. ./cmd.sh
+. shared/parse_options.sh || exit 1
+
+# vocab size for sentence piece models.
+# It will generate data/lang_bpe_xxx,
+# data/lang_bpe_yyy if the array contains xxx, yyy
+vocab_sizes=(
+ 200
+)
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: ${dl_dir}"
+
+if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
+ log "Stage -1: Download LM"
+ mkdir -p ${dl_dir}/lm
+ if [ ! -e ${dl_dir}/lm/.done ]; then
+ ./local/download_lm.py --out-dir=${dl_dir}/lm
+ touch ${dl_dir}/lm/.done
+ fi
+fi
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Download data"
+
+ # If you have pre-downloaded it to /path/to/LibriSpeech,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/LibriSpeech $dl_dir/LibriSpeech
+ #
+ if [ ! -d $dl_dir/LibriSpeech/train-clean-100 ]; then
+ lhotse download librispeech --full ${dl_dir}
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare LibriSpeech manifest"
+ # We assume that you have downloaded the LibriSpeech corpus
+ # to $dl_dir/LibriSpeech
+ mkdir -p data/manifests
+ if [ ! -e data/manifests/.librispeech.done ]; then
+ lhotse prepare librispeech -j ${nj} \
+ -p dev-clean \
+ -p dev-other \
+ -p test-clean \
+ -p test-other \
+ -p train-clean-100 "${dl_dir}/LibriSpeech" "${manifests_dir}"
+ touch data/manifests/.librispeech.done
+ fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Compute ${feature_type} feature for librispeech (train-clean-100)"
+ mkdir -p "${feature_dir}"
+ if [ ! -e "${feature_dir}/.librispeech.done" ]; then
+ if [ "${feature_type}" = ssl ]; then
+ ./local/compute_ssl_librispeech.py
+ elif [ "${feature_type}" = fbank ]; then
+ ./local/compute_fbank_librispeech.py --perturb-speed ${perturb_speed}
+ else
+ log "Error: not supported --feature-type '${feature_type}'"
+ exit 2
+ fi
+
+ touch "${feature_dir}.librispeech.done"
+ fi
+
+ if [ ! -e "${feature_dir}/.librispeech-validated.done" ]; then
+ log "Validating data/ssl for LibriSpeech"
+ parts=(
+ train-clean-100
+ test-clean
+ test-other
+ dev-clean
+ dev-other
+ )
+ for part in ${parts[@]}; do
+ python3 ./local/validate_manifest.py \
+ "${feature_dir}/librispeech_cuts_${part}.jsonl.gz"
+ done
+ touch "${feature_dir}/.librispeech-validated.done"
+ fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Prepare words.txt"
+ mkdir -p ${lang_dir}
+
+ (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) |
+ cat - $dl_dir/lm/librispeech-lexicon.txt |
+ sort | uniq > ${lang_dir}/lexicon.txt
+
+ local/get_words_from_lexicon.py \
+ --lang-dir ${lang_dir} \
+ --otc-token ${otc_token}
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Prepare BPE based lang"
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ bpe_lang_dir="data/lang_bpe_${vocab_size}"
+ mkdir -p "${bpe_lang_dir}"
+ # We reuse words.txt from phone based lexicon
+ # so that the two can share G.pt later.
+ cp "${lang_dir}/words.txt" "${bpe_lang_dir}"
+
+ if [ ! -f "${bpe_lang_dir}/transcript_words.txt" ]; then
+ log "Generate data for BPE training"
+ files=$(
+ find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
+ find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt"
+ find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt"
+ )
+ for f in ${files[@]}; do
+ cat $f | cut -d " " -f 2-
+ done > "${bpe_lang_dir}/transcript_words.txt"
+ fi
+
+ if [ ! -f ${bpe_lang_dir}/bpe.model ]; then
+ ./local/train_bpe_model.py \
+ --lang-dir ${bpe_lang_dir} \
+ --vocab-size ${vocab_size} \
+ --transcript ${bpe_lang_dir}/transcript_words.txt
+ fi
+
+ if [ ! -f ${bpe_lang_dir}/L_disambig.pt ]; then
+ ./local/prepare_otc_lang_bpe.py \
+ --lang-dir "${bpe_lang_dir}" \
+ --otc-token "${otc_token}"
+
+ log "Validating ${bpe_lang_dir}/lexicon.txt"
+ ./local/validate_bpe_lexicon.py \
+ --lexicon ${bpe_lang_dir}/lexicon.txt \
+ --bpe-model ${bpe_lang_dir}/bpe.model \
+ --otc-token "${otc_token}"
+ fi
+ done
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Prepare G"
+ # We assume you have install kaldilm, if not, please install
+ # it using: pip install kaldilm
+
+ mkdir -p "${lm_dir}"
+ if [ ! -f ${lm_dir}/G_3_gram.fst.txt ]; then
+ # It is used in building HLG
+ python3 -m kaldilm \
+ --read-symbol-table="${lang_dir}/words.txt" \
+ --disambig-symbol='#0' \
+ --max-order=3 \
+ ${dl_dir}/lm/3-gram.pruned.1e-7.arpa > ${lm_dir}/G_3_gram.fst.txt
+ fi
+
+ if [ ! -f ${lm_dir}/G_4_gram.fst.txt ]; then
+ # It is used for LM rescoring
+ python3 -m kaldilm \
+ --read-symbol-table="${lang_dir}/words.txt" \
+ --disambig-symbol='#0' \
+ --max-order=4 \
+ ${dl_dir}/lm/4-gram.arpa > ${lm_dir}/G_4_gram.fst.txt
+ fi
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "Stage 6: Compile HLG"
+ # Note If ./local/compile_hlg.py throws OOM,
+ # please switch to the following command
+ #
+ # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ bpe_lang_dir="data/lang_bpe_${vocab_size}"
+ echo "LM DIR: ${lm_dir}"
+ ./local/compile_hlg.py \
+ --lm-dir "${lm_dir}" \
+ --lang-dir "${bpe_lang_dir}"
+ done
+fi
diff --git a/icefall/otc_graph_compiler.py b/icefall/otc_graph_compiler.py
new file mode 100644
index 000000000..bfd679452
--- /dev/null
+++ b/icefall/otc_graph_compiler.py
@@ -0,0 +1,246 @@
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+# 2023 Johns Hopkins University (author: Dongji Gao)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from pathlib import Path
+from typing import List, Union
+
+import k2
+import sentencepiece as spm
+import torch
+
+from icefall.utils import str2bool
+
+
+class OtcTrainingGraphCompiler(object):
+ def __init__(
+ self,
+ lang_dir: Path,
+ otc_token: str,
+ device: Union[str, torch.device] = "cpu",
+ sos_token: str = "",
+ eos_token: str = "",
+ initial_bypass_weight: float = 0.0,
+ initial_self_loop_weight: float = 0.0,
+ bypass_weight_decay: float = 0.0,
+ self_loop_weight_decay: float = 0.0,
+ ) -> None:
+ """
+ Args:
+ lang_dir:
+ This directory is expected to contain the following files:
+
+ - bpe.model
+ - words.txt
+ otc_token:
+ The special token in OTC that represent all non-blank tokens
+ device:
+ It indicates CPU or CUDA.
+ sos_token:
+ The word piece that represents sos.
+ eos_token:
+ The word piece that represents eos.
+ """
+ lang_dir = Path(lang_dir)
+ bpe_model_file = lang_dir / "bpe.model"
+ sp = spm.SentencePieceProcessor()
+ sp.load(str(bpe_model_file))
+ self.sp = sp
+ self.token_table = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
+
+ self.otc_token = otc_token
+ assert self.otc_token in self.token_table
+
+ self.device = device
+
+ self.sos_id = self.sp.piece_to_id(sos_token)
+ self.eos_id = self.sp.piece_to_id(eos_token)
+
+ assert self.sos_id != self.sp.unk_id()
+ assert self.eos_id != self.sp.unk_id()
+
+ max_token_id = self.get_max_token_id()
+ ctc_topo = k2.ctc_topo(max_token_id, modified=False)
+ self.ctc_topo = ctc_topo.to(self.device)
+
+ self.initial_bypass_weight = initial_bypass_weight
+ self.initial_self_loop_weight = initial_self_loop_weight
+ self.bypass_weight_decay = bypass_weight_decay
+ self.self_loop_weight_decay = self_loop_weight_decay
+
+ def get_max_token_id(self):
+ max_token_id = 0
+ for symbol in self.token_table.symbols:
+ if not symbol.startswith("#"):
+ max_token_id = max(self.token_table[symbol], max_token_id)
+ assert max_token_id > 0
+
+ return max_token_id
+
+ def make_arc(
+ self,
+ from_state: int,
+ to_state: int,
+ symbol: Union[str, int],
+ weight: float,
+ ):
+ return f"{from_state} {to_state} {symbol} {weight}"
+
+ def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
+ """Convert a list of texts to a list-of-list of piece IDs.
+
+ Args:
+ texts:
+ It is a list of strings. Each string consists of space(s)
+ separated words. An example containing two strings is given below:
+
+ ['HELLO ICEFALL', 'HELLO k2']
+ Returns:
+ Return a list-of-list of piece IDs.
+ """
+ return self.sp.encode(texts, out_type=int)
+
+ def compile(
+ self,
+ texts: List[str],
+ allow_bypass_arc: str2bool = True,
+ allow_self_loop_arc: str2bool = True,
+ bypass_weight: float = 0.0,
+ self_loop_weight: float = 0.0,
+ ) -> k2.Fsa:
+ """Build a OTC graph from a texts (list of words).
+
+ Args:
+ texts:
+ A list of strings. Each string contains a sentence for an utterance.
+ A sentence consists of spaces separated words. An example `texts`
+ looks like:
+ ['hello icefall', 'CTC training with k2']
+ allow_bypass_arc:
+ Whether to add bypass arc to training graph for substitution
+ and insertion errors (wrong or extra words in the transcript).
+ allow_self_loop_arc:
+ Whether to add self-loop arc to training graph for deletion
+ errors (missing words in the transcript).
+ bypass_weight:
+ Weight associated with bypass arc.
+ self_loop_weight:
+ Weight associated with self-loop arc.
+
+ Return:
+ Return an FsaVec, which is the result of composing a
+ CTC topology with OTC FSAs constructed from the given texts.
+ """
+
+ transcript_fsa = self.convert_transcript_to_fsa(
+ texts,
+ self.otc_token,
+ allow_bypass_arc,
+ allow_self_loop_arc,
+ bypass_weight,
+ self_loop_weight,
+ )
+ transcript_fsa = transcript_fsa.to(self.device)
+ fsa_with_self_loop = k2.remove_epsilon_and_add_self_loops(transcript_fsa)
+ fsa_with_self_loop = k2.arc_sort(fsa_with_self_loop)
+
+ graph = k2.compose(
+ self.ctc_topo,
+ fsa_with_self_loop,
+ treat_epsilons_specially=False,
+ )
+ assert graph.requires_grad is False
+
+ return graph
+
+ def convert_transcript_to_fsa(
+ self,
+ texts: List[str],
+ otc_token: str,
+ allow_bypass_arc: str2bool = True,
+ allow_self_loop_arc: str2bool = True,
+ bypass_weight: float = 0.0,
+ self_loop_weight: float = 0.0,
+ ):
+ otc_token_id = self.token_table[otc_token]
+
+ transcript_fsa_list = []
+ for text in texts:
+ text_piece_ids = []
+
+ for word in text.split():
+ piece_ids = self.sp.encode(word, out_type=int)
+ text_piece_ids.append(piece_ids)
+
+ arcs = []
+ start_state = 0
+ cur_state = start_state
+ next_state = 1
+
+ for piece_ids in text_piece_ids:
+ bypass_cur_state = cur_state
+
+ if allow_self_loop_arc:
+ self_loop_arc = self.make_arc(
+ cur_state,
+ cur_state,
+ otc_token_id,
+ self_loop_weight,
+ )
+ arcs.append(self_loop_arc)
+
+ for piece_id in piece_ids:
+ arc = self.make_arc(cur_state, next_state, piece_id, 0.0)
+ arcs.append(arc)
+
+ cur_state = next_state
+ next_state += 1
+
+ bypass_next_state = cur_state
+ if allow_bypass_arc:
+ bypass_arc = self.make_arc(
+ bypass_cur_state,
+ bypass_next_state,
+ otc_token_id,
+ bypass_weight,
+ )
+ arcs.append(bypass_arc)
+ bypass_cur_state = cur_state
+
+ if allow_self_loop_arc:
+ self_loop_arc = self.make_arc(
+ cur_state,
+ cur_state,
+ otc_token_id,
+ self_loop_weight,
+ )
+ arcs.append(self_loop_arc)
+
+ # Deal with final state
+ final_state = next_state
+ final_arc = self.make_arc(cur_state, final_state, -1, 0.0)
+ arcs.append(final_arc)
+ arcs.append(f"{final_state}")
+ sorted_arcs = sorted(arcs, key=lambda a: int(a.split()[0]))
+
+ transcript_fsa = k2.Fsa.from_str("\n".join(sorted_arcs))
+ transcript_fsa = k2.arc_sort(transcript_fsa)
+ transcript_fsa_list.append(transcript_fsa)
+
+ transcript_fsa_vec = k2.create_fsa_vec(transcript_fsa_list)
+
+ return transcript_fsa_vec
diff --git a/icefall/utils.py b/icefall/utils.py
index 947d79438..8fda3a4ca 100644
--- a/icefall/utils.py
+++ b/icefall/utils.py
@@ -263,6 +263,70 @@ def get_texts(
return aux_labels.tolist()
+def encode_supervisions_otc(
+ supervisions: dict,
+ subsampling_factor: int,
+ token_ids: Optional[List[List[int]]] = None,
+) -> Tuple[torch.Tensor, Union[List[str], List[List[int]]]]:
+ """
+ Encodes Lhotse's ``batch["supervisions"]`` dict into
+ a pair of torch Tensor, and a list of transcription strings or token indexes
+
+ The supervision tensor has shape ``(batch_size, 3)``.
+ Its second dimension contains information about sequence index [0],
+ start frames [1] and num frames [2].
+
+ The batch items might become re-ordered during this operation -- the
+ returned tensor and list of strings are guaranteed to be consistent with
+ each other.
+ """
+ supervision_segments = torch.stack(
+ (
+ supervisions["sequence_idx"],
+ torch.div(
+ supervisions["start_frame"],
+ subsampling_factor,
+ rounding_mode="floor",
+ ),
+ torch.div(
+ supervisions["num_frames"],
+ subsampling_factor,
+ rounding_mode="floor",
+ ),
+ ),
+ 1,
+ ).to(torch.int32)
+
+ indices = torch.argsort(supervision_segments[:, 2], descending=True)
+ supervision_segments = supervision_segments[indices]
+
+ ids = []
+ verbatim_texts = []
+ sorted_ids = []
+ sorted_verbatim_texts = []
+
+ for cut in supervisions["cut"]:
+ id = cut.id
+ if hasattr(cut.supervisions[0], "verbatim_text"):
+ verbatim_text = cut.supervisions[0].verbatim_text
+ else:
+ verbatim_text = ""
+ ids.append(id)
+ verbatim_texts.append(verbatim_text)
+
+ for index in indices.tolist():
+ sorted_ids.append(ids[index])
+ sorted_verbatim_texts.append(verbatim_texts[index])
+
+ if token_ids is None:
+ texts = supervisions["text"]
+ res = [texts[idx] for idx in indices]
+ else:
+ res = [token_ids[idx] for idx in indices]
+
+ return supervision_segments, res, sorted_ids, sorted_verbatim_texts
+
+
@dataclass
class DecodingResults:
# timestamps[i][k] contains the frame number on which tokens[i][k]