mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Add scripts and recipe for BTC/OTC (#1255)
This commit is contained in:
parent
8181d19860
commit
3abc290c11
224
egs/librispeech/WSASR/README.md
Normal file
224
egs/librispeech/WSASR/README.md
Normal file
@ -0,0 +1,224 @@
|
||||
# Introduction
|
||||
|
||||
This is a weakly supervised ASR recipe for the LibriSpeech (clean 100 hours) dataset. We train a
|
||||
conformer model using [Bypass Temporal Classification](https://arxiv.org/pdf/2306.01031.pdf) (BTC)/[Omni-temporal Classification](https://arxiv.org/pdf/2309.15796.pdf) (OTC) with transcripts with synthetic errors. In this README, we will describe
|
||||
the task and the BTC/OTC training process.
|
||||
|
||||
Note that OTC is an extension of BTC and supports all BTC functions. Therefore, in the following, we only describe OTC.
|
||||
## Task
|
||||
We propose BTC/OTC to directly train an ASR system leveraging weak supervision, i.e., speech with non-verbatim transcripts. This is achieved by using a special token $\star$ to model uncertainties (i.e., substitution errors, insertion errors, and deletion errors)
|
||||
within the WFST framework during training.
|
||||
|
||||
|
||||
<div style="display: flex;flex; justify-content: space-between">
|
||||
<figure style="flex: 2; text-align: center; margin: 5px;">
|
||||
<img src="figures/sub.png" alt="Image 1" width="25%" />
|
||||
|
||||
</figure>
|
||||
<figure style="flex: 2; text-align: center; margin: 5px;">
|
||||
<img src="figures/ins.png" alt="Image 2" width="25%" />
|
||||
|
||||
</figure>
|
||||
<figure style="flex: 2; text-align: center;margin: 5px;">
|
||||
<img src="figures/del.png" alt="Image 3" width="25%" />
|
||||
|
||||
</figure>
|
||||
</div>
|
||||
<figcaption> Examples of errors (substitution, insertion, and deletion) in the transcript. The grey box is the verbatim transcript and the red box is the inaccurate transcript. Inaccurate words are marked in bold.</figcaption> <br><br>
|
||||
|
||||
|
||||
We modify $G(\mathbf{y})$ by adding self-loop arcs into each state and bypass arcs into each arc.
|
||||
<p align="center">
|
||||
<img src="figures/otc_g.png" alt="Image Alt Text" width="50%" />
|
||||
|
||||
</p>
|
||||
|
||||
We incorporate the penalty strategy and apply different configurations for the self-loop arc and bypass arc. The penalties are set as
|
||||
|
||||
$$\lambda_{1_{i}} = \beta_{1} * \tau_{1}^{i},\quad \lambda_{2_{i}} = \beta_{2} * \tau_{2}^{i}$$
|
||||
|
||||
for the $i$-th training epoch. $\beta$ is the initial penalty that encourages the model to rely more on the given transcript at the start of training.
|
||||
It decays exponentially by a factor of $\tau \in (0, 1)$, gradually encouraging the model to align speech with $\star$ when getting confused.
|
||||
|
||||
After composing the modified WFST $G_{\text{otc}}(\mathbf{y})$ with $L$ and $T$, the OTC training graph is shown in this figure:
|
||||
<figure style="text-align: center">
|
||||
<img src="figures/otc_training_graph.drawio.png" alt="Image Alt Text" />
|
||||
<figcaption>OTC training graph. The self-loop arcs and bypass arcs are highlighted in green and blue, respectively.</figcaption>
|
||||
</figure>
|
||||
|
||||
The $\star$ is represented as the average probability of all non-blank tokens.
|
||||
<p align="center">
|
||||
<img src="figures/otc_emission.drawio.png" width="50%" />
|
||||
</p>
|
||||
|
||||
The weight of $\star$ is the log average probability of "a" and "b": $\log \frac{e^{-1.2} + e^{-2.3}}{2} = -1.6$ and $\log \frac{e^{-1.9} + e^{-0.5}}{2} = -1.0$ for 2 frames.
|
||||
|
||||
## Description of the recipe
|
||||
### Preparation
|
||||
```
|
||||
# feature_type can be ssl or fbank
|
||||
feature_type=ssl
|
||||
feature_dir="data/${feature_type}"
|
||||
manifest_dir="${feature_dir}"
|
||||
lang_dir="data/lang"
|
||||
lm_dir="data/lm"
|
||||
exp_dir="conformer_ctc2/exp"
|
||||
otc_token="<star>"
|
||||
|
||||
./prepare.sh \
|
||||
--feature-type "${feature_type}" \
|
||||
--feature-dir "${feature_dir}" \
|
||||
--lang-dir "${lang_dir}" \
|
||||
--lm-dir "${lm_dir}" \
|
||||
--otc-token "${otc_token}"
|
||||
```
|
||||
This script adds the 'otc_token' ('\<star\>') and its corresponding sentence-piece ('▁\<star\>') to 'words.txt' and 'tokens.txt,' respectively. Additionally, it computes SSL features using the 'wav2vec2-base' model. (You can use GPU to accelerate feature extraction).
|
||||
|
||||
### Making synthetic errors to the transcript (train-clean-100) [optional]
|
||||
```
|
||||
sub_er=0.17
|
||||
ins_er=0.17
|
||||
del_er=0.17
|
||||
synthetic_train_manifest="librispeech_cuts_train-clean-100_${sub_er}_${ins_er}_${del_er}.jsonl.gz"
|
||||
|
||||
./local/make_error_cutset.py \
|
||||
--input-cutset "${manifest_dir}/librispeech_cuts_train-clean-100.jsonl.gz" \
|
||||
--words-file "${lang_dir}/words.txt" \
|
||||
--sub-error-rate "${sub_er}" \
|
||||
--ins-error-rate "${ins_er}" \
|
||||
--del-error-rate "${del_er}" \
|
||||
--output-cutset "${manifest_dir}/${synthetic_train_manifest}"
|
||||
```
|
||||
This script generates synthetic substitution, insertion, and deletion errors in the transcript with ratios 'sub_er', 'ins_er', and 'del_er', respectively. The original transcript is saved as 'verbatim transcript' in the cutset, along with information on how the transcript is corrupted:
|
||||
|
||||
- '[hello]' indicates the original word 'hello' is substituted by another word
|
||||
- '[]' indicates an extra word is inserted into the transcript
|
||||
- '-hello-' indicates the word 'hello' is deleted from the transcript
|
||||
|
||||
So if the original transcript is "have a nice day" and the synthetic one is "a very good day", the 'verbatim transcript' would be:
|
||||
```
|
||||
original: have a nice day
|
||||
synthetic: a very good day
|
||||
verbatim: -have- a [] [nice] day
|
||||
```
|
||||
|
||||
### Training
|
||||
The training uses synthetic data based on the train-clean-100 subset.
|
||||
```
|
||||
otc_lang_dir=data/lang_bpe_200
|
||||
|
||||
allow_bypass_arc=true
|
||||
allow_self_loop_arc=true
|
||||
initial_bypass_weight=-19
|
||||
initial_self_loop_weight=3.75
|
||||
bypass_weight_decay=0.975
|
||||
self_loop_weight_decay=0.999
|
||||
|
||||
show_alignment=true
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
./conformer_ctc2/train.py \
|
||||
--world-size 4 \
|
||||
--manifest-dir "${manifest_dir}" \
|
||||
--train-manifest "${synthetic_train_manifest}" \
|
||||
--exp-dir "${exp_dir}" \
|
||||
--lang-dir "${otc_lang_dir}" \
|
||||
--otc-token "${otc_token}" \
|
||||
--allow-bypass-arc "${allow_bypass_arc}" \
|
||||
--allow-self-loop-arc "${allow_self_loop_arc}" \
|
||||
--initial-bypass-weight "${initial_bypass_weight}" \
|
||||
--initial-self-loop-weight "${initial_self_loop_weight}" \
|
||||
--bypass-weight-decay "${bypass_weight_decay}" \
|
||||
--self-loop-weight-decay "${self_loop_weight_decay}" \
|
||||
--show-alignment "${show_alignment}"
|
||||
```
|
||||
The bypass arc deals with substitution and insertion errors, while the self-loop arc deals with deletion errors. Using "--show-alignment" would print the best alignment during training, which is very helpful for tuning hyperparameters and debugging.
|
||||
|
||||
### Decoding
|
||||
```
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
./conformer_ctc2/decode.py \
|
||||
--manifest-dir "${manifest_dir}" \
|
||||
--exp-dir "${exp_dir}" \
|
||||
--lang-dir "${otc_lang_dir}" \
|
||||
--lm-dir "${lm_dir}" \
|
||||
--otc-token "${otc_token}"
|
||||
```
|
||||
|
||||
### Results (ctc-greedy-search)
|
||||
<table>
|
||||
<tr>
|
||||
<td rowspan=2>Training Criterion</td>
|
||||
<td colspan=2>ssl</td>
|
||||
<td colspan=2>fbank</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>test-clean</td>
|
||||
<td>test-other</td>
|
||||
<td>test-clean</td>
|
||||
<td>test-other</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>CTC</td>
|
||||
<td>100.0</td>
|
||||
<td>100.0</td>
|
||||
<td>99.89</td>
|
||||
<td>99.98</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>OTC</td>
|
||||
<td>11.89</td>
|
||||
<td>25.46</td>
|
||||
<td>20.14</td>
|
||||
<td>44.24</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### Results (1best, blank_bias=-4)
|
||||
<table>
|
||||
<tr>
|
||||
<td rowspan=2>Training Criterion</td>
|
||||
<td colspan=2>ssl</td>
|
||||
<td colspan=2>fbank</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>test-clean</td>
|
||||
<td>test-other</td>
|
||||
<td>test-clean</td>
|
||||
<td>test-other</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>CTC</td>
|
||||
<td>98.40</td>
|
||||
<td>98.68</td>
|
||||
<td>99.79</td>
|
||||
<td>99.86</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>OTC</td>
|
||||
<td>6.59</td>
|
||||
<td>15.98</td>
|
||||
<td>11.78</td>
|
||||
<td>32.38</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Pre-trained Model
|
||||
Pre-trained model: <https://huggingface.co/dgao/icefall-otc-librispeech-conformer-ctc2>
|
||||
|
||||
## Citations
|
||||
```
|
||||
@inproceedings{gao2023bypass,
|
||||
title={Bypass Temporal Classification: Weakly Supervised Automatic Speech Recognition with Imperfect Transcripts},
|
||||
author={Gao, Dongji and Wiesner, Matthew and Xu, Hainan and Garcia, Leibny Paola and Povey, Daniel and Khudanpur, Sanjeev},
|
||||
booktitle={INTERSPEECH},
|
||||
year={2023}
|
||||
}
|
||||
|
||||
@inproceedings{gao2023learning,
|
||||
title={Learning from Flawed Data: Weakly Supervised Automatic Speech Recognition},
|
||||
author={Gao, Dongji and Xu, Hainan and Raj, Desh and Garcia, Leibny Paola and Povey, Daniel and Khudanpur, Sanjeev},
|
||||
booktitle={IEEE ASRU},
|
||||
year={2023}
|
||||
}
|
||||
```
|
1
egs/librispeech/WSASR/conformer_ctc2/__init__.py
Symbolic link
1
egs/librispeech/WSASR/conformer_ctc2/__init__.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/pruned_transducer_stateless2/__init__.py
|
369
egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py
Normal file
369
egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py
Normal file
@ -0,0 +1,369 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
# 2023 John Hopkins University (author: Dongji Gao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import AudioSamples # noqa F401 For AudioSamples
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class LibriSpeechAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--full-libri",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Used only when --mini-libri is False.When enabled,
|
||||
use 960h LibriSpeech. Otherwise, use 100h subset.""",
|
||||
)
|
||||
group.add_argument(
|
||||
"--mini-libri",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="True for mini librispeech",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/ssl"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--concatenate-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, utterances (cuts) will be concatenated "
|
||||
"to minimize the amount of padding.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--duration-factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Determines the maximum duration of a concatenated cut "
|
||||
"relative to the duration of the longest cut in a batch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gap",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The amount of padding (in seconds) inserted between "
|
||||
"concatenated cuts. This padding is filled with noise when "
|
||||
"noise augmentation is used.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--train-manifest",
|
||||
type=str,
|
||||
default="librispeech_cuts_train-clean-100.jsonl.gz",
|
||||
help="Train manifest file.",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||
)
|
||||
# Cut concatenation should be the first transform in the list,
|
||||
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||
# different utterances.
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2SpeechRecognitionDataset(
|
||||
input_strategy=eval(self.args.input_strategy)(),
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_5_cuts(self) -> CutSet:
|
||||
logging.info("mini_librispeech: About to get train-clean-5 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_100_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-100 cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / self.args.train_manifest)
|
||||
|
||||
@lru_cache()
|
||||
def train_all_shuf_cuts(self) -> CutSet:
|
||||
logging.info(
|
||||
"About to get the shuffled train-clean-100, \
|
||||
train-clean-360 and train-other-500 cuts"
|
||||
)
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_2_cuts(self) -> CutSet:
|
||||
logging.info("mini_librispeech: About to get dev-clean-2 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
||||
)
|
1
egs/librispeech/WSASR/conformer_ctc2/attention.py
Symbolic link
1
egs/librispeech/WSASR/conformer_ctc2/attention.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/conformer_ctc2/attention.py
|
949
egs/librispeech/WSASR/conformer_ctc2/conformer.py
Normal file
949
egs/librispeech/WSASR/conformer_ctc2/conformer.py
Normal file
@ -0,0 +1,949 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
||||
# 2022 Xiaomi Corp. (author: Quandong Wang)
|
||||
# 2023 Johns Hopkins University (author: Dongji Gao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
DoubleSwish,
|
||||
ScaledConv1d,
|
||||
ScaledLinear,
|
||||
)
|
||||
from subsampling import Conv2dSubsampling, Conv2dSubsampling2
|
||||
from torch import Tensor, nn
|
||||
from transformer import Supervisions, Transformer, encoder_padding_mask
|
||||
|
||||
|
||||
class Conformer(Transformer):
|
||||
"""
|
||||
Args:
|
||||
num_features (int): Number of input features
|
||||
num_classes (int): Number of output classes
|
||||
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
|
||||
d_model (int): attention dimension, also the output dimension
|
||||
nhead (int): number of head
|
||||
dim_feedforward (int): feedforward dimention
|
||||
num_encoder_layers (int): number of encoder layers
|
||||
num_decoder_layers (int): number of decoder layers
|
||||
dropout (float): dropout rate
|
||||
layer_dropout (float): layer-dropout rate.
|
||||
cnn_module_kernel (int): Kernel size of convolution module
|
||||
vgg_frontend (bool): whether to use vgg frontend.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
num_classes: int,
|
||||
subsampling_factor: int = 2,
|
||||
d_model: int = 256,
|
||||
nhead: int = 4,
|
||||
dim_feedforward: int = 2048,
|
||||
num_encoder_layers: int = 12,
|
||||
num_decoder_layers: int = 6,
|
||||
dropout: float = 0.2,
|
||||
layer_dropout: float = 0.075,
|
||||
cnn_module_kernel: int = 31,
|
||||
) -> None:
|
||||
super(Conformer, self).__init__(
|
||||
num_features=num_features,
|
||||
num_classes=num_classes,
|
||||
subsampling_factor=subsampling_factor,
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
dim_feedforward=dim_feedforward,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
num_decoder_layers=num_decoder_layers,
|
||||
dropout=dropout,
|
||||
layer_dropout=layer_dropout,
|
||||
)
|
||||
|
||||
self.num_features = num_features
|
||||
self.subsampling_factor = subsampling_factor
|
||||
if subsampling_factor != 4 and subsampling_factor != 2:
|
||||
raise NotImplementedError("Support only 'subsampling_factor=4 or 2'.")
|
||||
|
||||
# self.encoder_embed converts the input of shape (N, T, num_features)
|
||||
# to the shape (N, T//subsampling_factor, d_model).
|
||||
# That is, it does two things simultaneously:
|
||||
# (1) subsampling: T -> T//subsampling_factor
|
||||
# (2) embedding: num_features -> d_model
|
||||
if self.subsampling_factor == 4:
|
||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||
elif self.subsampling_factor == 2:
|
||||
self.encoder_embed = Conv2dSubsampling2(num_features, d_model)
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
|
||||
encoder_layer = ConformerEncoderLayer(
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward,
|
||||
dropout,
|
||||
layer_dropout,
|
||||
cnn_module_kernel,
|
||||
)
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||
|
||||
def run_encoder(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
supervisions: Optional[Supervisions] = None,
|
||||
warmup: float = 1.0,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
||||
supervisions:
|
||||
Supervision in lhotse format.
|
||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||
CAUTION: It contains length information, i.e., start and number of
|
||||
frames, before subsampling
|
||||
It is read directly from the batch, without any sorting. It is used
|
||||
to compute encoder padding mask, which is used as memory key padding
|
||||
mask for the decoder.
|
||||
warmup:
|
||||
A floating point value that gradually increases from 0 throughout
|
||||
training; when it is >= 1.0 we are "fully warmed up". It is used
|
||||
to turn modules on sequentially.
|
||||
Returns:
|
||||
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
|
||||
Tensor: Mask tensor of dimension (batch_size, input_length)
|
||||
"""
|
||||
x = self.encoder_embed(x)
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
mask = encoder_padding_mask(x.size(0), self.subsampling_factor, supervisions)
|
||||
if mask is not None:
|
||||
mask = mask.to(x.device)
|
||||
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
|
||||
x = self.encoder(
|
||||
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
|
||||
) # (T, N, C)
|
||||
|
||||
# x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
# return x, lengths
|
||||
return x, mask
|
||||
|
||||
|
||||
class ConformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
|
||||
See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
|
||||
|
||||
Args:
|
||||
d_model: the number of expected features in the input (required).
|
||||
nhead: the number of heads in the multiheadattention models (required).
|
||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||
dropout: the dropout value (default=0.1).
|
||||
cnn_module_kernel (int): Kernel size of convolution module.
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
||||
>>> src = torch.rand(10, 32, 512)
|
||||
>>> pos_emb = torch.rand(32, 19, 512)
|
||||
>>> out = encoder_layer(src, pos_emb)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
nhead: int,
|
||||
dim_feedforward: int = 2048,
|
||||
dropout: float = 0.1,
|
||||
layer_dropout: float = 0.075,
|
||||
cnn_module_kernel: int = 31,
|
||||
) -> None:
|
||||
super(ConformerEncoderLayer, self).__init__()
|
||||
|
||||
self.layer_dropout = layer_dropout
|
||||
|
||||
self.d_model = d_model
|
||||
|
||||
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
|
||||
|
||||
self.feed_forward = nn.Sequential(
|
||||
ScaledLinear(d_model, dim_feedforward),
|
||||
ActivationBalancer(channel_dim=-1),
|
||||
DoubleSwish(),
|
||||
nn.Dropout(dropout),
|
||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||
)
|
||||
|
||||
self.feed_forward_macaron = nn.Sequential(
|
||||
ScaledLinear(d_model, dim_feedforward),
|
||||
ActivationBalancer(channel_dim=-1),
|
||||
DoubleSwish(),
|
||||
nn.Dropout(dropout),
|
||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||
)
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||
|
||||
self.norm_final = BasicNorm(d_model)
|
||||
|
||||
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
||||
self.balancer = ActivationBalancer(
|
||||
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder layer (required).
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
src_mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||
bypass layers more frequently.
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*S-1, E)
|
||||
src_mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, N is the batch size, E is the feature number
|
||||
"""
|
||||
src_orig = src
|
||||
|
||||
warmup_scale = min(0.1 + warmup, 1.0)
|
||||
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
|
||||
# completely bypass it.
|
||||
if self.training:
|
||||
alpha = (
|
||||
warmup_scale
|
||||
if torch.rand(()).item() <= (1.0 - self.layer_dropout)
|
||||
else 0.1
|
||||
)
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
# macaron style feed forward module
|
||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||
|
||||
# multi-headed self-attention module
|
||||
src_att = self.self_attn(
|
||||
src,
|
||||
src,
|
||||
src,
|
||||
pos_emb=pos_emb,
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)[0]
|
||||
src = src + self.dropout(src_att)
|
||||
|
||||
# convolution module
|
||||
src = src + self.dropout(
|
||||
self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
||||
)
|
||||
|
||||
# feed forward module
|
||||
src = src + self.dropout(self.feed_forward(src))
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
if alpha != 1.0:
|
||||
src = alpha * src + (1 - alpha) * src_orig
|
||||
|
||||
return src
|
||||
|
||||
|
||||
class ConformerEncoder(nn.Module):
|
||||
r"""ConformerEncoder is a stack of N encoder layers
|
||||
|
||||
Args:
|
||||
encoder_layer: an instance of the ConformerEncoderLayer() class (required).
|
||||
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
||||
>>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
|
||||
>>> src = torch.rand(10, 32, 512)
|
||||
>>> pos_emb = torch.rand(32, 19, 512)
|
||||
>>> out = conformer_encoder(src, pos_emb)
|
||||
"""
|
||||
|
||||
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||
)
|
||||
self.num_layers = num_layers
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
) -> Tensor:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder (required).
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*S-1, E)
|
||||
mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||
|
||||
"""
|
||||
output = src
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
output = mod(
|
||||
output,
|
||||
pos_emb,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class RelPositionalEncoding(torch.nn.Module):
|
||||
"""Relative positional encoding module.
|
||||
|
||||
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
|
||||
|
||||
Args:
|
||||
d_model: Embedding dimension.
|
||||
dropout_rate: Dropout rate.
|
||||
max_len: Maximum input length.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(RelPositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
|
||||
def extend_pe(self, x: Tensor) -> None:
|
||||
"""Reset the positional encodings."""
|
||||
if self.pe is not None:
|
||||
# self.pe contains both positive and negative parts
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
# Suppose `i` means to the position of query vecotr and `j` means the
|
||||
# position of key vector. We use position relative positions when keys
|
||||
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
||||
pe_positive = torch.zeros(x.size(1), self.d_model)
|
||||
pe_negative = torch.zeros(x.size(1), self.d_model)
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
||||
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
||||
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
||||
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
||||
|
||||
# Reserve the order of positive indices and concat both positive and
|
||||
# negative indices. This is used to support the shifting trick
|
||||
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
||||
pe_negative = pe_negative[1:].unsqueeze(0)
|
||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
pos_emb = self.pe[
|
||||
:,
|
||||
self.pe.size(1) // 2
|
||||
- x.size(1)
|
||||
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||
+ x.size(1),
|
||||
]
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
|
||||
|
||||
class RelPositionMultiheadAttention(nn.Module):
|
||||
r"""Multi-Head Attention layer with relative position encoding
|
||||
|
||||
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||
|
||||
Args:
|
||||
embed_dim: total dimension of the model.
|
||||
num_heads: parallel attention heads.
|
||||
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
|
||||
>>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
dropout: float = 0.0,
|
||||
) -> None:
|
||||
super(RelPositionMultiheadAttention, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
|
||||
self.out_proj = ScaledLinear(
|
||||
embed_dim, embed_dim, bias=True, initial_scale=0.25
|
||||
)
|
||||
|
||||
# linear transformation for positional encoding.
|
||||
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
|
||||
# these two learnable bias are used in matrix c and matrix d
|
||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||
self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
|
||||
self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
|
||||
self._reset_parameters()
|
||||
|
||||
def _pos_bias_u(self):
|
||||
return self.pos_bias_u * self.pos_bias_u_scale.exp()
|
||||
|
||||
def _pos_bias_v(self):
|
||||
return self.pos_bias_v * self.pos_bias_v_scale.exp()
|
||||
|
||||
def _reset_parameters(self) -> None:
|
||||
nn.init.normal_(self.pos_bias_u, std=0.01)
|
||||
nn.init.normal_(self.pos_bias_v, std=0.01)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
pos_emb: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
pos_emb: Positional embedding tensor
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
be ignored by the attention. When given a binary mask and a value is True,
|
||||
the corresponding value on the attention layer will be ignored. When given
|
||||
a byte mask and a value is non-zero, the corresponding value on the attention
|
||||
layer will be ignored
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
|
||||
Shape:
|
||||
- Inputs:
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
||||
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
||||
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
||||
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
||||
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
||||
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
||||
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||
is provided, it will be added to the attention weight.
|
||||
|
||||
- Outputs:
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
E is the embedding dimension.
|
||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||
L is the target sequence length, S is the source sequence length.
|
||||
"""
|
||||
return self.multi_head_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
pos_emb,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
self.in_proj.get_weight(),
|
||||
self.in_proj.get_bias(),
|
||||
self.dropout,
|
||||
self.out_proj.get_weight(),
|
||||
self.out_proj.get_bias(),
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
|
||||
def rel_shift(self, x: Tensor) -> Tensor:
|
||||
"""Compute relative positional encoding.
|
||||
|
||||
Args:
|
||||
x: Input tensor (batch, head, time1, 2*time1-1).
|
||||
time1 means the length of query vector.
|
||||
|
||||
Returns:
|
||||
Tensor: tensor of shape (batch, head, time1, time2)
|
||||
(note: time2 has the same value as time1, but it is for
|
||||
the key, while time1 is for the query).
|
||||
"""
|
||||
(batch_size, num_heads, time1, n) = x.shape
|
||||
assert n == 2 * time1 - 1
|
||||
# Note: TorchScript requires explicit arg for stride()
|
||||
batch_stride = x.stride(0)
|
||||
head_stride = x.stride(1)
|
||||
time1_stride = x.stride(2)
|
||||
n_stride = x.stride(3)
|
||||
return x.as_strided(
|
||||
(batch_size, num_heads, time1, time1),
|
||||
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
||||
storage_offset=n_stride * (time1 - 1),
|
||||
)
|
||||
|
||||
def multi_head_attention_forward(
|
||||
self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
pos_emb: Tensor,
|
||||
embed_dim_to_check: int,
|
||||
num_heads: int,
|
||||
in_proj_weight: Tensor,
|
||||
in_proj_bias: Tensor,
|
||||
dropout_p: float,
|
||||
out_proj_weight: Tensor,
|
||||
out_proj_bias: Tensor,
|
||||
training: bool = True,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
pos_emb: Positional embedding tensor
|
||||
embed_dim_to_check: total dimension of the model.
|
||||
num_heads: parallel attention heads.
|
||||
in_proj_weight, in_proj_bias: input projection weight and bias.
|
||||
dropout_p: probability of an element to be zeroed.
|
||||
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
||||
training: apply dropout if is ``True``.
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
be ignored by the attention. This is an binary mask. When the value is True,
|
||||
the corresponding value on the attention layer will be filled with -inf.
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
|
||||
Shape:
|
||||
Inputs:
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
|
||||
length, N is the batch size, E is the embedding dimension.
|
||||
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
||||
will be unchanged. If a BoolTensor is provided, the positions with the
|
||||
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
||||
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
||||
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
||||
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
||||
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||
is provided, it will be added to the attention weight.
|
||||
|
||||
Outputs:
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
E is the embedding dimension.
|
||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||
L is the target sequence length, S is the source sequence length.
|
||||
"""
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
assert embed_dim == embed_dim_to_check
|
||||
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
||||
|
||||
head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
head_dim * num_heads == embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
scaling = float(head_dim) ** -0.5
|
||||
|
||||
if torch.equal(query, key) and torch.equal(key, value):
|
||||
# self-attention
|
||||
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
|
||||
3, dim=-1
|
||||
)
|
||||
|
||||
elif torch.equal(key, value):
|
||||
# encoder-decoder attention
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = 0
|
||||
_end = embed_dim
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
q = nn.functional.linear(query, _w, _b)
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim
|
||||
_end = None
|
||||
_w = in_proj_weight[_start:, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:]
|
||||
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
|
||||
|
||||
else:
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = 0
|
||||
_end = embed_dim
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
q = nn.functional.linear(query, _w, _b)
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim
|
||||
_end = embed_dim * 2
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
k = nn.functional.linear(key, _w, _b)
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim * 2
|
||||
_end = None
|
||||
_w = in_proj_weight[_start:, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:]
|
||||
v = nn.functional.linear(value, _w, _b)
|
||||
|
||||
if attn_mask is not None:
|
||||
assert (
|
||||
attn_mask.dtype == torch.float32
|
||||
or attn_mask.dtype == torch.float64
|
||||
or attn_mask.dtype == torch.float16
|
||||
or attn_mask.dtype == torch.uint8
|
||||
or attn_mask.dtype == torch.bool
|
||||
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
|
||||
attn_mask.dtype
|
||||
)
|
||||
if attn_mask.dtype == torch.uint8:
|
||||
warnings.warn(
|
||||
"Byte tensor for attn_mask is deprecated. Use bool tensor instead."
|
||||
)
|
||||
attn_mask = attn_mask.to(torch.bool)
|
||||
|
||||
if attn_mask.dim() == 2:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
||||
raise RuntimeError("The size of the 2D attn_mask is not correct.")
|
||||
elif attn_mask.dim() == 3:
|
||||
if list(attn_mask.size()) != [
|
||||
bsz * num_heads,
|
||||
query.size(0),
|
||||
key.size(0),
|
||||
]:
|
||||
raise RuntimeError("The size of the 3D attn_mask is not correct.")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"attn_mask's dimension {} is not supported".format(attn_mask.dim())
|
||||
)
|
||||
# attn_mask's dim is 3 now.
|
||||
|
||||
# convert ByteTensor key_padding_mask to bool
|
||||
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
||||
warnings.warn(
|
||||
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
|
||||
)
|
||||
key_padding_mask = key_padding_mask.to(torch.bool)
|
||||
|
||||
q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim)
|
||||
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
|
||||
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
|
||||
src_len = k.size(0)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
|
||||
key_padding_mask.size(0), bsz
|
||||
)
|
||||
assert key_padding_mask.size(1) == src_len, "{} == {}".format(
|
||||
key_padding_mask.size(1), src_len
|
||||
)
|
||||
|
||||
q = q.transpose(0, 1) # (batch, time1, head, d_k)
|
||||
|
||||
pos_emb_bsz = pos_emb.size(0)
|
||||
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
||||
|
||||
q_with_bias_u = (q + self._pos_bias_u()).transpose(
|
||||
1, 2
|
||||
) # (batch, head, time1, d_k)
|
||||
|
||||
q_with_bias_v = (q + self._pos_bias_v()).transpose(
|
||||
1, 2
|
||||
) # (batch, head, time1, d_k)
|
||||
|
||||
# compute attention score
|
||||
# first compute matrix a and matrix c
|
||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
||||
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
|
||||
|
||||
# compute matrix b and matrix d
|
||||
matrix_bd = torch.matmul(
|
||||
q_with_bias_v, p.transpose(-2, -1)
|
||||
) # (batch, head, time1, 2*time1-1)
|
||||
matrix_bd = self.rel_shift(matrix_bd)
|
||||
|
||||
attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2)
|
||||
|
||||
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
|
||||
|
||||
assert list(attn_output_weights.size()) == [
|
||||
bsz * num_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
]
|
||||
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
|
||||
else:
|
||||
attn_output_weights += attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz, num_heads, tgt_len, src_len
|
||||
)
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
||||
float("-inf"),
|
||||
)
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, tgt_len, src_len
|
||||
)
|
||||
|
||||
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
|
||||
attn_output_weights = nn.functional.dropout(
|
||||
attn_output_weights, p=dropout_p, training=training
|
||||
)
|
||||
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||
)
|
||||
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
|
||||
|
||||
if need_weights:
|
||||
# average attention weights over heads
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz, num_heads, tgt_len, src_len
|
||||
)
|
||||
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
||||
else:
|
||||
return attn_output, None
|
||||
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Conformer model.
|
||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
|
||||
|
||||
Args:
|
||||
channels (int): The number of channels of conv layers.
|
||||
kernel_size (int): Kernerl size of conv layers.
|
||||
bias (bool): Whether to use bias in conv layers (default=True).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
|
||||
self.pointwise_conv1 = ScaledConv1d(
|
||||
channels,
|
||||
2 * channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
# after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu).
|
||||
# For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
|
||||
# but sometimes, for some reason, for layer 0 the rms ends up being very large,
|
||||
# between 50 and 100 for different channels. This will cause very peaky and
|
||||
# sparse derivatives for the sigmoid gating function, which will tend to make
|
||||
# the loss function not learn effectively. (for most layers the average absolute values
|
||||
# are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
|
||||
# at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
|
||||
# layers, which likely breaks down as 0.5 for the "linear" half and
|
||||
# 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we
|
||||
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
|
||||
# it will be in a better position to start learning something, i.e. to latch onto
|
||||
# the correct range.
|
||||
self.deriv_balancer1 = ActivationBalancer(
|
||||
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
|
||||
)
|
||||
|
||||
self.depthwise_conv = ScaledConv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.deriv_balancer2 = ActivationBalancer(
|
||||
channel_dim=1, min_positive=0.05, max_positive=1.0
|
||||
)
|
||||
|
||||
self.activation = DoubleSwish()
|
||||
|
||||
self.pointwise_conv2 = ScaledConv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
initial_scale=0.25,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
"""Compute convolution module.
|
||||
|
||||
Args:
|
||||
x: Input tensor (#time, batch, channels).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (#time, batch, channels).
|
||||
|
||||
"""
|
||||
# exchange the temporal dimension and the feature dimension
|
||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||
|
||||
# GLU mechanism
|
||||
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
|
||||
|
||||
x = self.deriv_balancer1(x)
|
||||
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
||||
|
||||
# 1D Depthwise Conv
|
||||
if src_key_padding_mask is not None:
|
||||
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||
x = self.depthwise_conv(x)
|
||||
|
||||
x = self.deriv_balancer2(x)
|
||||
x = self.activation(x)
|
||||
|
||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||
|
||||
return x.permute(2, 0, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
feature_dim = 50
|
||||
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
||||
batch_size = 5
|
||||
seq_len = 20
|
||||
# Just make sure the forward pass runs.
|
||||
f = c(
|
||||
torch.randn(batch_size, seq_len, feature_dim),
|
||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||
warmup=0.5,
|
||||
)
|
718
egs/librispeech/WSASR/conformer_ctc2/decode.py
Executable file
718
egs/librispeech/WSASR/conformer_ctc2/decode.py
Executable file
@ -0,0 +1,718 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
|
||||
# Fangjun Kuang,
|
||||
# Quandong Wang)
|
||||
# 2023 Johns Hopkins University (Author: Dongji Gao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.decode import get_lattice, one_best_decoding
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.otc_graph_compiler import OtcTrainingGraphCompiler
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
load_averaged_model,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--otc-token",
|
||||
type=str,
|
||||
default="<star>",
|
||||
help="OTC token",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--blank-bias",
|
||||
type=float,
|
||||
default=0,
|
||||
help="bias (log-prob) added to blank token during decoding",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=20,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="ctc-greedy-search",
|
||||
help="""Decoding method.
|
||||
Supported values are:
|
||||
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
||||
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||
It needs neither a lexicon nor an n-gram LM.
|
||||
- (1) ctc-greedy-search. It only use CTC output and a sentence piece
|
||||
model for decoding. It produces the same results with ctc-decoding.
|
||||
- (2) 1best. Extract the best path from the decoding lattice as the
|
||||
decoding result.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-decoder-layers",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""Number of decoder layer of transformer decoder.
|
||||
Setting this to 0 will not create the decoder at all (pure CTC model)
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="conformer_ctc2/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe_200",
|
||||
help="The lang dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-dir",
|
||||
type=str,
|
||||
default="data/lm",
|
||||
help="""The n-gram LM dir.
|
||||
It should contain either G_4_gram.pt or G_4_gram.fst.txt
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
# parameters for conformer
|
||||
"subsampling_factor": 2,
|
||||
"feature_dim": 768,
|
||||
"nhead": 8,
|
||||
"dim_feedforward": 2048,
|
||||
"encoder_dim": 512,
|
||||
"num_encoder_layers": 12,
|
||||
# parameters for decoding
|
||||
"search_beam": 20,
|
||||
"output_beam": 8,
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def ctc_greedy_search(
|
||||
nnet_output: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
memory_key_padding_mask: torch.Tensor,
|
||||
) -> List[List[int]]:
|
||||
"""Apply CTC greedy search
|
||||
|
||||
Args:
|
||||
speech (torch.Tensor): (batch, max_len, feat_dim)
|
||||
speech_length (torch.Tensor): (batch, )
|
||||
Returns:
|
||||
List[List[int]]: best path result
|
||||
"""
|
||||
batch_size = memory.shape[1]
|
||||
# Let's assume B = batch_size
|
||||
encoder_out = memory
|
||||
encoder_mask = memory_key_padding_mask
|
||||
maxlen = encoder_out.size(0)
|
||||
|
||||
ctc_probs = nnet_output # (B, maxlen, vocab_size)
|
||||
topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1)
|
||||
topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen)
|
||||
topk_index = topk_index.masked_fill_(encoder_mask, 0) # (B, maxlen)
|
||||
hyps = [hyp.tolist() for hyp in topk_index]
|
||||
scores = topk_prob.max(1)
|
||||
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
|
||||
return hyps, scores
|
||||
|
||||
|
||||
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
|
||||
# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
|
||||
new_hyp: List[int] = []
|
||||
cur = 0
|
||||
while cur < len(hyp):
|
||||
if hyp[cur] != 0:
|
||||
new_hyp.append(hyp[cur])
|
||||
prev = cur
|
||||
while cur < len(hyp) and hyp[cur] == hyp[prev]:
|
||||
cur += 1
|
||||
return new_hyp
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
batch: dict,
|
||||
word_table: k2.SymbolTable,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if no rescoring is used, the key is the string `no_rescore`.
|
||||
If LM rescoring is used, the key is the string `lm_scale_xxx`,
|
||||
where `xxx` is the value of `lm_scale`. An example key is
|
||||
`lm_scale_0.7`
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
|
||||
- params.method is "1best", it uses 1best decoding without LM rescoring.
|
||||
|
||||
model:
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
||||
H:
|
||||
The ctc topo. Used only when params.method is ctc-decoding.
|
||||
bpe_model:
|
||||
The BPE model. Used only when params.method is ctc-decoding.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
sos_id:
|
||||
The token ID of the SOS.
|
||||
eos_id:
|
||||
The token ID of the EOS.
|
||||
G:
|
||||
An LM. It is not None when params.method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
is a 3-gram LM, while this G is a 4-gram LM.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict. Note: If it decodes to nothing, then return None.
|
||||
"""
|
||||
if HLG is not None:
|
||||
device = HLG.device
|
||||
else:
|
||||
device = H.device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
|
||||
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
|
||||
# nnet_output is (N, T, C)
|
||||
nnet_output[:, :, 0] += params.blank_bias
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
torch.div(
|
||||
supervisions["start_frame"],
|
||||
params.subsampling_factor,
|
||||
rounding_mode="trunc",
|
||||
),
|
||||
torch.div(
|
||||
supervisions["num_frames"],
|
||||
params.subsampling_factor,
|
||||
rounding_mode="trunc",
|
||||
),
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
||||
if H is None:
|
||||
assert HLG is not None
|
||||
decoding_graph = HLG
|
||||
else:
|
||||
assert HLG is None
|
||||
assert bpe_model is not None
|
||||
decoding_graph = H
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=nnet_output,
|
||||
decoding_graph=decoding_graph,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor + 2,
|
||||
)
|
||||
|
||||
if params.method == "ctc-decoding":
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
# Note: `best_path.aux_labels` contains token IDs, not word IDs
|
||||
# since we are using H, not HLG here.
|
||||
#
|
||||
# token_ids is a lit-of-list of IDs
|
||||
token_ids = get_texts(best_path)
|
||||
|
||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
|
||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||
hyps = [s.split() for s in hyps]
|
||||
key = "ctc-decoding"
|
||||
return {key: hyps}
|
||||
|
||||
if params.method == "ctc-greedy-search":
|
||||
hyps, _ = ctc_greedy_search(
|
||||
nnet_output,
|
||||
memory,
|
||||
memory_key_padding_mask,
|
||||
)
|
||||
|
||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||
hyps = bpe_model.decode(hyps)
|
||||
|
||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||
hyps = [s.split() for s in hyps]
|
||||
key = "ctc-greedy-search"
|
||||
return {key: hyps}
|
||||
|
||||
if params.method in ["1best"]:
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
key = "no_rescore"
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
|
||||
return {key: hyps}
|
||||
else:
|
||||
assert False, f"Unsupported decoding method: {params.method}"
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
word_table: k2.SymbolTable,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
||||
H:
|
||||
The ctc topo. Used only when params.method is ctc-decoding.
|
||||
bpe_model:
|
||||
The BPE model. Used only when params.method is ctc-decoding.
|
||||
word_table:
|
||||
It is the word symbol table.
|
||||
sos_id:
|
||||
The token ID for SOS.
|
||||
eos_id:
|
||||
The token ID for EOS.
|
||||
G:
|
||||
An LM. It is not None when params.method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
is a 3-gram LM, while this G is a 4-gram LM.
|
||||
Returns:
|
||||
Return a dict, whose key may be "no-rescore" if no LM rescoring
|
||||
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
batch=batch,
|
||||
word_table=word_table,
|
||||
G=G,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
if hyps_dict is not None:
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
else:
|
||||
assert len(results) > 0, "It should not decode to empty in the first batch!"
|
||||
this_batch = []
|
||||
hyp_words = []
|
||||
for ref_text in texts:
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
|
||||
for lm_scale in results.keys():
|
||||
results[lm_scale].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
if params.method in ("attention-decoder", "rnn-lm"):
|
||||
# Set it to False since there are too many logs.
|
||||
enable_log = False
|
||||
else:
|
||||
enable_log = True
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=enable_log
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
if enable_log:
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
args.lm_dir = Path(args.lm_dir)
|
||||
assert "▁" not in args.otc_token
|
||||
args.otc_token = f"▁{args.otc_token}"
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
|
||||
logging.info("Decoding started")
|
||||
logging.info(params)
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
# remove otc_token from decoding units
|
||||
max_token_id = max(lexicon.tokens) - 1
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
graph_compiler = OtcTrainingGraphCompiler(
|
||||
params.lang_dir,
|
||||
params.otc_token,
|
||||
device=device,
|
||||
sos_token="<sos/eos>",
|
||||
eos_token="<sos/eos>",
|
||||
)
|
||||
sos_id = graph_compiler.sos_id
|
||||
eos_id = graph_compiler.eos_id
|
||||
|
||||
params.num_classes = num_classes
|
||||
params.sos_id = sos_id
|
||||
params.eos_id = eos_id
|
||||
|
||||
if params.method == "ctc-decoding" or params.method == "ctc-greedy-search":
|
||||
HLG = None
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
modified=False,
|
||||
device=device,
|
||||
)
|
||||
bpe_model = spm.SentencePieceProcessor()
|
||||
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
||||
else:
|
||||
H = None
|
||||
bpe_model = None
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||
)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
G = None
|
||||
|
||||
model = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
nhead=params.nhead,
|
||||
d_model=params.encoder_dim,
|
||||
num_classes=num_classes,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
num_decoder_layers=params.num_decoder_layers,
|
||||
)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
word_table=lexicon.word_table,
|
||||
G=G,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/librispeech/WSASR/conformer_ctc2/export.py
Symbolic link
1
egs/librispeech/WSASR/conformer_ctc2/export.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/conformer_ctc2/export.py
|
1
egs/librispeech/WSASR/conformer_ctc2/label_smoothing.py
Symbolic link
1
egs/librispeech/WSASR/conformer_ctc2/label_smoothing.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/conformer_ctc/label_smoothing.py
|
1
egs/librispeech/WSASR/conformer_ctc2/optim.py
Symbolic link
1
egs/librispeech/WSASR/conformer_ctc2/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/pruned_transducer_stateless2/optim.py
|
1
egs/librispeech/WSASR/conformer_ctc2/scaling.py
Symbolic link
1
egs/librispeech/WSASR/conformer_ctc2/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/pruned_transducer_stateless2/scaling.py
|
184
egs/librispeech/WSASR/conformer_ctc2/subsampling.py
Normal file
184
egs/librispeech/WSASR/conformer_ctc2/subsampling.py
Normal file
@ -0,0 +1,184 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
||||
# 2022 Xiaomi Corporation (author: Quandong Wang)
|
||||
# 2023 Johns Hopkins University (author: Dongji Gao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
DoubleSwish,
|
||||
ScaledConv2d,
|
||||
ScaledLinear,
|
||||
)
|
||||
|
||||
|
||||
class Conv2dSubsampling(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length).
|
||||
|
||||
Convert an input of shape (N, T, idim) to an output
|
||||
with shape (N, T', odim), where
|
||||
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
|
||||
|
||||
It is based on
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
layer1_channels: int = 8,
|
||||
layer2_channels: int = 32,
|
||||
layer3_channels: int = 128,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
in_channels:
|
||||
Number of channels in. The input shape is (N, T, in_channels).
|
||||
Caution: It requires: T >=7, in_channels >=7
|
||||
out_channels
|
||||
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels)
|
||||
layer1_channels:
|
||||
Number of channels in layer1
|
||||
layer1_channels:
|
||||
Number of channels in layer2
|
||||
"""
|
||||
assert in_channels >= 7
|
||||
super().__init__()
|
||||
|
||||
self.conv = torch.nn.Sequential(
|
||||
ScaledConv2d(
|
||||
in_channels=1,
|
||||
out_channels=layer1_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=layer1_channels,
|
||||
out_channels=layer2_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=layer2_channels,
|
||||
out_channels=layer3_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
)
|
||||
self.out = ScaledLinear(
|
||||
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
|
||||
)
|
||||
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
||||
# itself has learned scale, so the extra degree of freedom is not
|
||||
# needed.
|
||||
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
||||
# constrain median of output to be close to zero.
|
||||
self.out_balancer = ActivationBalancer(
|
||||
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is (N, T, idim).
|
||||
|
||||
Returns:
|
||||
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
||||
"""
|
||||
# On entry, x is (N, T, idim)
|
||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||
x = self.conv(x)
|
||||
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||
x = self.out_norm(x)
|
||||
x = self.out_balancer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Conv2dSubsampling2(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/2 length).
|
||||
|
||||
Convert an input of shape (N, T, idim) to an output
|
||||
with shape (N, T', odim) where
|
||||
T' = (T - 1) // 2 - 2, which approximates T' == T // 2
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
layer1_channels: int = 8,
|
||||
layer2_channels: int = 32,
|
||||
layer3_channels: int = 128,
|
||||
) -> None:
|
||||
assert in_channels >= 7
|
||||
super().__init__()
|
||||
|
||||
self.conv = torch.nn.Sequential(
|
||||
ScaledConv2d(
|
||||
in_channels=1,
|
||||
out_channels=layer1_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=layer1_channels,
|
||||
out_channels=layer2_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=layer2_channels,
|
||||
out_channels=layer3_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
)
|
||||
self.out = ScaledLinear(
|
||||
layer3_channels * ((in_channels - 1) // 2 - 2), out_channels
|
||||
)
|
||||
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
||||
self.out_balancer = ActivationBalancer(
|
||||
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.unsqueeze(1)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
x = self.out_norm(x)
|
||||
x = self.out_balancer(x)
|
||||
return x
|
1115
egs/librispeech/WSASR/conformer_ctc2/train.py
Executable file
1115
egs/librispeech/WSASR/conformer_ctc2/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1055
egs/librispeech/WSASR/conformer_ctc2/transformer.py
Normal file
1055
egs/librispeech/WSASR/conformer_ctc2/transformer.py
Normal file
File diff suppressed because it is too large
Load Diff
BIN
egs/librispeech/WSASR/figures/del.png
Normal file
BIN
egs/librispeech/WSASR/figures/del.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 14 KiB |
BIN
egs/librispeech/WSASR/figures/ins.png
Normal file
BIN
egs/librispeech/WSASR/figures/ins.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 16 KiB |
BIN
egs/librispeech/WSASR/figures/otc_emission.drawio.png
Normal file
BIN
egs/librispeech/WSASR/figures/otc_emission.drawio.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 39 KiB |
BIN
egs/librispeech/WSASR/figures/otc_g.png
Normal file
BIN
egs/librispeech/WSASR/figures/otc_g.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 33 KiB |
BIN
egs/librispeech/WSASR/figures/otc_training_graph.drawio.png
Normal file
BIN
egs/librispeech/WSASR/figures/otc_training_graph.drawio.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 150 KiB |
BIN
egs/librispeech/WSASR/figures/sub.png
Normal file
BIN
egs/librispeech/WSASR/figures/sub.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 16 KiB |
173
egs/librispeech/WSASR/local/compile_hlg.py
Executable file
173
egs/librispeech/WSASR/local/compile_hlg.py
Executable file
@ -0,0 +1,173 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input lang_dir and generates HLG from
|
||||
|
||||
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
|
||||
- L, the lexicon, built from lang_dir/L_disambig.pt
|
||||
|
||||
Caution: We use a lexicon that contains disambiguation symbols
|
||||
|
||||
- G, the LM, built from data/lm/G_n_gram.fst.txt
|
||||
|
||||
The generated HLG is saved in $lang_dir/HLG.pt
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lm",
|
||||
type=str,
|
||||
default="G_3_gram",
|
||||
help="""Stem name for LM used in HLG compiling.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lm-dir",
|
||||
type=str,
|
||||
help="""LM directory.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compile_HLG(lm_dir: str, lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
||||
lm:
|
||||
The language stem base name.
|
||||
|
||||
Return:
|
||||
An FSA representing HLG.
|
||||
"""
|
||||
lexicon = Lexicon(lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
||||
H = k2.ctc_topo(max_token_id)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
if Path(f"{lm_dir}/{lm}.pt").is_file():
|
||||
logging.info(f"Loading pre-compiled {lm}")
|
||||
d = torch.load(f"{lm_dir}/{lm}.pt")
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info(f"Loading {lm}.fst.txt")
|
||||
with open(f"{lm_dir}/{lm}.fst.txt") as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
torch.save(G.as_dict(), f"{lm_dir}/{lm}.pt")
|
||||
|
||||
first_token_disambig_id = lexicon.token_table["#0"]
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
L = k2.arc_sort(L)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
logging.info("Intersecting L and G")
|
||||
LG = k2.compose(L, G)
|
||||
logging.info(f"LG shape: {LG.shape}")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
LG = k2.connect(LG)
|
||||
logging.info(f"LG shape after k2.connect: {LG.shape}")
|
||||
|
||||
logging.info(type(LG.aux_labels))
|
||||
logging.info("Determinizing LG")
|
||||
|
||||
LG = k2.determinize(LG)
|
||||
logging.info(type(LG.aux_labels))
|
||||
|
||||
logging.info("Connecting LG after k2.determinize")
|
||||
LG = k2.connect(LG)
|
||||
|
||||
logging.info("Removing disambiguation symbols on LG")
|
||||
|
||||
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# See https://github.com/k2-fsa/k2/issues/874
|
||||
# for why we need to set LG.properties to None
|
||||
LG.__dict__["_properties"] = None
|
||||
|
||||
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
|
||||
LG = k2.remove_epsilon(LG)
|
||||
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||
|
||||
LG = k2.connect(LG)
|
||||
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
LG = k2.arc_sort(LG)
|
||||
|
||||
logging.info("Composing H and LG")
|
||||
# CAUTION: The name of the inner_labels is fixed
|
||||
# to `tokens`. If you want to change it, please
|
||||
# also change other places in icefall that are using
|
||||
# it.
|
||||
HLG = k2.compose(H, LG, inner_labels="tokens")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
HLG = k2.connect(HLG)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
HLG = k2.arc_sort(HLG)
|
||||
logging.info(f"HLG.shape: {HLG.shape}")
|
||||
|
||||
return HLG
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lm_dir = Path(args.lm_dir)
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
if (lang_dir / "HLG.pt").is_file():
|
||||
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
|
||||
return
|
||||
|
||||
logging.info(f"Processing {lang_dir}")
|
||||
|
||||
HLG = compile_HLG(lm_dir, lang_dir, args.lm)
|
||||
logging.info(f"Saving HLG.pt to {lang_dir}")
|
||||
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
162
egs/librispeech/WSASR/local/compute_fbank_librispeech.py
Executable file
162
egs/librispeech/WSASR/local/compute_fbank_librispeech.py
Executable file
@ -0,0 +1,162 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the LibriSpeech dataset.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from filter_cuts import filter_cuts
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to the bpe.model. If not None, we will remove short and
|
||||
long utterances before extracting features""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
help="""Dataset parts to compute fbank. If None, we will use all""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--perturb-speed",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="""Perturb speed with factor 0.9 and 1.1 on train subset.""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compute_fbank_librispeech(
|
||||
bpe_model: Optional[str] = None,
|
||||
dataset: Optional[str] = None,
|
||||
perturb_speed: Optional[bool] = True,
|
||||
):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
num_mel_bins = 80
|
||||
|
||||
if bpe_model:
|
||||
logging.info(f"Loading {bpe_model}")
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(bpe_model)
|
||||
|
||||
if dataset is None:
|
||||
dataset_parts = (
|
||||
"dev-clean",
|
||||
"dev-other",
|
||||
"test-clean",
|
||||
"test-other",
|
||||
"train-clean-100",
|
||||
)
|
||||
else:
|
||||
dataset_parts = dataset.split(" ", -1)
|
||||
|
||||
prefix = "librispeech"
|
||||
suffix = "jsonl.gz"
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=src_dir,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
assert len(manifests) == len(dataset_parts), (
|
||||
len(manifests),
|
||||
len(dataset_parts),
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
||||
if (output_dir / cuts_filename).is_file():
|
||||
logging.info(f"{partition} already exists - skipping.")
|
||||
continue
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
|
||||
if "train" in partition:
|
||||
if bpe_model:
|
||||
cut_set = filter_cuts(cut_set, sp)
|
||||
if perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set
|
||||
+ cut_set.perturb_speed(0.9)
|
||||
+ cut_set.perturb_speed(1.1)
|
||||
)
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||
# when an executor is specified, make more partitions
|
||||
num_jobs=num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
cut_set.to_file(output_dir / cuts_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
args = get_args()
|
||||
logging.info(vars(args))
|
||||
compute_fbank_librispeech(
|
||||
bpe_model=args.bpe_model,
|
||||
dataset=args.dataset,
|
||||
perturb_speed=args.perturb_speed,
|
||||
)
|
100
egs/librispeech/WSASR/local/compute_ssl_librispeech.py
Executable file
100
egs/librispeech/WSASR/local/compute_ssl_librispeech.py
Executable file
@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# 2023 Johns Hopkins University (author: Dongji Gao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the LibriSpeech dataset.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import S3PRLSSL, CutSet, NumpyFilesWriter, S3PRLSSLConfig
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_ssl_librispeech():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/ssl")
|
||||
num_jobs = 1
|
||||
|
||||
dataset_parts = (
|
||||
"dev-clean",
|
||||
"dev-other",
|
||||
"test-clean",
|
||||
"test-other",
|
||||
"train-clean-100",
|
||||
)
|
||||
prefix = "librispeech"
|
||||
suffix = "jsonl.gz"
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=src_dir,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
assert len(manifests) == len(dataset_parts), (
|
||||
len(manifests),
|
||||
len(dataset_parts),
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = S3PRLSSL(S3PRLSSLConfig(ssl_model="wav2vec2", device="cuda"))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
||||
if (output_dir / cuts_filename).is_file():
|
||||
logging.info(f"{partition} already exists - skipping.")
|
||||
continue
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||
storage_type=NumpyFilesWriter,
|
||||
)
|
||||
cut_set.to_file(output_dir / cuts_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_ssl_librispeech()
|
160
egs/librispeech/WSASR/local/filter_cuts.py
Normal file
160
egs/librispeech/WSASR/local/filter_cuts.py
Normal file
@ -0,0 +1,160 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script removes short and long utterances from a cutset.
|
||||
|
||||
Caution:
|
||||
You may need to tune the thresholds for your own dataset.
|
||||
|
||||
Usage example:
|
||||
|
||||
python3 ./local/filter_cuts.py \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--in-cuts data/fbank/librispeech_cuts_test-clean.jsonl.gz \
|
||||
--out-cuts data/fbank-filtered/librispeech_cuts_test-clean.jsonl.gz
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.cut import Cut
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=Path,
|
||||
help="Path to the bpe.model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--in-cuts",
|
||||
type=Path,
|
||||
help="Path to the input cutset",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--out-cuts",
|
||||
type=Path,
|
||||
help="Path to the output cutset",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor):
|
||||
total = 0 # number of total utterances before removal
|
||||
removed = 0 # number of removed utterances
|
||||
|
||||
def remove_short_and_long_utterances(c: Cut):
|
||||
"""Return False to exclude the input cut"""
|
||||
nonlocal removed, total
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
#
|
||||
# Caution: There is a reason to select 20.0 here. Please see
|
||||
# ./display_manifest_statistics.py
|
||||
#
|
||||
# You should use ./display_manifest_statistics.py to get
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
total += 1
|
||||
if c.duration < 1.0 or c.duration > 20.0:
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
)
|
||||
removed += 1
|
||||
return False
|
||||
|
||||
# In pruned RNN-T, we require that T >= S
|
||||
# where T is the number of feature frames after subsampling
|
||||
# and S is the number of tokens in the utterance
|
||||
|
||||
# In ./pruned_transducer_stateless2/conformer.py, the
|
||||
# conv module uses the following expression
|
||||
# for subsampling
|
||||
if c.num_frames is None:
|
||||
num_frames = c.duration * 100 # approximate
|
||||
else:
|
||||
num_frames = c.num_frames
|
||||
|
||||
T = ((num_frames - 1) // 2 - 1) // 2
|
||||
# Note: for ./lstm_transducer_stateless/lstm.py, the formula is
|
||||
# T = ((num_frames - 3) // 2 - 1) // 2
|
||||
|
||||
# Note: for ./pruned_transducer_stateless7/zipformer.py, the formula is
|
||||
# T = ((num_frames - 7) // 2 + 1) // 2
|
||||
|
||||
tokens = sp.encode(c.supervisions[0].text, out_type=str)
|
||||
|
||||
if T < len(tokens):
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. "
|
||||
f"Number of frames (before subsampling): {c.num_frames}. "
|
||||
f"Number of frames (after subsampling): {T}. "
|
||||
f"Text: {c.supervisions[0].text}. "
|
||||
f"Tokens: {tokens}. "
|
||||
f"Number of tokens: {len(tokens)}"
|
||||
)
|
||||
removed += 1
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# We use to_eager() here so that we can print out the value of total
|
||||
# and removed below.
|
||||
ans = cut_set.filter(remove_short_and_long_utterances).to_eager()
|
||||
ratio = removed / total * 100
|
||||
logging.info(
|
||||
f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed."
|
||||
)
|
||||
return ans
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
if args.out_cuts.is_file():
|
||||
logging.info(f"{args.out_cuts} already exists - skipping")
|
||||
return
|
||||
|
||||
assert args.in_cuts.is_file(), f"{args.in_cuts} does not exist"
|
||||
assert args.bpe_model.is_file(), f"{args.bpe_model} does not exist"
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(str(args.bpe_model))
|
||||
|
||||
cut_set = load_manifest_lazy(args.in_cuts)
|
||||
assert isinstance(cut_set, CutSet)
|
||||
|
||||
cut_set = filter_cuts(cut_set, sp)
|
||||
logging.info(f"Saving to {args.out_cuts}")
|
||||
args.out_cuts.parent.mkdir(parents=True, exist_ok=True)
|
||||
cut_set.to_file(args.out_cuts)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
48
egs/librispeech/WSASR/local/get_words_from_lexicon.py
Executable file
48
egs/librispeech/WSASR/local/get_words_from_lexicon.py
Executable file
@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from icefall.lexicon import read_lexicon
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
It should contain a file lexicon.txt.
|
||||
Generated files by this script are saved into this directory.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--otc-token",
|
||||
type=str,
|
||||
help="OTC token to be added to words.txt",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
otc_token = args.otc_token
|
||||
|
||||
lexicon = read_lexicon(lang_dir / "lexicon.txt")
|
||||
ans = set()
|
||||
for word, _ in lexicon:
|
||||
ans.add(word)
|
||||
sorted_ans = sorted(list(ans))
|
||||
words = ["<eps>"] + sorted_ans + [otc_token] + ["#0", "<s>", "</s>"]
|
||||
|
||||
words_file = lang_dir / "words.txt"
|
||||
with open(words_file, "w") as wf:
|
||||
for i, word in enumerate(words):
|
||||
wf.write(f"{word} {i}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
175
egs/librispeech/WSASR/local/make_error_cutset.py
Executable file
175
egs/librispeech/WSASR/local/make_error_cutset.py
Executable file
@ -0,0 +1,175 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2023 Johns Hopkins University (author: Dongji Gao)
|
||||
|
||||
import argparse
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from lhotse import CutSet, load_manifest
|
||||
from lhotse.cut.base import Cut
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-cutset",
|
||||
type=str,
|
||||
help="Supervision manifest that contains verbatim transcript",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--words-file",
|
||||
type=str,
|
||||
help="words.txt file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--otc-token",
|
||||
type=str,
|
||||
help="OTC token in words.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sub-error-rate",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Substitution error rate",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ins-error-rate",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Insertion error rate",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--del-error-rate",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Deletion error rate",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-cutset",
|
||||
type=str,
|
||||
default="",
|
||||
help="Supervision manifest that contains modified non-verbatim transcript",
|
||||
)
|
||||
|
||||
parser.add_argument("--verbose", type=str2bool, help="show details of errors")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def check_args(args):
|
||||
total_error_rate = args.sub_error_rate + args.ins_error_rate + args.del_error_rate
|
||||
assert args.sub_error_rate >= 0 and args.sub_error_rate <= 1.0
|
||||
assert args.ins_error_rate >= 0 and args.sub_error_rate <= 1.0
|
||||
assert args.del_error_rate >= 0 and args.sub_error_rate <= 1.0
|
||||
assert total_error_rate <= 1.0
|
||||
|
||||
|
||||
def get_word_list(token_path: str) -> List:
|
||||
word_list = []
|
||||
with open(Path(token_path), "r") as tp:
|
||||
for line in tp.readlines():
|
||||
token = line.split()[0]
|
||||
assert token not in word_list
|
||||
word_list.append(token)
|
||||
return word_list
|
||||
|
||||
|
||||
def modify_cut_text(
|
||||
cut: Cut,
|
||||
words_list: List,
|
||||
non_words: List,
|
||||
sub_ratio: float = 0.0,
|
||||
ins_ratio: float = 0.0,
|
||||
del_ratio: float = 0.0,
|
||||
):
|
||||
text = cut.supervisions[0].text
|
||||
text_list = text.split()
|
||||
|
||||
# We save the modified information of the original verbatim text for debugging
|
||||
marked_verbatim_text_list = []
|
||||
modified_text_list = []
|
||||
|
||||
del_index_set = set()
|
||||
sub_index_set = set()
|
||||
ins_index_set = set()
|
||||
|
||||
# We follow the order: deletion -> substitution -> insertion
|
||||
for token in text_list:
|
||||
marked_token = token
|
||||
modified_token = token
|
||||
|
||||
prob = random.random()
|
||||
|
||||
if prob <= del_ratio:
|
||||
marked_token = f"-{token}-"
|
||||
modified_token = ""
|
||||
elif prob <= del_ratio + sub_ratio + ins_ratio:
|
||||
if prob <= del_ratio + sub_ratio:
|
||||
marked_token = f"[{token}]"
|
||||
else:
|
||||
marked_verbatim_text_list.append(marked_token)
|
||||
modified_text_list.append(modified_token)
|
||||
marked_token = "[]"
|
||||
|
||||
# get new_token
|
||||
while (
|
||||
modified_token == token
|
||||
or modified_token in non_words
|
||||
or modified_token.startswith("#")
|
||||
):
|
||||
modified_token = random.choice(words_list)
|
||||
|
||||
marked_verbatim_text_list.append(marked_token)
|
||||
modified_text_list.append(modified_token)
|
||||
|
||||
marked_text = " ".join(marked_verbatim_text_list)
|
||||
modified_text = " ".join(modified_text_list)
|
||||
|
||||
if not hasattr(cut.supervisions[0], "verbatim_text"):
|
||||
cut.supervisions[0].verbatim_text = marked_text
|
||||
cut.supervisions[0].text = modified_text
|
||||
|
||||
return cut
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
check_args(args)
|
||||
|
||||
otc_token = args.otc_token
|
||||
non_words = set(("sil", "<UNK>", "<eps>"))
|
||||
non_words.add(otc_token)
|
||||
|
||||
words_list = get_word_list(args.words_file)
|
||||
cutset = load_manifest(Path(args.input_cutset))
|
||||
|
||||
cuts = []
|
||||
|
||||
for cut in cutset:
|
||||
modified_cut = modify_cut_text(
|
||||
cut=cut,
|
||||
words_list=words_list,
|
||||
non_words=non_words,
|
||||
sub_ratio=args.sub_error_rate,
|
||||
ins_ratio=args.ins_error_rate,
|
||||
del_ratio=args.del_error_rate,
|
||||
)
|
||||
cuts.append(modified_cut)
|
||||
|
||||
output_cutset = CutSet.from_cuts(cuts)
|
||||
output_cutset.to_file(args.output_cutset)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
413
egs/librispeech/WSASR/local/prepare_lang.py
Executable file
413
egs/librispeech/WSASR/local/prepare_lang.py
Executable file
@ -0,0 +1,413 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
|
||||
consisting of words and tokens (i.e., phones) and does the following:
|
||||
|
||||
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
|
||||
|
||||
2. Generate tokens.txt, the token table mapping a token to a unique integer.
|
||||
|
||||
3. Generate words.txt, the word table mapping a word to a unique integer.
|
||||
|
||||
4. Generate L.pt, in k2 format. It can be loaded by
|
||||
|
||||
d = torch.load("L.pt")
|
||||
lexicon = k2.Fsa.from_dict(d)
|
||||
|
||||
5. Generate L_disambig.pt, in k2 format.
|
||||
"""
|
||||
import argparse
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import read_lexicon, write_lexicon
|
||||
from icefall.utils import str2bool
|
||||
|
||||
Lexicon = List[Tuple[str, List[str]]]
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
It should contain a file lexicon.txt.
|
||||
Generated files by this script are saved into this directory.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True for debugging, which will generate
|
||||
a visualization of the lexicon FST.
|
||||
|
||||
Caution: If your lexicon contains hundreds of thousands
|
||||
of lines, please set it to False!
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
|
||||
"""Write a symbol to ID mapping to a file.
|
||||
|
||||
Note:
|
||||
No need to implement `read_mapping` as it can be done
|
||||
through :func:`k2.SymbolTable.from_file`.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename to save the mapping.
|
||||
sym2id:
|
||||
A dict mapping symbols to IDs.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
for sym, i in sym2id.items():
|
||||
f.write(f"{sym} {i}\n")
|
||||
|
||||
|
||||
def get_tokens(lexicon: Lexicon) -> List[str]:
|
||||
"""Get tokens from a lexicon.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
It is the return value of :func:`read_lexicon`.
|
||||
Returns:
|
||||
Return a list of unique tokens.
|
||||
"""
|
||||
ans = set()
|
||||
for _, tokens in lexicon:
|
||||
ans.update(tokens)
|
||||
sorted_ans = sorted(list(ans))
|
||||
return sorted_ans
|
||||
|
||||
|
||||
def get_words(lexicon: Lexicon) -> List[str]:
|
||||
"""Get words from a lexicon.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
It is the return value of :func:`read_lexicon`.
|
||||
Returns:
|
||||
Return a list of unique words.
|
||||
"""
|
||||
ans = set()
|
||||
for word, _ in lexicon:
|
||||
ans.add(word)
|
||||
sorted_ans = sorted(list(ans))
|
||||
return sorted_ans
|
||||
|
||||
|
||||
def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
|
||||
"""It adds pseudo-token disambiguation symbols #1, #2 and so on
|
||||
at the ends of tokens to ensure that all pronunciations are different,
|
||||
and that none is a prefix of another.
|
||||
|
||||
See also add_lex_disambig.pl from kaldi.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
It is returned by :func:`read_lexicon`.
|
||||
Returns:
|
||||
Return a tuple with two elements:
|
||||
|
||||
- The output lexicon with disambiguation symbols
|
||||
- The ID of the max disambiguation symbol that appears
|
||||
in the lexicon
|
||||
"""
|
||||
|
||||
# (1) Work out the count of each token-sequence in the
|
||||
# lexicon.
|
||||
count = defaultdict(int)
|
||||
for _, tokens in lexicon:
|
||||
count[" ".join(tokens)] += 1
|
||||
|
||||
# (2) For each left sub-sequence of each token-sequence, note down
|
||||
# that it exists (for identifying prefixes of longer strings).
|
||||
issubseq = defaultdict(int)
|
||||
for _, tokens in lexicon:
|
||||
tokens = tokens.copy()
|
||||
tokens.pop()
|
||||
while tokens:
|
||||
issubseq[" ".join(tokens)] = 1
|
||||
tokens.pop()
|
||||
|
||||
# (3) For each entry in the lexicon:
|
||||
# if the token sequence is unique and is not a
|
||||
# prefix of another word, no disambig symbol.
|
||||
# Else output #1, or #2, #3, ... if the same token-seq
|
||||
# has already been assigned a disambig symbol.
|
||||
ans = []
|
||||
|
||||
# We start with #1 since #0 has its own purpose
|
||||
first_allowed_disambig = 1
|
||||
max_disambig = first_allowed_disambig - 1
|
||||
last_used_disambig_symbol_of = defaultdict(int)
|
||||
|
||||
for word, tokens in lexicon:
|
||||
tokenseq = " ".join(tokens)
|
||||
assert tokenseq != ""
|
||||
if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
|
||||
ans.append((word, tokens))
|
||||
continue
|
||||
|
||||
cur_disambig = last_used_disambig_symbol_of[tokenseq]
|
||||
if cur_disambig == 0:
|
||||
cur_disambig = first_allowed_disambig
|
||||
else:
|
||||
cur_disambig += 1
|
||||
|
||||
if cur_disambig > max_disambig:
|
||||
max_disambig = cur_disambig
|
||||
last_used_disambig_symbol_of[tokenseq] = cur_disambig
|
||||
tokenseq += f" #{cur_disambig}"
|
||||
ans.append((word, tokenseq.split()))
|
||||
return ans, max_disambig
|
||||
|
||||
|
||||
def generate_id_map(symbols: List[str]) -> Dict[str, int]:
|
||||
"""Generate ID maps, i.e., map a symbol to a unique ID.
|
||||
|
||||
Args:
|
||||
symbols:
|
||||
A list of unique symbols.
|
||||
Returns:
|
||||
A dict containing the mapping between symbols and IDs.
|
||||
"""
|
||||
return {sym: i for i, sym in enumerate(symbols)}
|
||||
|
||||
|
||||
def add_self_loops(
|
||||
arcs: List[List[Any]], disambig_token: int, disambig_word: int
|
||||
) -> List[List[Any]]:
|
||||
"""Adds self-loops to states of an FST to propagate disambiguation symbols
|
||||
through it. They are added on each state with non-epsilon output symbols
|
||||
on at least one arc out of the state.
|
||||
|
||||
See also fstaddselfloops.pl from Kaldi. One difference is that
|
||||
Kaldi uses OpenFst style FSTs and it has multiple final states.
|
||||
This function uses k2 style FSTs and it does not need to add self-loops
|
||||
to the final state.
|
||||
|
||||
The input label of a self-loop is `disambig_token`, while the output
|
||||
label is `disambig_word`.
|
||||
|
||||
Args:
|
||||
arcs:
|
||||
A list-of-list. The sublist contains
|
||||
`[src_state, dest_state, label, aux_label, score]`
|
||||
disambig_token:
|
||||
It is the token ID of the symbol `#0`.
|
||||
disambig_word:
|
||||
It is the word ID of the symbol `#0`.
|
||||
|
||||
Return:
|
||||
Return new `arcs` containing self-loops.
|
||||
"""
|
||||
states_needs_self_loops = set()
|
||||
for arc in arcs:
|
||||
src, dst, ilabel, olabel, score = arc
|
||||
if olabel != 0:
|
||||
states_needs_self_loops.add(src)
|
||||
|
||||
ans = []
|
||||
for s in states_needs_self_loops:
|
||||
ans.append([s, s, disambig_token, disambig_word, 0])
|
||||
|
||||
return arcs + ans
|
||||
|
||||
|
||||
def lexicon_to_fst(
|
||||
lexicon: Lexicon,
|
||||
token2id: Dict[str, int],
|
||||
word2id: Dict[str, int],
|
||||
sil_token: str = "SIL",
|
||||
sil_prob: float = 0.5,
|
||||
need_self_loops: bool = False,
|
||||
) -> k2.Fsa:
|
||||
"""Convert a lexicon to an FST (in k2 format) with optional silence at
|
||||
the beginning and end of each word.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
The input lexicon. See also :func:`read_lexicon`
|
||||
token2id:
|
||||
A dict mapping tokens to IDs.
|
||||
word2id:
|
||||
A dict mapping words to IDs.
|
||||
sil_token:
|
||||
The silence token.
|
||||
sil_prob:
|
||||
The probability for adding a silence at the beginning and end
|
||||
of the word.
|
||||
need_self_loops:
|
||||
If True, add self-loop to states with non-epsilon output symbols
|
||||
on at least one arc out of the state. The input label for this
|
||||
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
|
||||
Returns:
|
||||
Return an instance of `k2.Fsa` representing the given lexicon.
|
||||
"""
|
||||
assert sil_prob > 0.0 and sil_prob < 1.0
|
||||
# CAUTION: we use score, i.e, negative cost.
|
||||
sil_score = math.log(sil_prob)
|
||||
no_sil_score = math.log(1.0 - sil_prob)
|
||||
|
||||
start_state = 0
|
||||
loop_state = 1 # words enter and leave from here
|
||||
sil_state = 2 # words terminate here when followed by silence; this state
|
||||
# has a silence transition to loop_state.
|
||||
next_state = 3 # the next un-allocated state, will be incremented as we go.
|
||||
arcs = []
|
||||
|
||||
assert token2id["<eps>"] == 0
|
||||
assert word2id["<eps>"] == 0
|
||||
|
||||
eps = 0
|
||||
|
||||
sil_token = token2id[sil_token]
|
||||
|
||||
arcs.append([start_state, loop_state, eps, eps, no_sil_score])
|
||||
arcs.append([start_state, sil_state, eps, eps, sil_score])
|
||||
arcs.append([sil_state, loop_state, sil_token, eps, 0])
|
||||
|
||||
for word, tokens in lexicon:
|
||||
assert len(tokens) > 0, f"{word} has no pronunciations"
|
||||
cur_state = loop_state
|
||||
|
||||
word = word2id[word]
|
||||
tokens = [token2id[i] for i in tokens]
|
||||
|
||||
for i in range(len(tokens) - 1):
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, next_state, tokens[i], w, 0])
|
||||
|
||||
cur_state = next_state
|
||||
next_state += 1
|
||||
|
||||
# now for the last token of this word
|
||||
# It has two out-going arcs, one to the loop state,
|
||||
# the other one to the sil_state.
|
||||
i = len(tokens) - 1
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
|
||||
arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
|
||||
|
||||
if need_self_loops:
|
||||
disambig_token = token2id["#0"]
|
||||
disambig_word = word2id["#0"]
|
||||
arcs = add_self_loops(
|
||||
arcs,
|
||||
disambig_token=disambig_token,
|
||||
disambig_word=disambig_word,
|
||||
)
|
||||
|
||||
final_state = next_state
|
||||
arcs.append([loop_state, final_state, -1, -1, 0])
|
||||
arcs.append([final_state])
|
||||
|
||||
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||
arcs = [" ".join(arc) for arc in arcs]
|
||||
arcs = "\n".join(arcs)
|
||||
|
||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||
return fsa
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
lexicon_filename = lang_dir / "lexicon.txt"
|
||||
sil_token = "SIL"
|
||||
sil_prob = 0.5
|
||||
|
||||
lexicon = read_lexicon(lexicon_filename)
|
||||
tokens = get_tokens(lexicon)
|
||||
words = get_words(lexicon)
|
||||
|
||||
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||
|
||||
for i in range(max_disambig + 1):
|
||||
disambig = f"#{i}"
|
||||
assert disambig not in tokens
|
||||
tokens.append(f"#{i}")
|
||||
|
||||
assert "<eps>" not in tokens
|
||||
tokens = ["<eps>"] + tokens
|
||||
|
||||
assert "<eps>" not in words
|
||||
assert "#0" not in words
|
||||
assert "<s>" not in words
|
||||
assert "</s>" not in words
|
||||
|
||||
words = ["<eps>"] + words + ["#0", "<s>", "</s>"]
|
||||
|
||||
token2id = generate_id_map(tokens)
|
||||
word2id = generate_id_map(words)
|
||||
|
||||
write_mapping(lang_dir / "tokens.txt", token2id)
|
||||
write_mapping(lang_dir / "words.txt", word2id)
|
||||
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||
|
||||
L = lexicon_to_fst(
|
||||
lexicon,
|
||||
token2id=token2id,
|
||||
word2id=word2id,
|
||||
sil_token=sil_token,
|
||||
sil_prob=sil_prob,
|
||||
)
|
||||
|
||||
L_disambig = lexicon_to_fst(
|
||||
lexicon_disambig,
|
||||
token2id=token2id,
|
||||
word2id=word2id,
|
||||
sil_token=sil_token,
|
||||
sil_prob=sil_prob,
|
||||
need_self_loops=True,
|
||||
)
|
||||
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||
|
||||
if args.debug:
|
||||
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
|
||||
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
|
||||
L.labels_sym = labels_sym
|
||||
L.aux_labels_sym = aux_labels_sym
|
||||
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
|
||||
|
||||
L_disambig.labels_sym = labels_sym
|
||||
L_disambig.aux_labels_sym = aux_labels_sym
|
||||
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
295
egs/librispeech/WSASR/local/prepare_otc_lang_bpe.py
Executable file
295
egs/librispeech/WSASR/local/prepare_otc_lang_bpe.py
Executable file
@ -0,0 +1,295 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# 2023 Johns Hopkins University (author: Dongji Gao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
|
||||
This script takes as input `lang_dir`, which should contain::
|
||||
|
||||
- lang_dir/bpe.model,
|
||||
- lang_dir/words.txt
|
||||
|
||||
and generates the following files in the directory `lang_dir`:
|
||||
|
||||
- lexicon.txt
|
||||
- lexicon_disambig.txt
|
||||
- L.pt
|
||||
- L_disambig.pt
|
||||
- tokens.txt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from prepare_lang import (
|
||||
Lexicon,
|
||||
add_disambig_symbols,
|
||||
add_self_loops,
|
||||
write_lexicon,
|
||||
write_mapping,
|
||||
)
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def lexicon_to_fst_no_sil(
|
||||
lexicon: Lexicon,
|
||||
token2id: Dict[str, int],
|
||||
word2id: Dict[str, int],
|
||||
need_self_loops: bool = False,
|
||||
) -> k2.Fsa:
|
||||
"""Convert a lexicon to an FST (in k2 format).
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
The input lexicon. See also :func:`read_lexicon`
|
||||
token2id:
|
||||
A dict mapping tokens to IDs.
|
||||
word2id:
|
||||
A dict mapping words to IDs.
|
||||
need_self_loops:
|
||||
If True, add self-loop to states with non-epsilon output symbols
|
||||
on at least one arc out of the state. The input label for this
|
||||
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
|
||||
Returns:
|
||||
Return an instance of `k2.Fsa` representing the given lexicon.
|
||||
"""
|
||||
loop_state = 0 # words enter and leave from here
|
||||
next_state = 1 # the next un-allocated state, will be incremented as we go
|
||||
|
||||
arcs = []
|
||||
|
||||
# The blank symbol <blk> is defined in local/train_bpe_model.py
|
||||
assert token2id["<blk>"] == 0
|
||||
assert word2id["<eps>"] == 0
|
||||
|
||||
eps = 0
|
||||
|
||||
for word, pieces in lexicon:
|
||||
assert len(pieces) > 0, f"{word} has no pronunciations"
|
||||
cur_state = loop_state
|
||||
|
||||
word = word2id[word]
|
||||
pieces = [token2id[i] for i in pieces]
|
||||
|
||||
for i in range(len(pieces) - 1):
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, next_state, pieces[i], w, 0])
|
||||
|
||||
cur_state = next_state
|
||||
next_state += 1
|
||||
|
||||
# now for the last piece of this word
|
||||
i = len(pieces) - 1
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, loop_state, pieces[i], w, 0])
|
||||
|
||||
if need_self_loops:
|
||||
disambig_token = token2id["#0"]
|
||||
disambig_word = word2id["#0"]
|
||||
arcs = add_self_loops(
|
||||
arcs,
|
||||
disambig_token=disambig_token,
|
||||
disambig_word=disambig_word,
|
||||
)
|
||||
|
||||
final_state = next_state
|
||||
arcs.append([loop_state, final_state, -1, -1, 0])
|
||||
arcs.append([final_state])
|
||||
|
||||
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||
arcs = [" ".join(arc) for arc in arcs]
|
||||
arcs = "\n".join(arcs)
|
||||
|
||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||
return fsa
|
||||
|
||||
|
||||
def generate_otc_lexicon(
|
||||
model_file: str,
|
||||
words: List[str],
|
||||
oov: str,
|
||||
otc_token: str,
|
||||
) -> Tuple[Lexicon, Dict[str, int]]:
|
||||
"""Generate a lexicon from a BPE model.
|
||||
|
||||
Args:
|
||||
model_file:
|
||||
Path to a sentencepiece model.
|
||||
words:
|
||||
A list of strings representing words.
|
||||
oov:
|
||||
The out of vocabulary word in lexicon.
|
||||
otc_token:
|
||||
The OTC token in lexicon.
|
||||
Returns:
|
||||
Return a tuple with two elements:
|
||||
- A dict whose keys are words and values are the corresponding
|
||||
word pieces.
|
||||
- A dict representing the token symbol, mapping from tokens to IDs.
|
||||
"""
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(str(model_file))
|
||||
|
||||
# Convert word to word piece IDs instead of word piece strings
|
||||
# to avoid OOV tokens.
|
||||
words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int)
|
||||
|
||||
# Now convert word piece IDs back to word piece strings.
|
||||
words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids]
|
||||
|
||||
lexicon = []
|
||||
for word, pieces in zip(words, words_pieces):
|
||||
lexicon.append((word, pieces))
|
||||
|
||||
lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())]))
|
||||
token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())}
|
||||
|
||||
# Add OTC token to the last.
|
||||
lexicon.append((otc_token, [f"▁{otc_token}"]))
|
||||
otc_token_index = len(token2id)
|
||||
token2id[f"▁{otc_token}"] = otc_token_index
|
||||
|
||||
return lexicon, token2id
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
It should contain the bpe.model and words.txt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--oov",
|
||||
type=str,
|
||||
default="<UNK>",
|
||||
help="The out of vocabulary word in lexicon.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--otc-token",
|
||||
type=str,
|
||||
default="<star>",
|
||||
help="The OTC token in lexicon.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True for debugging, which will generate
|
||||
a visualization of the lexicon FST.
|
||||
|
||||
Caution: If your lexicon contains hundreds of thousands
|
||||
of lines, please set it to False!
|
||||
|
||||
See "test/test_bpe_lexicon.py" for usage.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
model_file = lang_dir / "bpe.model"
|
||||
otc_token = args.otc_token
|
||||
|
||||
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
|
||||
words = word_sym_table.symbols
|
||||
|
||||
excluded = [
|
||||
"<eps>",
|
||||
"!SIL",
|
||||
"<SPOKEN_NOISE>",
|
||||
args.oov,
|
||||
otc_token,
|
||||
"#0",
|
||||
"<s>",
|
||||
"</s>",
|
||||
]
|
||||
|
||||
for w in excluded:
|
||||
if w in words:
|
||||
words.remove(w)
|
||||
|
||||
lexicon, token_sym_table = generate_otc_lexicon(
|
||||
model_file, words, args.oov, otc_token
|
||||
)
|
||||
|
||||
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||
|
||||
next_token_id = max(token_sym_table.values()) + 1
|
||||
for i in range(max_disambig + 1):
|
||||
disambig = f"#{i}"
|
||||
assert disambig not in token_sym_table
|
||||
token_sym_table[disambig] = next_token_id
|
||||
next_token_id += 1
|
||||
|
||||
word_sym_table.add("#0")
|
||||
word_sym_table.add("<s>")
|
||||
word_sym_table.add("</s>")
|
||||
|
||||
write_mapping(lang_dir / "tokens.txt", token_sym_table)
|
||||
|
||||
write_lexicon(lang_dir / "lexicon.txt", lexicon)
|
||||
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||
|
||||
L = lexicon_to_fst_no_sil(
|
||||
lexicon,
|
||||
token2id=token_sym_table,
|
||||
word2id=word_sym_table,
|
||||
)
|
||||
|
||||
L_disambig = lexicon_to_fst_no_sil(
|
||||
lexicon_disambig,
|
||||
token2id=token_sym_table,
|
||||
word2id=word_sym_table,
|
||||
need_self_loops=True,
|
||||
)
|
||||
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||
|
||||
if args.debug:
|
||||
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
|
||||
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
|
||||
L.labels_sym = labels_sym
|
||||
L.aux_labels_sym = aux_labels_sym
|
||||
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
|
||||
|
||||
L_disambig.labels_sym = labels_sym
|
||||
L_disambig.aux_labels_sym = aux_labels_sym
|
||||
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
100
egs/librispeech/WSASR/local/train_bpe_model.py
Executable file
100
egs/librispeech/WSASR/local/train_bpe_model.py
Executable file
@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# You can install sentencepiece via:
|
||||
#
|
||||
# pip install sentencepiece
|
||||
#
|
||||
# Due to an issue reported in
|
||||
# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
|
||||
#
|
||||
# Please install a version >=0.1.96
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
The generated bpe.model is saved to this directory.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--transcript",
|
||||
type=str,
|
||||
help="Training transcript.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocab-size",
|
||||
type=int,
|
||||
help="Vocabulary size for BPE training",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
vocab_size = args.vocab_size
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
model_type = "unigram"
|
||||
|
||||
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
|
||||
train_text = args.transcript
|
||||
character_coverage = 1.0
|
||||
input_sentence_size = 100000000
|
||||
|
||||
user_defined_symbols = ["<blk>", "<sos/eos>"]
|
||||
unk_id = len(user_defined_symbols)
|
||||
# Note: unk_id is fixed to 2.
|
||||
# If you change it, you should also change other
|
||||
# places that are using it.
|
||||
|
||||
model_file = Path(model_prefix + ".model")
|
||||
if not model_file.is_file():
|
||||
spm.SentencePieceTrainer.train(
|
||||
input=train_text,
|
||||
vocab_size=vocab_size,
|
||||
model_type=model_type,
|
||||
model_prefix=model_prefix,
|
||||
input_sentence_size=input_sentence_size,
|
||||
character_coverage=character_coverage,
|
||||
user_defined_symbols=user_defined_symbols,
|
||||
unk_id=unk_id,
|
||||
bos_id=-1,
|
||||
eos_id=-1,
|
||||
)
|
||||
else:
|
||||
print(f"{model_file} exists - skipping")
|
||||
return
|
||||
|
||||
shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
85
egs/librispeech/WSASR/local/validate_bpe_lexicon.py
Executable file
85
egs/librispeech/WSASR/local/validate_bpe_lexicon.py
Executable file
@ -0,0 +1,85 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script checks that there are no OOV tokens in the BPE-based lexicon.
|
||||
|
||||
Usage example:
|
||||
|
||||
python3 ./local/validate_bpe_lexicon.py \
|
||||
--lexicon /path/to/lexicon.txt \
|
||||
--bpe-model /path/to/bpe.model
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
from icefall.lexicon import read_lexicon
|
||||
|
||||
# Map word to word pieces
|
||||
Lexicon = List[Tuple[str, List[str]]]
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--lexicon",
|
||||
required=True,
|
||||
type=Path,
|
||||
help="Path to lexicon.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
required=True,
|
||||
type=Path,
|
||||
help="Path to bpe.model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--otc-token",
|
||||
required=True,
|
||||
type=str,
|
||||
help="OTC token",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
assert args.lexicon.is_file(), args.lexicon
|
||||
assert args.bpe_model.is_file(), args.bpe_model
|
||||
|
||||
lexicon = read_lexicon(args.lexicon)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(str(args.bpe_model))
|
||||
|
||||
word_pieces = set(sp.id_to_piece(list(range(sp.vocab_size()))))
|
||||
word_pieces.add(f"▁{args.otc_token}")
|
||||
for word, pieces in lexicon:
|
||||
for p in pieces:
|
||||
if p not in word_pieces:
|
||||
raise ValueError(f"The word {word} contains an OOV token {p}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
92
egs/librispeech/WSASR/local/validate_manifest.py
Executable file
92
egs/librispeech/WSASR/local/validate_manifest.py
Executable file
@ -0,0 +1,92 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script checks the following assumptions of the generated manifest:
|
||||
|
||||
- Single supervision per cut
|
||||
- Supervision time bounds are within cut time bounds
|
||||
|
||||
We will add more checks later if needed.
|
||||
|
||||
Usage example:
|
||||
|
||||
python3 ./local/validate_manifest.py \
|
||||
./data/fbank/librispeech_cuts_train-clean-100.jsonl.gz
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.cut import Cut
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"manifest",
|
||||
type=Path,
|
||||
help="Path to the manifest file",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def validate_one_supervision_per_cut(c: Cut):
|
||||
if len(c.supervisions) != 1:
|
||||
raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions")
|
||||
|
||||
|
||||
def validate_supervision_and_cut_time_bounds(c: Cut):
|
||||
s = c.supervisions[0]
|
||||
if s.start < c.start:
|
||||
raise ValueError(
|
||||
f"{c.id}: Supervision start time {s.start} is less "
|
||||
f"than cut start time {c.start}"
|
||||
)
|
||||
|
||||
if s.end > c.end:
|
||||
raise ValueError(
|
||||
f"{c.id}: Supervision end time {s.end} is larger "
|
||||
f"than cut end time {c.end}"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
|
||||
manifest = args.manifest
|
||||
logging.info(f"Validating {manifest}")
|
||||
|
||||
assert manifest.is_file(), f"{manifest} does not exist"
|
||||
cut_set = load_manifest_lazy(manifest)
|
||||
assert isinstance(cut_set, CutSet)
|
||||
|
||||
for c in cut_set:
|
||||
validate_one_supervision_per_cut(c)
|
||||
validate_supervision_and_cut_time_bounds(c)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
233
egs/librispeech/WSASR/prepare.sh
Executable file
233
egs/librispeech/WSASR/prepare.sh
Executable file
@ -0,0 +1,233 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
nj=15
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
# directories and files. If not, they will be downloaded
|
||||
# by this script automatically.
|
||||
#
|
||||
# - $dl_dir/LibriSpeech
|
||||
# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it.
|
||||
# You can download them from https://www.openslr.org/12
|
||||
#
|
||||
# - $dl_dir/lm
|
||||
# This directory contains the following files downloaded from
|
||||
# http://www.openslr.org/resources/11
|
||||
#
|
||||
# - 3-gram.pruned.1e-7.arpa.gz
|
||||
# - 3-gram.pruned.1e-7.arpa
|
||||
# - 4-gram.arpa.gz
|
||||
# - 4-gram.arpa
|
||||
# - librispeech-vocab.txt
|
||||
# - librispeech-lexicon.txt
|
||||
# - librispeech-lm-norm.txt.gz
|
||||
#
|
||||
otc_token="<star>"
|
||||
feature_type="ssl"
|
||||
|
||||
dl_dir=$PWD/download
|
||||
manifests_dir="data/manifests"
|
||||
feature_dir="data/${feature_type}"
|
||||
lang_dir="data/lang"
|
||||
lm_dir="data/lm"
|
||||
|
||||
perturb_speed=false
|
||||
|
||||
# ssl or fbank
|
||||
|
||||
. ./cmd.sh
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# vocab size for sentence piece models.
|
||||
# It will generate data/lang_bpe_xxx,
|
||||
# data/lang_bpe_yyy if the array contains xxx, yyy
|
||||
vocab_sizes=(
|
||||
200
|
||||
)
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: ${dl_dir}"
|
||||
|
||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
log "Stage -1: Download LM"
|
||||
mkdir -p ${dl_dir}/lm
|
||||
if [ ! -e ${dl_dir}/lm/.done ]; then
|
||||
./local/download_lm.py --out-dir=${dl_dir}/lm
|
||||
touch ${dl_dir}/lm/.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download data"
|
||||
|
||||
# If you have pre-downloaded it to /path/to/LibriSpeech,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/LibriSpeech $dl_dir/LibriSpeech
|
||||
#
|
||||
if [ ! -d $dl_dir/LibriSpeech/train-clean-100 ]; then
|
||||
lhotse download librispeech --full ${dl_dir}
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare LibriSpeech manifest"
|
||||
# We assume that you have downloaded the LibriSpeech corpus
|
||||
# to $dl_dir/LibriSpeech
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.librispeech.done ]; then
|
||||
lhotse prepare librispeech -j ${nj} \
|
||||
-p dev-clean \
|
||||
-p dev-other \
|
||||
-p test-clean \
|
||||
-p test-other \
|
||||
-p train-clean-100 "${dl_dir}/LibriSpeech" "${manifests_dir}"
|
||||
touch data/manifests/.librispeech.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Compute ${feature_type} feature for librispeech (train-clean-100)"
|
||||
mkdir -p "${feature_dir}"
|
||||
if [ ! -e "${feature_dir}/.librispeech.done" ]; then
|
||||
if [ "${feature_type}" = ssl ]; then
|
||||
./local/compute_ssl_librispeech.py
|
||||
elif [ "${feature_type}" = fbank ]; then
|
||||
./local/compute_fbank_librispeech.py --perturb-speed ${perturb_speed}
|
||||
else
|
||||
log "Error: not supported --feature-type '${feature_type}'"
|
||||
exit 2
|
||||
fi
|
||||
|
||||
touch "${feature_dir}.librispeech.done"
|
||||
fi
|
||||
|
||||
if [ ! -e "${feature_dir}/.librispeech-validated.done" ]; then
|
||||
log "Validating data/ssl for LibriSpeech"
|
||||
parts=(
|
||||
train-clean-100
|
||||
test-clean
|
||||
test-other
|
||||
dev-clean
|
||||
dev-other
|
||||
)
|
||||
for part in ${parts[@]}; do
|
||||
python3 ./local/validate_manifest.py \
|
||||
"${feature_dir}/librispeech_cuts_${part}.jsonl.gz"
|
||||
done
|
||||
touch "${feature_dir}/.librispeech-validated.done"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Prepare words.txt"
|
||||
mkdir -p ${lang_dir}
|
||||
|
||||
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
|
||||
cat - $dl_dir/lm/librispeech-lexicon.txt |
|
||||
sort | uniq > ${lang_dir}/lexicon.txt
|
||||
|
||||
local/get_words_from_lexicon.py \
|
||||
--lang-dir ${lang_dir} \
|
||||
--otc-token ${otc_token}
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Prepare BPE based lang"
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
bpe_lang_dir="data/lang_bpe_${vocab_size}"
|
||||
mkdir -p "${bpe_lang_dir}"
|
||||
# We reuse words.txt from phone based lexicon
|
||||
# so that the two can share G.pt later.
|
||||
cp "${lang_dir}/words.txt" "${bpe_lang_dir}"
|
||||
|
||||
if [ ! -f "${bpe_lang_dir}/transcript_words.txt" ]; then
|
||||
log "Generate data for BPE training"
|
||||
files=$(
|
||||
find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
|
||||
find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt"
|
||||
find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt"
|
||||
)
|
||||
for f in ${files[@]}; do
|
||||
cat $f | cut -d " " -f 2-
|
||||
done > "${bpe_lang_dir}/transcript_words.txt"
|
||||
fi
|
||||
|
||||
if [ ! -f ${bpe_lang_dir}/bpe.model ]; then
|
||||
./local/train_bpe_model.py \
|
||||
--lang-dir ${bpe_lang_dir} \
|
||||
--vocab-size ${vocab_size} \
|
||||
--transcript ${bpe_lang_dir}/transcript_words.txt
|
||||
fi
|
||||
|
||||
if [ ! -f ${bpe_lang_dir}/L_disambig.pt ]; then
|
||||
./local/prepare_otc_lang_bpe.py \
|
||||
--lang-dir "${bpe_lang_dir}" \
|
||||
--otc-token "${otc_token}"
|
||||
|
||||
log "Validating ${bpe_lang_dir}/lexicon.txt"
|
||||
./local/validate_bpe_lexicon.py \
|
||||
--lexicon ${bpe_lang_dir}/lexicon.txt \
|
||||
--bpe-model ${bpe_lang_dir}/bpe.model \
|
||||
--otc-token "${otc_token}"
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Prepare G"
|
||||
# We assume you have install kaldilm, if not, please install
|
||||
# it using: pip install kaldilm
|
||||
|
||||
mkdir -p "${lm_dir}"
|
||||
if [ ! -f ${lm_dir}/G_3_gram.fst.txt ]; then
|
||||
# It is used in building HLG
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="${lang_dir}/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=3 \
|
||||
${dl_dir}/lm/3-gram.pruned.1e-7.arpa > ${lm_dir}/G_3_gram.fst.txt
|
||||
fi
|
||||
|
||||
if [ ! -f ${lm_dir}/G_4_gram.fst.txt ]; then
|
||||
# It is used for LM rescoring
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="${lang_dir}/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=4 \
|
||||
${dl_dir}/lm/4-gram.arpa > ${lm_dir}/G_4_gram.fst.txt
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Compile HLG"
|
||||
# Note If ./local/compile_hlg.py throws OOM,
|
||||
# please switch to the following command
|
||||
#
|
||||
# ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
bpe_lang_dir="data/lang_bpe_${vocab_size}"
|
||||
echo "LM DIR: ${lm_dir}"
|
||||
./local/compile_hlg.py \
|
||||
--lm-dir "${lm_dir}" \
|
||||
--lang-dir "${bpe_lang_dir}"
|
||||
done
|
||||
fi
|
246
icefall/otc_graph_compiler.py
Normal file
246
icefall/otc_graph_compiler.py
Normal file
@ -0,0 +1,246 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# 2023 Johns Hopkins University (author: Dongji Gao)
|
||||
#
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class OtcTrainingGraphCompiler(object):
|
||||
def __init__(
|
||||
self,
|
||||
lang_dir: Path,
|
||||
otc_token: str,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
sos_token: str = "<sos/eos>",
|
||||
eos_token: str = "<sos/eos>",
|
||||
initial_bypass_weight: float = 0.0,
|
||||
initial_self_loop_weight: float = 0.0,
|
||||
bypass_weight_decay: float = 0.0,
|
||||
self_loop_weight_decay: float = 0.0,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
This directory is expected to contain the following files:
|
||||
|
||||
- bpe.model
|
||||
- words.txt
|
||||
otc_token:
|
||||
The special token in OTC that represent all non-blank tokens
|
||||
device:
|
||||
It indicates CPU or CUDA.
|
||||
sos_token:
|
||||
The word piece that represents sos.
|
||||
eos_token:
|
||||
The word piece that represents eos.
|
||||
"""
|
||||
lang_dir = Path(lang_dir)
|
||||
bpe_model_file = lang_dir / "bpe.model"
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(str(bpe_model_file))
|
||||
self.sp = sp
|
||||
self.token_table = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
|
||||
|
||||
self.otc_token = otc_token
|
||||
assert self.otc_token in self.token_table
|
||||
|
||||
self.device = device
|
||||
|
||||
self.sos_id = self.sp.piece_to_id(sos_token)
|
||||
self.eos_id = self.sp.piece_to_id(eos_token)
|
||||
|
||||
assert self.sos_id != self.sp.unk_id()
|
||||
assert self.eos_id != self.sp.unk_id()
|
||||
|
||||
max_token_id = self.get_max_token_id()
|
||||
ctc_topo = k2.ctc_topo(max_token_id, modified=False)
|
||||
self.ctc_topo = ctc_topo.to(self.device)
|
||||
|
||||
self.initial_bypass_weight = initial_bypass_weight
|
||||
self.initial_self_loop_weight = initial_self_loop_weight
|
||||
self.bypass_weight_decay = bypass_weight_decay
|
||||
self.self_loop_weight_decay = self_loop_weight_decay
|
||||
|
||||
def get_max_token_id(self):
|
||||
max_token_id = 0
|
||||
for symbol in self.token_table.symbols:
|
||||
if not symbol.startswith("#"):
|
||||
max_token_id = max(self.token_table[symbol], max_token_id)
|
||||
assert max_token_id > 0
|
||||
|
||||
return max_token_id
|
||||
|
||||
def make_arc(
|
||||
self,
|
||||
from_state: int,
|
||||
to_state: int,
|
||||
symbol: Union[str, int],
|
||||
weight: float,
|
||||
):
|
||||
return f"{from_state} {to_state} {symbol} {weight}"
|
||||
|
||||
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
|
||||
"""Convert a list of texts to a list-of-list of piece IDs.
|
||||
|
||||
Args:
|
||||
texts:
|
||||
It is a list of strings. Each string consists of space(s)
|
||||
separated words. An example containing two strings is given below:
|
||||
|
||||
['HELLO ICEFALL', 'HELLO k2']
|
||||
Returns:
|
||||
Return a list-of-list of piece IDs.
|
||||
"""
|
||||
return self.sp.encode(texts, out_type=int)
|
||||
|
||||
def compile(
|
||||
self,
|
||||
texts: List[str],
|
||||
allow_bypass_arc: str2bool = True,
|
||||
allow_self_loop_arc: str2bool = True,
|
||||
bypass_weight: float = 0.0,
|
||||
self_loop_weight: float = 0.0,
|
||||
) -> k2.Fsa:
|
||||
"""Build a OTC graph from a texts (list of words).
|
||||
|
||||
Args:
|
||||
texts:
|
||||
A list of strings. Each string contains a sentence for an utterance.
|
||||
A sentence consists of spaces separated words. An example `texts`
|
||||
looks like:
|
||||
['hello icefall', 'CTC training with k2']
|
||||
allow_bypass_arc:
|
||||
Whether to add bypass arc to training graph for substitution
|
||||
and insertion errors (wrong or extra words in the transcript).
|
||||
allow_self_loop_arc:
|
||||
Whether to add self-loop arc to training graph for deletion
|
||||
errors (missing words in the transcript).
|
||||
bypass_weight:
|
||||
Weight associated with bypass arc.
|
||||
self_loop_weight:
|
||||
Weight associated with self-loop arc.
|
||||
|
||||
Return:
|
||||
Return an FsaVec, which is the result of composing a
|
||||
CTC topology with OTC FSAs constructed from the given texts.
|
||||
"""
|
||||
|
||||
transcript_fsa = self.convert_transcript_to_fsa(
|
||||
texts,
|
||||
self.otc_token,
|
||||
allow_bypass_arc,
|
||||
allow_self_loop_arc,
|
||||
bypass_weight,
|
||||
self_loop_weight,
|
||||
)
|
||||
transcript_fsa = transcript_fsa.to(self.device)
|
||||
fsa_with_self_loop = k2.remove_epsilon_and_add_self_loops(transcript_fsa)
|
||||
fsa_with_self_loop = k2.arc_sort(fsa_with_self_loop)
|
||||
|
||||
graph = k2.compose(
|
||||
self.ctc_topo,
|
||||
fsa_with_self_loop,
|
||||
treat_epsilons_specially=False,
|
||||
)
|
||||
assert graph.requires_grad is False
|
||||
|
||||
return graph
|
||||
|
||||
def convert_transcript_to_fsa(
|
||||
self,
|
||||
texts: List[str],
|
||||
otc_token: str,
|
||||
allow_bypass_arc: str2bool = True,
|
||||
allow_self_loop_arc: str2bool = True,
|
||||
bypass_weight: float = 0.0,
|
||||
self_loop_weight: float = 0.0,
|
||||
):
|
||||
otc_token_id = self.token_table[otc_token]
|
||||
|
||||
transcript_fsa_list = []
|
||||
for text in texts:
|
||||
text_piece_ids = []
|
||||
|
||||
for word in text.split():
|
||||
piece_ids = self.sp.encode(word, out_type=int)
|
||||
text_piece_ids.append(piece_ids)
|
||||
|
||||
arcs = []
|
||||
start_state = 0
|
||||
cur_state = start_state
|
||||
next_state = 1
|
||||
|
||||
for piece_ids in text_piece_ids:
|
||||
bypass_cur_state = cur_state
|
||||
|
||||
if allow_self_loop_arc:
|
||||
self_loop_arc = self.make_arc(
|
||||
cur_state,
|
||||
cur_state,
|
||||
otc_token_id,
|
||||
self_loop_weight,
|
||||
)
|
||||
arcs.append(self_loop_arc)
|
||||
|
||||
for piece_id in piece_ids:
|
||||
arc = self.make_arc(cur_state, next_state, piece_id, 0.0)
|
||||
arcs.append(arc)
|
||||
|
||||
cur_state = next_state
|
||||
next_state += 1
|
||||
|
||||
bypass_next_state = cur_state
|
||||
if allow_bypass_arc:
|
||||
bypass_arc = self.make_arc(
|
||||
bypass_cur_state,
|
||||
bypass_next_state,
|
||||
otc_token_id,
|
||||
bypass_weight,
|
||||
)
|
||||
arcs.append(bypass_arc)
|
||||
bypass_cur_state = cur_state
|
||||
|
||||
if allow_self_loop_arc:
|
||||
self_loop_arc = self.make_arc(
|
||||
cur_state,
|
||||
cur_state,
|
||||
otc_token_id,
|
||||
self_loop_weight,
|
||||
)
|
||||
arcs.append(self_loop_arc)
|
||||
|
||||
# Deal with final state
|
||||
final_state = next_state
|
||||
final_arc = self.make_arc(cur_state, final_state, -1, 0.0)
|
||||
arcs.append(final_arc)
|
||||
arcs.append(f"{final_state}")
|
||||
sorted_arcs = sorted(arcs, key=lambda a: int(a.split()[0]))
|
||||
|
||||
transcript_fsa = k2.Fsa.from_str("\n".join(sorted_arcs))
|
||||
transcript_fsa = k2.arc_sort(transcript_fsa)
|
||||
transcript_fsa_list.append(transcript_fsa)
|
||||
|
||||
transcript_fsa_vec = k2.create_fsa_vec(transcript_fsa_list)
|
||||
|
||||
return transcript_fsa_vec
|
@ -263,6 +263,70 @@ def get_texts(
|
||||
return aux_labels.tolist()
|
||||
|
||||
|
||||
def encode_supervisions_otc(
|
||||
supervisions: dict,
|
||||
subsampling_factor: int,
|
||||
token_ids: Optional[List[List[int]]] = None,
|
||||
) -> Tuple[torch.Tensor, Union[List[str], List[List[int]]]]:
|
||||
"""
|
||||
Encodes Lhotse's ``batch["supervisions"]`` dict into
|
||||
a pair of torch Tensor, and a list of transcription strings or token indexes
|
||||
|
||||
The supervision tensor has shape ``(batch_size, 3)``.
|
||||
Its second dimension contains information about sequence index [0],
|
||||
start frames [1] and num frames [2].
|
||||
|
||||
The batch items might become re-ordered during this operation -- the
|
||||
returned tensor and list of strings are guaranteed to be consistent with
|
||||
each other.
|
||||
"""
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
torch.div(
|
||||
supervisions["start_frame"],
|
||||
subsampling_factor,
|
||||
rounding_mode="floor",
|
||||
),
|
||||
torch.div(
|
||||
supervisions["num_frames"],
|
||||
subsampling_factor,
|
||||
rounding_mode="floor",
|
||||
),
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
||||
indices = torch.argsort(supervision_segments[:, 2], descending=True)
|
||||
supervision_segments = supervision_segments[indices]
|
||||
|
||||
ids = []
|
||||
verbatim_texts = []
|
||||
sorted_ids = []
|
||||
sorted_verbatim_texts = []
|
||||
|
||||
for cut in supervisions["cut"]:
|
||||
id = cut.id
|
||||
if hasattr(cut.supervisions[0], "verbatim_text"):
|
||||
verbatim_text = cut.supervisions[0].verbatim_text
|
||||
else:
|
||||
verbatim_text = ""
|
||||
ids.append(id)
|
||||
verbatim_texts.append(verbatim_text)
|
||||
|
||||
for index in indices.tolist():
|
||||
sorted_ids.append(ids[index])
|
||||
sorted_verbatim_texts.append(verbatim_texts[index])
|
||||
|
||||
if token_ids is None:
|
||||
texts = supervisions["text"]
|
||||
res = [texts[idx] for idx in indices]
|
||||
else:
|
||||
res = [token_ids[idx] for idx in indices]
|
||||
|
||||
return supervision_segments, res, sorted_ids, sorted_verbatim_texts
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodingResults:
|
||||
# timestamps[i][k] contains the frame number on which tokens[i][k]
|
||||
|
Loading…
x
Reference in New Issue
Block a user