Merge 4d7eefb02d201ee5c6cc5d7920bed9d3ed77ca22 into fba5e67d5e14c808cea7f2bf5ccc7fa0c248cc5c

This commit is contained in:
Sathvik Udupa 2025-07-07 16:33:55 +08:00 committed by GitHub
commit 6b091921e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 3403 additions and 0 deletions

74
egs/mucs/ASR/RESULTS.md Normal file
View File

@ -0,0 +1,74 @@
# Results for mucs hi-en and bn-en
This page shows the WERs for the code switched test corpus of MUCS hi-en and bn-en.
## using conformer ctc
The following results are obtained with run.sh
Specify the language through dataset arg (hi-en or bn-en)
LM is trained using kenlm, with the training corpus
Here are the results with different decoding methods
bn-en
| | test |
|-------------------------|-------|
| ctc decoding | 31.72 |
| 1best | 28.05 |
| nbest | 27.92 |
| nbest-rescoring | 27.22 |
| whole-lattice-rescoring | 27.24 |
| attention-decoder | 26.46 |
hi-en
| | test |
|-------------------------|-------|
| ctc decoding | 31.43 |
| 1best | 28.48 |
| nbest | 28.55 |
| nbest-rescoring | 28.23 |
| whole-lattice-rescoring | 28.77 |
| attention-decoder | 28.16 |
The training command for reproducing is given below:
```bash
cd egs/mucs/ASR/
./prepare.sh
dataset="hi-en" #hi-en or bn-en
bpe=400
datadir=data_"$dataset"
./conformer_ctc/train.py \
--num-epochs 60 \
--max-duration 300 \
--exp-dir ./conformer_ctc/exp_"$dataset"_bpe"$bpe" \
--manifest-dir $datadir/fbank \
--lang-dir $datadir/lang_bpe_"$bpe" \
--enable-musan False \
```
The decoding command is given below:
```bash
dataset="hi-en" #hi-en or bn-en
bpe=400
datadir=data_"$dataset"
num_paths=10
max_duration=10
decode_methods="attention-decoder 1best nbest nbest-rescoring ctc-decoding whole-lattice-rescoring"
for decode_method in $decode_methods;
do
./conformer_ctc/decode.py \
--epoch 59 \
--avg 10 \
--manifest-dir $datadir/fbank \
--exp-dir ./conformer_ctc/exp_"$dataset"_bpe"$bpe" \
--max-duration $max_duration \
--lang-dir $datadir/lang_bpe_"$bpe" \
--lm-dir $datadir/"lm" \
--method $decode_method \
--num-paths $num_paths \
done
```

View File

@ -0,0 +1,75 @@
## Introduction
Please visit
<https://icefall.readthedocs.io/en/latest/recipes/librispeech/conformer_ctc.html>
for how to run this recipe.
## How to compute framewise alignment information
### Step 1: Train a model
Please use `conformer_ctc/train.py` to train a model.
See <https://icefall.readthedocs.io/en/latest/recipes/librispeech/conformer_ctc.html>
for how to do it.
### Step 2: Compute framewise alignment
Run
```
# Choose a checkpoint and determine the number of checkpoints to average
epoch=30
avg=15
./conformer_ctc/ali.py \
--epoch $epoch \
--avg $avg \
--max-duration 500 \
--bucketing-sampler 0 \
--full-libri 1 \
--exp-dir conformer_ctc/exp \
--lang-dir data/lang_bpe_500 \
--ali-dir data/ali_500
```
and you will get four files inside the folder `data/ali_500`:
```
$ ls -lh data/ali_500
total 546M
-rw-r--r-- 1 kuangfangjun root 1.1M Sep 28 08:06 test_clean.pt
-rw-r--r-- 1 kuangfangjun root 1.1M Sep 28 08:07 test_other.pt
-rw-r--r-- 1 kuangfangjun root 542M Sep 28 11:36 train-960.pt
-rw-r--r-- 1 kuangfangjun root 2.1M Sep 28 11:38 valid.pt
```
**Note**: It can take more than 3 hours to compute the alignment
for the training dataset, which contains 960 * 3 = 2880 hours of data.
**Caution**: The model parameters in `conformer_ctc/ali.py` have to match those
in `conformer_ctc/train.py`.
**Caution**: You have to set the parameter `preserve_id` to `True` for `CutMix`.
Search `./conformer_ctc/asr_datamodule.py` for `preserve_id`.
### Step 3: Check your extracted alignments
There is a file `test_ali.py` in `icefall/test` that can be used to test your
alignments. It uses pre-computed alignments to modify a randomly generated
`nnet_output` and it checks that we can decode the correct transcripts
from the resulting `nnet_output`.
You should get something like the following if you run that script:
```
$ ./test/test_ali.py
['THE GOOD NATURED AUDIENCE IN PITY TO FALLEN MAJESTY SHOWED FOR ONCE GREATER DEFERENCE TO THE KING THAN TO THE MINISTER AND SUNG THE PSALM WHICH THE FORMER HAD CALLED FOR', 'THE OLD SERVANT TOLD HIM QUIETLY AS THEY CREPT BACK TO DWELL THAT THIS PASSAGE THAT LED FROM THE HUT IN THE PLEASANCE TO SHERWOOD AND THAT GEOFFREY FOR THE TIME WAS HIDING WITH THE OUTLAWS IN THE FOREST', 'FOR A WHILE SHE LAY IN HER CHAIR IN HAPPY DREAMY PLEASURE AT SUN AND BIRD AND TREE', "BUT THE ESSENCE OF LUTHER'S LECTURES IS THERE"]
['THE GOOD NATURED AUDIENCE IN PITY TO FALLEN MAJESTY SHOWED FOR ONCE GREATER DEFERENCE TO THE KING THAN TO THE MINISTER AND SUNG THE PSALM WHICH THE FORMER HAD CALLED FOR', 'THE OLD SERVANT TOLD HIM QUIETLY AS THEY CREPT BACK TO GAMEWELL THAT THIS PASSAGE WAY LED FROM THE HUT IN THE PLEASANCE TO SHERWOOD AND THAT GEOFFREY FOR THE TIME WAS HIDING WITH THE OUTLAWS IN THE FOREST', 'FOR A WHILE SHE LAY IN HER CHAIR IN HAPPY DREAMY PLEASURE AT SUN AND BIRD AND TREE', "BUT THE ESSENCE OF LUTHER'S LECTURES IS THERE"]
```
### Step 4: Use your alignments in training
Please refer to `conformer_mmi/train.py` for usage. Some useful
functions are:
- `load_alignments()`, it loads alignment saved by `conformer_ctc/ali.py`
- `convert_alignments_to_tensor()`, it converts alignments to PyTorch tensors
- `lookup_alignments()`, it returns the alignments of utterances by giving the cut ID of the utterances.

View File

@ -0,0 +1 @@
../../../librispeech/ASR/conformer_ctc/__init__.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/conformer_ctc/ali.py

View File

@ -0,0 +1,419 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class MUCSAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train, valid dataloader, and one test loader
This modified from librispeech asrmodule
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--full-libri",
type=str2bool,
default=True,
help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "mucs_cuts_train.jsonl.gz"
)
@lru_cache()
def dev_mucs_cuts(self) -> CutSet:
logging.info("About to get valid-mucs")
return load_manifest_lazy(
self.args.manifest_dir / "mucs_cuts_dev.jsonl.gz"
)
@lru_cache()
def test_mucs_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "mucs_cuts_test.jsonl.gz"
)
@lru_cache()
def train_clean_mucs_cuts(self) -> CutSet:
logging.info("About to get train-mucs")
return load_manifest_lazy(
self.args.manifest_dir / "mucs_cuts_train.jsonl.gz"
)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/conformer_ctc/conformer.py

View File

@ -0,0 +1,813 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import MUCSAsrDataModule
from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
rescore_with_rnn_lm,
rescore_with_whole_lattice,
)
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import (
AttributeDict,
get_texts,
load_averaged_model,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=77,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=55,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--method",
type=str,
default="attention-decoder",
help="""Decoding method.
Supported values are:
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (1) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result.
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
is the decoding result.
- (5) attention-decoder. Extract n paths from the LM rescored
lattice, the path with the highest score is the decoding result.
- (6) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume
you have trained an RNN LM using ./rnn_lm/train.py
- (7) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
A smaller value results in more unique paths.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_500",
help="The lang dir",
)
parser.add_argument(
"--lm-dir",
type=str,
default="data/lm",
help="""The n-gram LM dir.
It should contain either G_4_gram.pt or G_4_gram.fst.txt
""",
)
parser.add_argument(
"--rnn-lm-exp-dir",
type=str,
default="rnn_lm/exp",
help="""Used only when --method is rnn-lm.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is rnn-lm.
It specifies the checkpoint to use.
""",
)
parser.add_argument(
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is rnn-lm.
It specifies the number of checkpoints to average.
""",
)
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"subsampling_factor": 4,
"vgg_frontend": False,
"use_feat_batchnorm": True,
"feature_dim": 80,
"nhead": 8,
"attention_dim": 512,
"num_decoder_layers": 6,
# parameters for decoding
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
"env_info": get_env_info(),
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
rnn_lm_model: Optional[nn.Module],
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
batch: dict,
word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if no rescoring is used, the key is the string `no_rescore`.
If LM rescoring is used, the key is the string `lm_scale_xxx`,
where `xxx` is the value of `lm_scale`. An example key is
`lm_scale_0.7`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.method is "1best", it uses 1best decoding without LM rescoring.
- params.method is "nbest", it uses nbest decoding without LM rescoring.
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
rescoring.
model:
The neural model.
rnn_lm_model:
The neural model for RNN LM.
HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
The token ID of the EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return the decoding result. See above description for the format of
the returned dict. Note: If it decodes to nothing, then return None.
"""
if HLG is not None:
device = HLG.device
else:
device = H.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
# nnet_output is (N, T, C)
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // params.subsampling_factor,
supervisions["num_frames"] // params.subsampling_factor,
),
1,
).to(torch.int32)
if H is None:
assert HLG is not None
decoding_graph = HLG
else:
assert HLG is None
assert bpe_model is not None
decoding_graph = H
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=decoding_graph,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method == "ctc-decoding":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
# Note: `best_path.aux_labels` contains token IDs, not word IDs
# since we are using H, not HLG here.
#
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "ctc-decoding"
return {key: hyps}
if params.method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is only slightly worse than that of rescored lattices.
best_path = nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
word_table=word_table,
nbest_scale=params.nbest_scale,
oov="<UNK>",
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
return {key: hyps}
if params.method in ["1best", "nbest"]:
if params.method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = "no_rescore"
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
return {key: hyps}
assert params.method in [
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.method == "nbest-rescoring":
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
nbest_scale=params.nbest_scale,
)
elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
elif params.method == "attention-decoder":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
)
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
nbest_scale=params.nbest_scale,
)
elif params.method == "rnn-lm":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
)
best_path_dict = rescore_with_rnn_lm(
lattice=rescored_lattice,
num_paths=params.num_paths,
rnn_lm_model=rnn_lm_model,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
blank_id=0,
nbest_scale=params.nbest_scale,
)
else:
assert False, f"Unsupported decoding method: {params.method}"
ans = dict()
if best_path_dict is not None:
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
else:
ans = None
return ans
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
rnn_lm_model: Optional[nn.Module],
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
rnn_lm_model:
The neural model for RNN LM.
HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
word_table:
It is the word symbol table.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
rnn_lm_model=rnn_lm_model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
batch=batch,
word_table=word_table,
G=G,
sos_id=sos_id,
eos_id=eos_id,
)
if hyps_dict is not None:
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch)
else:
assert len(results) > 0, "It should not decode to empty in the first batch!"
this_batch = []
hyp_words = []
for ref_text in texts:
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
for lm_scale in results.keys():
results[lm_scale].extend(this_batch)
num_cuts += len(texts)
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[int], List[int]]]],
):
if params.method in ("attention-decoder", "rnn-lm"):
# Set it to False since there are too many logs.
enable_log = False
else:
enable_log = True
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=enable_log
)
test_set_wers[key] = wer
if enable_log:
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"wer-summary-{test_set_name}-{params.method}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
MUCSAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
args.lm_dir = Path(args.lm_dir)
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
logging.info("Decoding started")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id
params.num_classes = num_classes
params.sos_id = sos_id
params.eos_id = eos_id
if params.method == "ctc-decoding":
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
else:
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
)
assert HLG.requires_grad is False
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()
if params.method in (
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
):
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take few minutes.")
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set G.properties to None
G.__dict__["_properties"] = None
G = k2.Fsa.from_fsas([G]).to(device)
G = k2.arc_sort(G)
# Save a dummy value so that it can be loaded in C++.
# See https://github.com/pytorch/pytorch/issues/67902
# for why we need to do this.
G.dummy = 1
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
G = k2.Fsa.from_dict(d)
if params.method in [
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
]:
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G = G.to(device)
# G.lm_scores is used to replace HLG.lm_scores during
# LM rescoring.
G.lm_scores = G.scores.clone()
else:
G = None
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
model = load_averaged_model(
params.exp_dir, model, params.epoch, params.avg, device
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
rnn_lm_model = None
if params.method == "rnn-lm":
rnn_lm_model = RnnLmModel(
vocab_size=params.num_classes,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
)
if params.rnn_lm_avg == 1:
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
)
rnn_lm_model.to(device)
else:
rnn_lm_model = load_averaged_model(
params.rnn_lm_exp_dir,
rnn_lm_model,
params.rnn_lm_epoch,
params.rnn_lm_avg,
device,
)
rnn_lm_model.eval()
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = MUCSAsrDataModule(args)
test_clean_cuts = librispeech.test_mucs_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_sets = ["test"]
test_dl = [test_clean_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
rnn_lm_model=rnn_lm_model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
word_table=lexicon.word_table,
G=G,
sos_id=sos_id,
eos_id=eos_id,
)
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/conformer_ctc/export.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/conformer_ctc/label_smoothing.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/conformer_ctc/pretrained.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/conformer_ctc/subsampling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/conformer_ctc/test_label_smoothing.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/conformer_ctc/test_subsampling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/conformer_ctc/test_transformer.py

View File

@ -0,0 +1,824 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./conformer_ctc/train.py \
--exp-dir ./conformer_ctc/exp \
--world-size 4 \
--full-libri 1 \
--max-duration 200 \
--num-epochs 20
"""
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import MUCSAsrDataModule
from conformer import Conformer
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
encode_supervisions,
setup_logger,
str2bool,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=78,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
conformer_ctc/exp/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_500",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--att-rate",
type=float,
default=0.8,
help="""The attention rate.
The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss
""",
)
parser.add_argument(
"--num-decoder-layers",
type=int,
default=6,
help="""Number of decoder layer of transformer decoder.
Setting this to 0 will not create the decoder at all (pure CTC model)
""",
)
parser.add_argument(
"--lr-factor",
type=float,
default=5.0,
help="The lr_factor for Noam optimizer",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
- use_feat_batchnorm: Normalization for the input features, can be a
boolean indicating whether to do batch
normalization, or a float which means just scaling
the input features with this float value.
If given a float value, we will remove batchnorm
layer in `ConvolutionModule` as well.
- attention_dim: Hidden dim for multi-head attention model.
- head: Number of heads of multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- beam_size: It is used in k2.ctc_loss
- reduction: It is used in k2.ctc_loss
- use_double_scores: It is used in k2.ctc_loss
- weight_decay: The weight_decay for the optimizer.
- warm_step: The warm_step for Noam optimizer.
"""
params = AttributeDict(
{
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000,
# parameters for conformer
"feature_dim": 80,
"subsampling_factor": 4,
"use_feat_batchnorm": True,
"attention_dim": 512,
"nhead": 8,
# parameters for loss
"beam_size": 10,
"reduction": "sum",
"use_double_scores": True,
# parameters for Noam
"weight_decay": 1e-6,
"warm_step": 80000,
"env_info": get_env_info(),
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
params: AttributeDict,
model: nn.Module,
batch: dict,
graph_compiler: BpeCtcTrainingGraphCompiler,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Conformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
graph_compiler:
It is used to build a decoding graph from a ctc topo and training
transcript. The training transcript is contained in the given `batch`,
while the ctc topo is built when this compiler is instantiated.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
"""
device = graph_compiler.device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is (N, T, C)
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=params.subsampling_factor
)
if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler):
# Works with a BPE model
token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
elif isinstance(graph_compiler, CtcTrainingGraphCompiler):
# Works with a phone lexicon
decoding_graph = graph_compiler.compile(texts)
else:
raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}")
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
if params.att_rate != 0.0:
with torch.set_grad_enabled(is_training):
mmodel = model.module if hasattr(model, "module") else model
# Note: We need to generate an unsorted version of token_ids
# `encode_supervisions()` called above sorts text, but
# encoder_memory and memory_mask are not sorted, so we
# use an unsorted version `supervisions["text"]` to regenerate
# the token_ids
#
# See https://github.com/k2-fsa/icefall/issues/97
# for more details
unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
att_loss = mmodel.decoder_forward(
encoder_memory,
memory_mask,
token_ids=unsorted_token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
else:
loss = ctc_loss
att_loss = torch.tensor([0])
assert loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = supervision_segments[:, 2].sum().item()
info["ctc_loss"] = ctc_loss.detach().cpu().item()
if params.att_rate != 0.0:
info["att_loss"] = att_loss.detach().cpu().item()
info["loss"] = loss.detach().cpu().item()
# `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
info["utterances"] = feature.size(0)
# averaged input duration in frames over utterances
info["utt_duration"] = supervisions["num_frames"].sum().item()
# averaged padding proportion over utterances
info["utt_pad_proportion"] = (
((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item()
)
return loss, info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
graph_compiler: BpeCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
graph_compiler: BpeCtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
graph_compiler:
It is used to convert transcripts to FSAs.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(params.seed)
# world_size = 2
# params.master_port = 12355
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
logging.info(params)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
if "lang_bpe" in str(params.lang_dir):
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
elif "lang_phone" in str(params.lang_dir):
assert params.att_rate == 0, (
"Attention decoder training does not support phone lang dirs "
"at this time due to a missing <sos/eos> symbol. Set --att-rate=0 "
"for pure CTC training when using a phone-based lang dir."
)
assert params.num_decoder_layers == 0, (
"Attention decoder training does not support phone lang dirs "
"at this time due to a missing <sos/eos> symbol. "
"Set --num-decoder-layers=0 for pure CTC training when using "
"a phone-based lang dir."
)
graph_compiler = CtcTrainingGraphCompiler(
lexicon,
device=device,
)
# Manually add the sos/eos ID with their default values
# from the BPE recipe which we're adapting here.
graph_compiler.sos_id = 1
graph_compiler.eos_id = 1
else:
raise ValueError(
f"Unsupported type of lang dir (we expected it to have "
f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}"
)
logging.info("About to create model")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False,
use_feat_batchnorm=params.use_feat_batchnorm,
)
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[rank])
optimizer = Noam(
model.parameters(),
model_size=params.attention_dim,
factor=params.lr_factor,
warm_step=params.warm_step,
weight_decay=params.weight_decay,
)
if checkpoints:
optimizer.load_state_dict(checkpoints["optimizer"])
librispeech = MUCSAsrDataModule(args)
# params.full_libri = False
# if params.full_libri:
# train_cuts = librispeech.train_all_shuf_cuts()
# else:
train_cuts = librispeech.train_clean_mucs_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_dl = librispeech.train_dataloaders(train_cuts)
valid_cuts = librispeech.dev_mucs_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
graph_compiler=graph_compiler,
params=params,
)
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
graph_compiler=graph_compiler,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def scan_pessimistic_batches_for_oom(
model: nn.Module,
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
graph_compiler: BpeCtcTrainingGraphCompiler,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
logging.info(
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
)
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
optimizer.zero_grad()
loss, _ = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
)
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logging.error(
"Your GPU ran out of memory with the current "
"max_duration setting. We recommend decreasing "
"max_duration and trying again.\n"
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
raise
def main():
parser = get_parser()
MUCSAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()
#TypeError: trim_to_supervisions() got an unexpected keyword argument 'ignore_channel'
#AssertionError: Trimmed cut has supervisions with different channels. Either set `ignore_channel=True` to keep original channels or `keep_overlapping=False` to retain only 1 supervision per trimmed cut.

View File

@ -0,0 +1 @@
../../../librispeech/ASR/conformer_ctc/transformer.py

167
egs/mucs/ASR/local/compile_hlg.py Executable file
View File

@ -0,0 +1,167 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input lang_dir and generates HLG from
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
- L, the lexicon, built from lang_dir/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G_n_gram.fst.txt
The generated HLG is saved in $lang_dir/HLG.pt
"""
import argparse
import logging
from pathlib import Path
import k2
import torch
from icefall.lexicon import Lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lm",
type=str,
default="G_3_gram",
help="""Stem name for LM used in HLG compiling.
""",
)
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)
return parser.parse_args()
def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
lm:
The language stem base name.
Return:
An FSA representing HLG.
"""
lexicon = Lexicon(lang_dir)
datapath = str(lang_dir).split('/')[0]
max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path(f"{datapath}/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"{datapath}/lm/{lm}.pt")
G = k2.Fsa.from_dict(d)
else:
logging.info(f"Loading {lm}.fst.txt")
with open(f"{datapath}/lm/{lm}.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), f"{datapath}/lm/{lm}.pt")
first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
L = k2.arc_sort(L)
G = k2.arc_sort(G)
logging.info("Intersecting L and G")
LG = k2.compose(L, G)
logging.info(f"LG shape: {LG.shape}")
logging.info("Connecting LG")
LG = k2.connect(LG)
logging.info(f"LG shape after k2.connect: {LG.shape}")
logging.info(type(LG.aux_labels))
logging.info("Determinizing LG")
LG = k2.determinize(LG)
logging.info(type(LG.aux_labels))
logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)
logging.info("Removing disambiguation symbols on LG")
LG.labels[LG.labels >= first_token_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set LG.properties to None
LG.__dict__["_properties"] = None
assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG)
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
logging.info("Composing H and LG")
# CAUTION: The name of the inner_labels is fixed
# to `tokens`. If you want to change it, please
# also change other places in icefall that are using
# it.
HLG = k2.compose(H, LG, inner_labels="tokens")
logging.info("Connecting LG")
HLG = k2.connect(HLG)
logging.info("Arc sorting LG")
HLG = k2.arc_sort(HLG)
logging.info(f"HLG.shape: {HLG.shape}")
return HLG
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
if (lang_dir / "HLG.pt").is_file():
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
return
logging.info(f"Processing {lang_dir}")
HLG = compile_HLG(lang_dir, args.lm)
logging.info(f"Saving HLG.pt to {lang_dir}")
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,149 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the LibriSpeech dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
from typing import Optional
import sentencepiece as spm
import torch
from filter_cuts import filter_cuts
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to the bpe.model. If not None, we will remove short and
long utterances before extracting features""",
)
parser.add_argument(
"--dataset",
type=str,
help="""Dataset parts to compute fbank. If None, we will use all""",
)
parser.add_argument(
"--manifestpath",
type=str,
help="""Dataset parts to compute fbank. If None, we will use all""",
)
parser.add_argument(
"--fbankpath",
type=str,
help="""Dataset parts to compute fbank. If None, we will use all""",
)
return parser.parse_args()
def compute_fbank_mucs(
bpe_model: Optional[str] = None,
dataset: Optional[str] = None,
):
src_dir = Path(args.manifestpath)
output_dir = Path(args.fbankpath)
num_jobs = min(48, os.cpu_count())
num_mel_bins = 80
if bpe_model:
logging.info(f"Loading {bpe_model}")
sp = spm.SentencePieceProcessor()
sp.load(bpe_model)
dataset_parts = (
"train",
"test",
"dev",
)
prefix = "mucs"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
if (output_dir / cuts_filename).is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None, keep_all_channels=False,
)
cut_set.to_file(output_dir / cuts_filename)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(vars(args))
compute_fbank_mucs(bpe_model=args.bpe_model, dataset=args.dataset)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/filter_cuts.py

View File

@ -0,0 +1,87 @@
#!/usr/bin/env perl
# Copyright 2010-2012 Microsoft Corporation
# Johns Hopkins University (author: Daniel Povey)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script takes a list of utterance-ids or any file whose first field
# of each line is an utterance-id, and filters an scp
# file (or any file whose "n-th" field is an utterance id), printing
# out only those lines whose "n-th" field is in id_list. The index of
# the "n-th" field is 1, by default, but can be changed by using
# the -f <n> switch
$exclude = 0;
$field = 1;
$shifted = 0;
do {
$shifted=0;
if ($ARGV[0] eq "--exclude") {
$exclude = 1;
shift @ARGV;
$shifted=1;
}
if ($ARGV[0] eq "-f") {
$field = $ARGV[1];
shift @ARGV; shift @ARGV;
$shifted=1
}
} while ($shifted);
if(@ARGV < 1 || @ARGV > 2) {
die "Usage: filter_scp.pl [--exclude] [-f <field-to-filter-on>] id_list [in.scp] > out.scp \n" .
"Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" .
"Note: only the first field of each line in id_list matters. With --exclude, prints\n" .
"only the lines that were *not* in id_list.\n" .
"Caution: previously, the -f option was interpreted as a zero-based field index.\n" .
"If your older scripts (written before Oct 2014) stopped working and you used the\n" .
"-f option, add 1 to the argument.\n" .
"See also: utils/filter_scp.pl .\n";
}
$idlist = shift @ARGV;
open(F, "<$idlist") || die "Could not open id-list file $idlist";
while(<F>) {
@A = split;
@A>=1 || die "Invalid id-list file line $_";
$seen{$A[0]} = 1;
}
if ($field == 1) { # Treat this as special case, since it is common.
while(<>) {
$_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field.";
# $1 is what we filter on.
if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) {
print $_;
}
}
} else {
while(<>) {
@A = split;
@A > 0 || die "Invalid scp file line $_";
@A >= $field || die "Invalid scp file line $_";
if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) {
print $_;
}
}
}
# tests:
# the following should print "foo 1"
# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo)
# the following should print "bar 2".
# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/prepare_lang.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/prepare_lang_bpe.py

View File

@ -0,0 +1,54 @@
#!/usr/bin/env python3
import argparse
import gzip
import logging
import os
import shutil
from pathlib import Path
from tqdm.auto import tqdm
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--out-dir", type=str, help="Output directory.")
parser.add_argument("--data-path", type=str, help="Input directory.")
parser.add_argument("--mode", type=str, help="Input split")
args = parser.parse_args()
return args
def read_text(path):
with open(path, 'r') as f:
lines = f.read().split('\n')
return [' '.join(l.split(' ')[1:]) for l in lines]
def create_files(text):
lexicon = {}
for line in text:
for word in line.split(' '):
if word.strip() == '': continue
if word not in lexicon:
lexicon[word] = ' '.join(list(word))
with open(os.path.join(args.out_dir, 'mucs_lexicon.txt'), 'w') as f:
for word in lexicon:
f.write(word + '\t' + lexicon[word] + '\n')
with open(os.path.join(args.out_dir, 'mucs_vocab.txt'), 'w') as f:
for word in lexicon:
f.write(word + '\n')
with open(os.path.join(args.out_dir, 'mucs_vocab_text.txt'), 'w') as f:
for line in text:
f.write(line + '\n')
def main():
path = os.path.join(args.data_path, args.mode)
text = read_text(os.path.join(path, "text"))
create_files(text)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(f"out_dir: {args.out_dir}")
logging.info(f"in_dir: {args.data_path}")
main()

View File

@ -0,0 +1,86 @@
#!/usr/bin/env python3
import sys
import logging
import shutil
import lhotse
import os
import tarfile
from pathlib import Path
from typing import Dict, Optional, Sequence, Union
from tqdm import tqdm
from lhotse import (
RecordingSet,
SupervisionSegment,
SupervisionSet,
validate_recordings_and_supervisions,
)
from lhotse.recipes.utils import manifests_exist, read_manifests_if_cached
from lhotse.utils import Pathlike, safe_extract, urlretrieve_progress
def prepare_mucs(
corpus_dir: Pathlike,
output_dir: Optional[Pathlike] = None,
num_jobs: int = 1,
) -> Dict[str, Dict[str, Union[RecordingSet, SupervisionSet]]]:
"""
Returns the manifests which consist of the Recordings and Supervisions.
When all the manifests are available in the ``output_dir``, it will simply read and return them.
:param corpus_dir: Pathlike, the path of the data dir.
:param dataset_parts: string or sequence of strings representing dataset part names, e.g. 'train-clean-100', 'train-clean-5', 'dev-clean'.
By default we will infer which parts are available in ``corpus_dir``.
:param output_dir: Pathlike, the path where to write the manifests.
:param num_jobs: the number of parallel workers parsing the data.
:param link_previous_utt: If true adds previous utterance id to supervisions.
Useful for reconstructing chains of utterances as they were read.
If previous utterance was skipped from LibriTTS datasets previous_utt label is None.
:return: a Dict whose key is the dataset part, and the value is Dicts with the keys 'audio' and 'supervisions'.
"""
corpus_dir = Path(corpus_dir)
assert corpus_dir.is_dir(), f"No such directory: {corpus_dir}"
dataset_parts = ["train", "test", "dev"]
manifests = {}
if output_dir is not None:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Maybe the manifests already exist: we can read them and save a bit of preparation time.
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts, output_dir=output_dir, prefix="mucs"
)
# Contents of the file
# ;ID |SEX| SUBSET |MINUTES| NAME
# 14 | F | train-clean-360 | 25.03 | ...
# 16 | F | train-clean-360 | 25.11 | ...
# 17 | M | train-clean-360 | 25.04 | ...
for part in tqdm(dataset_parts, desc="Preparing mucs parts from espnet files"):
if manifests_exist(part=part, output_dir=output_dir, prefix="mucs"):
logging.info(f"mucs subset: {part} already prepared - skipping.")
continue
recordings, supervisions, _ = lhotse.kaldi.load_kaldi_data_dir(os.path.join(corpus_dir, part), sampling_rate=16000)
validate_recordings_and_supervisions(recordings, supervisions)
if output_dir is not None:
supervisions.to_file(output_dir / f"mucs_supervisions_{part}.jsonl.gz")
recordings.to_file(output_dir / f"mucs_recordings_{part}.jsonl.gz")
manifests[part] = {"recordings": recordings, "supervisions": supervisions}
return
if __name__ == "__main__":
datapath = sys.argv[1]
nj = int(sys.argv[2])
savepath = sys.argv[3]
print(datapath, nj, savepath)
prepare_mucs(datapath, savepath, nj)

View File

@ -0,0 +1,196 @@
#!/usr/bin/env bash
# Copyright 2010-2011 Microsoft Corporation
# 2012-2013 Johns Hopkins University (Author: Daniel Povey)
# Apache 2.0
# This script operates on a data directory, such as in data/train/.
# See http://kaldi-asr.org/doc/data_prep.html#data_prep_data
# for what these directories contain.
# This script creates a subset of that data, consisting of some specified
# number of utterances. (The selected utterances are distributed evenly
# throughout the file, by the program ./subset_scp.pl).
# There are six options, none compatible with any other.
# If you give the --per-spk option, it will attempt to select the supplied
# number of utterances for each speaker (typically you would supply a much
# smaller number in this case).
# If you give the --speakers option, it selects a subset of n randomly
# selected speakers.
# If you give the --shortest option, it will give you the n shortest utterances.
# If you give the --first option, it will just give you the n first utterances.
# If you give the --last option, it will just give you the n last utterances.
# If you give the --spk-list or --utt-list option, it reads the
# speakers/utterances to keep from <speaker-list-file>/<utt-list-file>" (note,
# in this case there is no <num-utt> positional parameter; see usage message.)
shortest=false
perspk=false
speakers=false
first_opt=
spk_list=
utt_list=
expect_args=3
case $1 in
--first|--last) first_opt=$1; shift ;;
--per-spk) perspk=true; shift ;;
--shortest) shortest=true; shift ;;
--speakers) speakers=true; shift ;;
--spk-list) shift; spk_list=$1; shift; expect_args=2 ;;
--utt-list) shift; utt_list=$1; shift; expect_args=2 ;;
--*) echo "$0: invalid option '$1'"; exit 1
esac
if [ $# != $expect_args ]; then
echo "Usage:"
echo " subset_data_dir.sh [--speakers|--shortest|--first|--last|--per-spk] <srcdir> <num-utt> <destdir>"
echo " subset_data_dir.sh [--spk-list <speaker-list-file>] <srcdir> <destdir>"
echo " subset_data_dir.sh [--utt-list <utt-list-file>] <srcdir> <destdir>"
echo "By default, randomly selects <num-utt> utterances from the data directory."
echo "With --speakers, randomly selects enough speakers that we have <num-utt> utterances"
echo "With --per-spk, selects <num-utt> utterances per speaker, if available."
echo "With --first, selects the first <num-utt> utterances"
echo "With --last, selects the last <num-utt> utterances"
echo "With --shortest, selects the shortest <num-utt> utterances."
echo "With --spk-list, reads the speakers to keep from <speaker-list-file>"
echo "With --utt-list, reads the utterances to keep from <utt-list-file>"
exit 1;
fi
srcdir=$1
if [[ $spk_list || $utt_list ]]; then
numutt=
destdir=$2
else
numutt=$2
destdir=$3
fi
export LC_ALL=C
if [ ! -f $srcdir/utt2spk ]; then
echo "$0: no such file $srcdir/utt2spk"
exit 1
fi
if [[ $numutt && $numutt -gt $(wc -l <$srcdir/utt2spk) ]]; then
echo "$0: cannot subset to more utterances than you originally had."
exit 1
fi
if $shortest && [ ! -f $srcdir/feats.scp ]; then
echo "$0: you selected --shortest but no feats.scp exist."
exit 1
fi
mkdir -p $destdir || exit 1
if [[ $spk_list ]]; then
local/filter_scp.pl "$spk_list" $srcdir/spk2utt > $destdir/spk2utt || exit 1;
utils/spk2utt_to_utt2spk.pl < $destdir/spk2utt > $destdir/utt2spk || exit 1;
elif [[ $utt_list ]]; then
local/filter_scp.pl "$utt_list" $srcdir/utt2spk > $destdir/utt2spk || exit 1;
local/utt2spk_to_spk2utt.pl < $destdir/utt2spk > $destdir/spk2utt || exit 1;
elif $speakers; then
utils/shuffle_list.pl < $srcdir/spk2utt |
awk -v numutt=$numutt '{ if (tot < numutt){ print; } tot += (NF-1); }' |
sort > $destdir/spk2utt
utils/spk2utt_to_utt2spk.pl < $destdir/spk2utt > $destdir/utt2spk
elif $perspk; then
awk '{ n='$numutt'; printf("%s ",$1);
skip=1; while(n*(skip+1) <= NF-1) { skip++; }
for(x=2; x<=NF && x <= (n*skip+1); x += skip) { printf("%s ", $x); }
printf("\n"); }' <$srcdir/spk2utt >$destdir/spk2utt
utils/spk2utt_to_utt2spk.pl < $destdir/spk2utt > $destdir/utt2spk
else
if $shortest; then
# Select $numutt shortest utterances.
. ./path.sh
if [ -f $srcdir/utt2num_frames ]; then
ln -sf $(utils/make_absolute.sh $srcdir)/utt2num_frames $destdir/tmp.len
else
feat-to-len scp:$srcdir/feats.scp ark,t:$destdir/tmp.len || exit 1;
fi
sort -n -k2 $destdir/tmp.len |
awk '{print $1}' |
head -$numutt >$destdir/tmp.uttlist
local/filter_scp.pl $destdir/tmp.uttlist $srcdir/utt2spk >$destdir/utt2spk
rm $destdir/tmp.uttlist $destdir/tmp.len
else
# Select $numutt random utterances.
local/subset_scp.pl $first_opt $numutt $srcdir/utt2spk > $destdir/utt2spk || exit 1;
fi
local/utt2spk_to_spk2utt.pl < $destdir/utt2spk > $destdir/spk2utt
fi
# Perform filtering. utt2spk and spk2utt files already exist by this point.
# Filter by utterance.
[ -f $srcdir/feats.scp ] &&
local/filter_scp.pl $destdir/utt2spk <$srcdir/feats.scp >$destdir/feats.scp
[ -f $srcdir/vad.scp ] &&
local/filter_scp.pl $destdir/utt2spk <$srcdir/vad.scp >$destdir/vad.scp
[ -f $srcdir/utt2lang ] &&
local/filter_scp.pl $destdir/utt2spk <$srcdir/utt2lang >$destdir/utt2lang
[ -f $srcdir/utt2dur ] &&
local/filter_scp.pl $destdir/utt2spk <$srcdir/utt2dur >$destdir/utt2dur
[ -f $srcdir/utt2num_frames ] &&
local/filter_scp.pl $destdir/utt2spk <$srcdir/utt2num_frames >$destdir/utt2num_frames
[ -f $srcdir/utt2uniq ] &&
local/filter_scp.pl $destdir/utt2spk <$srcdir/utt2uniq >$destdir/utt2uniq
[ -f $srcdir/wav.scp ] &&
local/filter_scp.pl $destdir/utt2spk <$srcdir/wav.scp >$destdir/wav.scp
[ -f $srcdir/utt2warp ] &&
local/filter_scp.pl $destdir/utt2spk <$srcdir/utt2warp >$destdir/utt2warp
[ -f $srcdir/text ] &&
local/filter_scp.pl $destdir/utt2spk <$srcdir/text >$destdir/text
# Filter by speaker.
[ -f $srcdir/spk2warp ] &&
local/filter_scp.pl $destdir/spk2utt <$srcdir/spk2warp >$destdir/spk2warp
[ -f $srcdir/spk2gender ] &&
local/filter_scp.pl $destdir/spk2utt <$srcdir/spk2gender >$destdir/spk2gender
[ -f $srcdir/cmvn.scp ] &&
local/filter_scp.pl $destdir/spk2utt <$srcdir/cmvn.scp >$destdir/cmvn.scp
# Filter by recording-id.
if [ -f $srcdir/segments ]; then
local/filter_scp.pl $destdir/utt2spk <$srcdir/segments >$destdir/segments
# Recording-ids are in segments.
awk '{print $2}' $destdir/segments | sort | uniq >$destdir/reco
# The next line overrides the command above for wav.scp, which would be incorrect.
[ -f $srcdir/wav.scp ] &&
local/filter_scp.pl $destdir/reco <$srcdir/wav.scp >$destdir/wav.scp
else
# No segments; recording-ids are in wav.scp.
awk '{print $1}' $destdir/wav.scp | sort | uniq >$destdir/reco
fi
[ -f $srcdir/reco2file_and_channel ] &&
local/filter_scp.pl $destdir/reco <$srcdir/reco2file_and_channel >$destdir/reco2file_and_channel
[ -f $srcdir/reco2dur ] &&
local/filter_scp.pl $destdir/reco <$srcdir/reco2dur >$destdir/reco2dur
# Filter the STM file for proper sclite scoring.
# Copy over the comments from STM file.
[ -f $srcdir/stm ] &&
(grep "^;;" $srcdir/stm
local/filter_scp.pl $destdir/reco $srcdir/stm) >$destdir/stm
rm $destdir/reco
# Copy frame_shift if present.
[ -f $srcdir/frame_shift ] && cp $srcdir/frame_shift $destdir
srcutts=$(wc -l <$srcdir/utt2spk)
destutts=$(wc -l <$destdir/utt2spk)
echo "$0: reducing #utt from $srcutts to $destutts"
exit 0

105
egs/mucs/ASR/local/subset_scp.pl Executable file
View File

@ -0,0 +1,105 @@
#!/usr/bin/env perl
use warnings; #sed replacement for -w perl parameter
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This program selects a subset of N elements in the scp.
# By default, it selects them evenly from throughout the scp, in order to avoid
# selecting too many from the same speaker. It prints them on the standard
# output.
# With the option --first, it just selects the N first utterances.
# With the option --last, it just selects the N last utterances.
# Last modified by JHU & HKUST @2013
$quiet = 0;
$first = 0;
$last = 0;
if (@ARGV > 0 && $ARGV[0] eq "--quiet") {
shift;
$quiet = 1;
}
if (@ARGV > 0 && $ARGV[0] eq "--first") {
shift;
$first = 1;
}
if (@ARGV > 0 && $ARGV[0] eq "--last") {
shift;
$last = 1;
}
if(@ARGV < 2 ) {
die "Usage: subset_scp.pl [--quiet][--first|--last] N in.scp\n" .
" --quiet causes it to not die if N < num lines in scp.\n" .
" --first and --last make it equivalent to head or tail.\n" .
"See also: filter_scp.pl\n";
}
$N = shift @ARGV;
if($N == 0) {
die "First command-line parameter to subset_scp.pl must be an integer, got \"$N\"";
}
$inscp = shift @ARGV;
open(I, "<$inscp") || die "Opening input scp file $inscp";
@F = ();
while(<I>) {
push @F, $_;
}
$numlines = @F;
if($N > $numlines) {
if ($quiet) {
$N = $numlines;
} else {
die "You requested from subset_scp.pl more elements than available: $N > $numlines";
}
}
sub select_n {
my ($start,$end,$num_needed) = @_;
my $diff = $end - $start;
if ($num_needed > $diff) {
die "select_n: code error";
}
if ($diff == 1 ) {
if ($num_needed > 0) {
print $F[$start];
}
} else {
my $halfdiff = int($diff/2);
my $halfneeded = int($num_needed/2);
select_n($start, $start+$halfdiff, $halfneeded);
select_n($start+$halfdiff, $end, $num_needed - $halfneeded);
}
}
if ( ! $first && ! $last) {
if ($N > 0) {
select_n(0, $numlines, $N);
}
} else {
if ($first) { # --first option: same as head.
for ($n = 0; $n < $N; $n++) {
print $F[$n];
}
} else { # --last option: same as tail.
for ($n = @F - $N; $n < @F; $n++) {
print $F[$n];
}
}
}

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/train_bpe_model.py

View File

@ -0,0 +1,38 @@
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# converts an utt2spk file to a spk2utt file.
# Takes input from the stdin or from a file argument;
# output goes to the standard out.
if ( @ARGV > 1 ) {
die "Usage: utt2spk_to_spk2utt.pl [ utt2spk ] > spk2utt";
}
while(<>){
@A = split(" ", $_);
@A == 2 || die "Invalid line in utt2spk file: $_";
($u,$s) = @A;
if(!$seen_spk{$s}) {
$seen_spk{$s} = 1;
push @spklist, $s;
}
push (@{$spk_hash{$s}}, "$u");
}
foreach $s (@spklist) {
$l = join(' ',@{$spk_hash{$s}});
print "$s $l\n";
}

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/validate_bpe_lexicon.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/validate_manifest.py

259
egs/mucs/ASR/prepare.sh Executable file
View File

@ -0,0 +1,259 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=60
stage=-1
stop_stage=9
# We assume dl_dir (download dir) contains the following
# directories and files. download them from https://www.openslr.org/resources/104/
#
# - $dl_dir/hi-en
dl_dir=$PWD/download
mkdir -p $dl_dir
raw_data_path="/data/Database/MUCS/"
dataset="hi-en" #hin-en or bn-en
datadir="data_"$dataset
raw_kaldi_files_path=$dl_dir/$dataset/
. shared/parse_options.sh || exit 1
# vocab size for sentence piece models.
vocab_size=400
mkdir -p $datadir
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "Stage -1: prepare data files"
mkdir -p $dl_dir/$dataset
for x in train dev test train_all; do
if [ -d "$dl_dir/$dataset/$x" ]; then rm -Rf $dl_dir/$dataset/$x; fi
done
mkdir -p $dl_dir/$dataset/{train,test,dev}
cp -r $raw_data_path/$dataset/"train"/"transcripts"/* $dl_dir/$dataset/"train"
cp -r $raw_data_path/$dataset/"test"/"transcripts"/* $dl_dir/$dataset/"test"
for x in train test
do
cp $dl_dir/$dataset/$x/"wav.scp" $dl_dir/$dataset/$x/"wav.scp_old"
cat $dl_dir/$dataset/$x/"wav.scp" | cut -d' ' -f1 > $dl_dir/$dataset/$x/wav_ids
cat $dl_dir/$dataset/$x/"wav.scp" | cut -d' ' -f2 | awk -v var="$raw_data_path/$dataset/$x/" '{print var$1}' > $dl_dir/$dataset/$x/wav_ids_with_fullpath
paste -d' ' $dl_dir/$dataset/$x/wav_ids $dl_dir/$dataset/$x/wav_ids_with_fullpath > $dl_dir/$dataset/$x/"wav.scp"
rm $dl_dir/$dataset/$x/wav_ids
rm $dl_dir/$dataset/$x/wav_ids_with_fullpath
done
./local/subset_data_dir.sh --first $dl_dir/$dataset/"train" 1000 $dl_dir/$dataset/"dev"
total=$(wc -l $dl_dir/$dataset/"train"/"text" | cut -d' ' -f1)
count=$(expr $total - 1000)
./local/subset_data_dir.sh --first $dl_dir/$dataset/"train" $count $dl_dir/$dataset/"train_reduced"
mv $dl_dir/$dataset/"train" $dl_dir/$dataset/"train_all"
mv $dl_dir/$dataset/"train_reduced" $dl_dir/$dataset/"train"
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: prepare LM files"
mkdir -p $raw_kaldi_files_path/lm
if [ ! -e $raw_kaldi_files_path/lm/.done ]; then
./local/prepare_lm_files.py --out-dir=$dl_dir/lm --data-path=$raw_kaldi_files_path --mode="train"
touch $raw_kaldi_files_path/lm/.done
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare MUCS manifest"
# We assume that you have downloaded the MUCS corpus
# to $dl_dir/
mkdir -p $datadir/manifests
if [ ! -e $datadir/manifests/.mucs.done ]; then
# generate lhotse manifests from kaldi style files
./local/prepare_manifest.py "$raw_kaldi_files_path" $nj $datadir/manifests
touch $datadir/manifests/.mucs.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank for mucs"
mkdir -p $datadir/fbank
if [ ! -e $datadir/fbank/.mucs.done ]; then
./local/compute_fbank_mucs.py --manifestpath $datadir/manifests/ --fbankpath $datadir/fbank
touch $datadir/fbank/.mucs.done
fi
# exit
if [ ! -e $datadir/fbank/.mucs-validated.done ]; then
log "Validating $datadir/fbank for mucs"
parts=(
train
test
dev
)
for part in ${parts[@]}; do
python3 ./local/validate_manifest.py \
$datadir/fbank/mucs_cuts_${part}.jsonl.gz
done
touch $datadir/fbank/.mucs-validated.done
fi
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare phone based lang"
lang_dir=$datadir/lang_phone
mkdir -p $lang_dir
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - $dl_dir/lm/mucs_lexicon.txt |
sort | uniq > $lang_dir/lexicon.txt
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir
fi
if [ ! -f $lang_dir/L.fst ]; then
log "Converting L.pt to L.fst"
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L.pt \
$lang_dir/L.fst
fi
if [ ! -f $lang_dir/L_disambig.fst ]; then
log "Converting L_disambig.pt to L_disambig.fst"
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L_disambig.pt \
$lang_dir/disambig_L.fst
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare BPE based lang"
lang_dir=$datadir/lang_bpe_${vocab_size}
mkdir -p $lang_dir
# We reuse words.txt from phone based lexicon
# so that the two can share G.pt later.
cp $datadir/lang_phone/words.txt $lang_dir
if [ ! -f $lang_dir/transcript_words.txt ]; then
log "Generate data for BPE training"
cp download/lm/mucs_vocab_text.txt $lang_dir/transcript_words.txt
fi
if [ ! -f $lang_dir/bpe.model ]; then
./local/train_bpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript $lang_dir/transcript_words.txt
fi
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bpe.py --lang-dir $lang_dir
log "Validating $lang_dir/lexicon.txt"
./local/validate_bpe_lexicon.py \
--lexicon $lang_dir/lexicon.txt \
--bpe-model $lang_dir/bpe.model
fi
if [ ! -f $lang_dir/L.fst ]; then
log "Converting L.pt to L.fst"
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L.pt \
$lang_dir/L.fst
fi
if [ ! -f $lang_dir/L_disambig.fst ]; then
log "Converting L_disambig.pt to L_disambig.fst"
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L_disambig.pt \
$lang_dir/L_disambig.fst
fi
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Train LM from training data"
lang_dir=$datadir/lang_bpe_${vocab_size}
if [ ! -f $lang_dir/lm_3.arpa ]; then
./shared/make_kn_lm.py \
-ngram-order 3 \
-text $lang_dir/transcript_words.txt \
-lm $lang_dir/lm_3.arpa
fi
if [ ! -f $lang_dir/lm_4.arpa ]; then
./shared/make_kn_lm.py \
-ngram-order 4 \
-text $lang_dir/transcript_words.txt \
-lm $lang_dir/lm_4.arpa
fi
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Prepare G"
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
mkdir -p $datadir/lm
if [ ! -f $datadir/lm/G_3_gram.fst.txt ]; then
# It is used in building HLG
python3 -m kaldilm \
--read-symbol-table="$datadir/lang_phone/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
$datadir/lang_bpe_${vocab_size}/lm_3.arpa > $datadir/lm/G_3_gram.fst.txt
fi
if [ ! -f $datadir/lm/G_4_gram.fst.txt ]; then
# It is used in building HLG
python3 -m kaldilm \
--read-symbol-table="$datadir/lang_phone/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
$datadir/lang_bpe_${vocab_size}/lm_4.arpa > $datadir/lm/G_4_gram.fst.txt
fi
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "Stage 9: Compile HLG"
lang_dir=$datadir/lang_bpe_${vocab_size}
./local/compile_hlg.py --lang-dir $lang_dir
fi

39
egs/mucs/ASR/run.sh Executable file
View File

@ -0,0 +1,39 @@
#!/bin/bash
export CUDA_VISIBLE_DEVICES="0"
set -e
dataset='hi-en'
datadir=data_"$dataset"
bpe=400
# decode_methods="attention-decoder 1best nbest nbest-rescoring ctc-decoding whole-lattice-rescoring"
decode_methods="nbest nbest-rescoring whole-lattice-rescoring"
num_paths=10
max_duration=5
# ./conformer_ctc/train.py \
# --num-epochs 60 \
# --max-duration 300 \
# --exp-dir ./conformer_ctc/exp_"$dataset"_bpe"$bpe" \
# --manifest-dir $datadir/fbank \
# --lang-dir $datadir/lang_bpe_"$bpe" \
# --enable-musan False \
for decode_method in $decode_methods;
do
./conformer_ctc/decode.py \
--epoch 59 \
--avg 10 \
--manifest-dir $datadir/fbank \
--exp-dir ./conformer_ctc/exp_"$dataset"_bpe"$bpe" \
--max-duration $max_duration \
--lang-dir $datadir/lang_bpe_"$bpe" \
--lm-dir $datadir/"lm" \
--method $decode_method \
--num-paths $num_paths \
done
exit