diff --git a/egs/mucs/ASR/RESULTS.md b/egs/mucs/ASR/RESULTS.md new file mode 100644 index 000000000..4c0d3291d --- /dev/null +++ b/egs/mucs/ASR/RESULTS.md @@ -0,0 +1,74 @@ +# Results for mucs hi-en and bn-en + +This page shows the WERs for the code switched test corpus of MUCS hi-en and bn-en. + +## using conformer ctc + +The following results are obtained with run.sh + +Specify the language through dataset arg (hi-en or bn-en) +LM is trained using kenlm, with the training corpus + +Here are the results with different decoding methods + +bn-en +| | test | +|-------------------------|-------| +| ctc decoding | 31.72 | +| 1best | 28.05 | +| nbest | 27.92 | +| nbest-rescoring | 27.22 | +| whole-lattice-rescoring | 27.24 | +| attention-decoder | 26.46 | + +hi-en +| | test | +|-------------------------|-------| +| ctc decoding | 31.43 | +| 1best | 28.48 | +| nbest | 28.55 | +| nbest-rescoring | 28.23 | +| whole-lattice-rescoring | 28.77 | +| attention-decoder | 28.16 | + +The training command for reproducing is given below: +```bash +cd egs/mucs/ASR/ +./prepare.sh + +dataset="hi-en" #hi-en or bn-en +bpe=400 +datadir=data_"$dataset" +./conformer_ctc/train.py \ + --num-epochs 60 \ + --max-duration 300 \ + --exp-dir ./conformer_ctc/exp_"$dataset"_bpe"$bpe" \ + --manifest-dir $datadir/fbank \ + --lang-dir $datadir/lang_bpe_"$bpe" \ + --enable-musan False \ +``` + +The decoding command is given below: +```bash +dataset="hi-en" #hi-en or bn-en +bpe=400 +datadir=data_"$dataset" +num_paths=10 +max_duration=10 +decode_methods="attention-decoder 1best nbest nbest-rescoring ctc-decoding whole-lattice-rescoring" + +for decode_method in $decode_methods; +do + ./conformer_ctc/decode.py \ + --epoch 59 \ + --avg 10 \ + --manifest-dir $datadir/fbank \ + --exp-dir ./conformer_ctc/exp_"$dataset"_bpe"$bpe" \ + --max-duration $max_duration \ + --lang-dir $datadir/lang_bpe_"$bpe" \ + --lm-dir $datadir/"lm" \ + --method $decode_method \ + --num-paths $num_paths \ + +done +``` \ No newline at end of file diff --git a/egs/mucs/ASR/conformer_ctc/README.md b/egs/mucs/ASR/conformer_ctc/README.md new file mode 100644 index 000000000..37ace4204 --- /dev/null +++ b/egs/mucs/ASR/conformer_ctc/README.md @@ -0,0 +1,75 @@ +## Introduction + +Please visit + +for how to run this recipe. + +## How to compute framewise alignment information + +### Step 1: Train a model + +Please use `conformer_ctc/train.py` to train a model. +See +for how to do it. + +### Step 2: Compute framewise alignment + +Run + +``` +# Choose a checkpoint and determine the number of checkpoints to average +epoch=30 +avg=15 +./conformer_ctc/ali.py \ + --epoch $epoch \ + --avg $avg \ + --max-duration 500 \ + --bucketing-sampler 0 \ + --full-libri 1 \ + --exp-dir conformer_ctc/exp \ + --lang-dir data/lang_bpe_500 \ + --ali-dir data/ali_500 +``` +and you will get four files inside the folder `data/ali_500`: + +``` +$ ls -lh data/ali_500 +total 546M +-rw-r--r-- 1 kuangfangjun root 1.1M Sep 28 08:06 test_clean.pt +-rw-r--r-- 1 kuangfangjun root 1.1M Sep 28 08:07 test_other.pt +-rw-r--r-- 1 kuangfangjun root 542M Sep 28 11:36 train-960.pt +-rw-r--r-- 1 kuangfangjun root 2.1M Sep 28 11:38 valid.pt +``` + +**Note**: It can take more than 3 hours to compute the alignment +for the training dataset, which contains 960 * 3 = 2880 hours of data. + +**Caution**: The model parameters in `conformer_ctc/ali.py` have to match those +in `conformer_ctc/train.py`. + +**Caution**: You have to set the parameter `preserve_id` to `True` for `CutMix`. +Search `./conformer_ctc/asr_datamodule.py` for `preserve_id`. + +### Step 3: Check your extracted alignments + +There is a file `test_ali.py` in `icefall/test` that can be used to test your +alignments. It uses pre-computed alignments to modify a randomly generated +`nnet_output` and it checks that we can decode the correct transcripts +from the resulting `nnet_output`. + +You should get something like the following if you run that script: + +``` +$ ./test/test_ali.py +['THE GOOD NATURED AUDIENCE IN PITY TO FALLEN MAJESTY SHOWED FOR ONCE GREATER DEFERENCE TO THE KING THAN TO THE MINISTER AND SUNG THE PSALM WHICH THE FORMER HAD CALLED FOR', 'THE OLD SERVANT TOLD HIM QUIETLY AS THEY CREPT BACK TO DWELL THAT THIS PASSAGE THAT LED FROM THE HUT IN THE PLEASANCE TO SHERWOOD AND THAT GEOFFREY FOR THE TIME WAS HIDING WITH THE OUTLAWS IN THE FOREST', 'FOR A WHILE SHE LAY IN HER CHAIR IN HAPPY DREAMY PLEASURE AT SUN AND BIRD AND TREE', "BUT THE ESSENCE OF LUTHER'S LECTURES IS THERE"] +['THE GOOD NATURED AUDIENCE IN PITY TO FALLEN MAJESTY SHOWED FOR ONCE GREATER DEFERENCE TO THE KING THAN TO THE MINISTER AND SUNG THE PSALM WHICH THE FORMER HAD CALLED FOR', 'THE OLD SERVANT TOLD HIM QUIETLY AS THEY CREPT BACK TO GAMEWELL THAT THIS PASSAGE WAY LED FROM THE HUT IN THE PLEASANCE TO SHERWOOD AND THAT GEOFFREY FOR THE TIME WAS HIDING WITH THE OUTLAWS IN THE FOREST', 'FOR A WHILE SHE LAY IN HER CHAIR IN HAPPY DREAMY PLEASURE AT SUN AND BIRD AND TREE', "BUT THE ESSENCE OF LUTHER'S LECTURES IS THERE"] +``` + +### Step 4: Use your alignments in training + +Please refer to `conformer_mmi/train.py` for usage. Some useful +functions are: + +- `load_alignments()`, it loads alignment saved by `conformer_ctc/ali.py` +- `convert_alignments_to_tensor()`, it converts alignments to PyTorch tensors +- `lookup_alignments()`, it returns the alignments of utterances by giving the cut ID of the utterances. diff --git a/egs/mucs/ASR/conformer_ctc/__init__.py b/egs/mucs/ASR/conformer_ctc/__init__.py new file mode 120000 index 000000000..0fd1b73f3 --- /dev/null +++ b/egs/mucs/ASR/conformer_ctc/__init__.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/__init__.py \ No newline at end of file diff --git a/egs/mucs/ASR/conformer_ctc/ali.py b/egs/mucs/ASR/conformer_ctc/ali.py new file mode 120000 index 000000000..71ca217cb --- /dev/null +++ b/egs/mucs/ASR/conformer_ctc/ali.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/ali.py \ No newline at end of file diff --git a/egs/mucs/ASR/conformer_ctc/asr_datamodule.py b/egs/mucs/ASR/conformer_ctc/asr_datamodule.py new file mode 100644 index 000000000..b1031370d --- /dev/null +++ b/egs/mucs/ASR/conformer_ctc/asr_datamodule.py @@ -0,0 +1,419 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class MUCSAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train, valid dataloader, and one test loader + This modified from librispeech asrmodule + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "mucs_cuts_train.jsonl.gz" + ) + @lru_cache() + def dev_mucs_cuts(self) -> CutSet: + logging.info("About to get valid-mucs") + return load_manifest_lazy( + self.args.manifest_dir / "mucs_cuts_dev.jsonl.gz" + ) + + @lru_cache() + def test_mucs_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "mucs_cuts_test.jsonl.gz" + ) + @lru_cache() + def train_clean_mucs_cuts(self) -> CutSet: + logging.info("About to get train-mucs") + return load_manifest_lazy( + self.args.manifest_dir / "mucs_cuts_train.jsonl.gz" + ) diff --git a/egs/mucs/ASR/conformer_ctc/conformer.py b/egs/mucs/ASR/conformer_ctc/conformer.py new file mode 120000 index 000000000..d1f4209d7 --- /dev/null +++ b/egs/mucs/ASR/conformer_ctc/conformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/conformer.py \ No newline at end of file diff --git a/egs/mucs/ASR/conformer_ctc/decode.py b/egs/mucs/ASR/conformer_ctc/decode.py new file mode 100755 index 000000000..112d46f76 --- /dev/null +++ b/egs/mucs/ASR/conformer_ctc/decode.py @@ -0,0 +1,813 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import MUCSAsrDataModule +from conformer import Conformer + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_n_best_list, + rescore_with_rnn_lm, + rescore_with_whole_lattice, +) +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.rnn_lm.model import RnnLmModel +from icefall.utils import ( + AttributeDict, + get_texts, + load_averaged_model, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=77, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=55, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--method", + type=str, + default="attention-decoder", + help="""Decoding method. + Supported values are: + - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (1) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (2) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (3) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (4) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + - (5) attention-decoder. Extract n paths from the LM rescored + lattice, the path with the highest score is the decoding result. + - (6) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume + you have trained an RNN LM using ./rnn_lm/train.py + - (7) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="The lang dir", + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + parser.add_argument( + "--rnn-lm-exp-dir", + type=str, + default="rnn_lm/exp", + help="""Used only when --method is rnn-lm. + It specifies the path to RNN LM exp dir. + """, + ) + + parser.add_argument( + "--rnn-lm-epoch", + type=int, + default=7, + help="""Used only when --method is rnn-lm. + It specifies the checkpoint to use. + """, + ) + + parser.add_argument( + "--rnn-lm-avg", + type=int, + default=2, + help="""Used only when --method is rnn-lm. + It specifies the number of checkpoints to average. + """, + ) + + parser.add_argument( + "--rnn-lm-embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--rnn-lm-hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--rnn-lm-num-layers", + type=int, + default=4, + help="Number of RNN layers the model", + ) + parser.add_argument( + "--rnn-lm-tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + # parameters for conformer + "subsampling_factor": 4, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "num_decoder_layers": 6, + # parameters for decoding + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + rnn_lm_model: Optional[nn.Module], + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + + - params.method is "1best", it uses 1best decoding without LM rescoring. + - params.method is "nbest", it uses nbest decoding without LM rescoring. + - params.method is "nbest-rescoring", it uses nbest LM rescoring. + - params.method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + rnn_lm_model: + The neural model for RNN LM. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + sos_id: + The token ID of the SOS. + eos_id: + The token ID of the EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + + nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) + # nnet_output is (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // params.subsampling_factor, + supervisions["num_frames"] // params.subsampling_factor, + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} + + if params.method in ["1best", "nbest"]: + if params.method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + elif params.method == "attention-decoder": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + nbest_scale=params.nbest_scale, + ) + elif params.method == "rnn-lm": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + + best_path_dict = rescore_with_rnn_lm( + lattice=rescored_lattice, + num_paths=params.num_paths, + rnn_lm_model=rnn_lm_model, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + blank_id=0, + nbest_scale=params.nbest_scale, + ) + else: + assert False, f"Unsupported decoding method: {params.method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + rnn_lm_model: Optional[nn.Module], + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + rnn_lm_model: + The neural model for RNN LM. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + word_table: + It is the word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + rnn_lm_model=rnn_lm_model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + if hyps_dict is not None: + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + else: + assert len(results) > 0, "It should not decode to empty in the first batch!" + this_batch = [] + hyp_words = [] + for ref_text in texts: + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + for lm_scale in results.keys(): + results[lm_scale].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[int], List[int]]]], +): + if params.method in ("attention-decoder", "rnn-lm"): + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=enable_log + ) + test_set_wers[key] = wer + + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wer-summary-{test_set_name}-{params.method}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + MUCSAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") + logging.info("Decoding started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + + params.num_classes = num_classes + params.sos_id = sos_id + params.eos_id = eos_id + + if params.method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take few minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.method in [ + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ]: + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + model = load_averaged_model( + params.exp_dir, model, params.epoch, params.avg, device + ) + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + rnn_lm_model = None + if params.method == "rnn-lm": + rnn_lm_model = RnnLmModel( + vocab_size=params.num_classes, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) + if params.rnn_lm_avg == 1: + load_checkpoint( + f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", + rnn_lm_model, + ) + rnn_lm_model.to(device) + else: + rnn_lm_model = load_averaged_model( + params.rnn_lm_exp_dir, + rnn_lm_model, + params.rnn_lm_epoch, + params.rnn_lm_avg, + device, + ) + rnn_lm_model.eval() + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = MUCSAsrDataModule(args) + + test_clean_cuts = librispeech.test_mucs_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + + test_sets = ["test"] + test_dl = [test_clean_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + rnn_lm_model=rnn_lm_model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + save_results(params=params, test_set_name=test_set, results_dict=results_dict) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/mucs/ASR/conformer_ctc/export.py b/egs/mucs/ASR/conformer_ctc/export.py new file mode 120000 index 000000000..60e314d9d --- /dev/null +++ b/egs/mucs/ASR/conformer_ctc/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/export.py \ No newline at end of file diff --git a/egs/mucs/ASR/conformer_ctc/label_smoothing.py b/egs/mucs/ASR/conformer_ctc/label_smoothing.py new file mode 120000 index 000000000..e9d239fff --- /dev/null +++ b/egs/mucs/ASR/conformer_ctc/label_smoothing.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/mucs/ASR/conformer_ctc/pretrained.py b/egs/mucs/ASR/conformer_ctc/pretrained.py new file mode 120000 index 000000000..526bc9678 --- /dev/null +++ b/egs/mucs/ASR/conformer_ctc/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/pretrained.py \ No newline at end of file diff --git a/egs/mucs/ASR/conformer_ctc/subsampling.py b/egs/mucs/ASR/conformer_ctc/subsampling.py new file mode 120000 index 000000000..16354dc73 --- /dev/null +++ b/egs/mucs/ASR/conformer_ctc/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/subsampling.py \ No newline at end of file diff --git a/egs/mucs/ASR/conformer_ctc/test_label_smoothing.py b/egs/mucs/ASR/conformer_ctc/test_label_smoothing.py new file mode 120000 index 000000000..04b959ecf --- /dev/null +++ b/egs/mucs/ASR/conformer_ctc/test_label_smoothing.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/test_label_smoothing.py \ No newline at end of file diff --git a/egs/mucs/ASR/conformer_ctc/test_subsampling.py b/egs/mucs/ASR/conformer_ctc/test_subsampling.py new file mode 120000 index 000000000..98c3be3e6 --- /dev/null +++ b/egs/mucs/ASR/conformer_ctc/test_subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/test_subsampling.py \ No newline at end of file diff --git a/egs/mucs/ASR/conformer_ctc/test_transformer.py b/egs/mucs/ASR/conformer_ctc/test_transformer.py new file mode 120000 index 000000000..8b0990ec6 --- /dev/null +++ b/egs/mucs/ASR/conformer_ctc/test_transformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/test_transformer.py \ No newline at end of file diff --git a/egs/mucs/ASR/conformer_ctc/train.py b/egs/mucs/ASR/conformer_ctc/train.py new file mode 100755 index 000000000..5bfc8b830 --- /dev/null +++ b/egs/mucs/ASR/conformer_ctc/train.py @@ -0,0 +1,824 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + export CUDA_VISIBLE_DEVICES="0,1,2,3" + ./conformer_ctc/train.py \ + --exp-dir ./conformer_ctc/exp \ + --world-size 4 \ + --full-libri 1 \ + --max-duration 200 \ + --num-epochs 20 +""" + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import MUCSAsrDataModule +from conformer import Conformer +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter +from transformer import Noam + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.graph_compiler import CtcTrainingGraphCompiler +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions, + setup_logger, + str2bool, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=78, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + conformer_ctc/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--att-rate", + type=float, + default=0.8, + help="""The attention rate. + The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss + """, + ) + + parser.add_argument( + "--num-decoder-layers", + type=int, + default=6, + help="""Number of decoder layer of transformer decoder. + Setting this to 0 will not create the decoder at all (pure CTC model) + """, + ) + + parser.add_argument( + "--lr-factor", + type=float, + default=5.0, + help="The lr_factor for Noam optimizer", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - use_feat_batchnorm: Normalization for the input features, can be a + boolean indicating whether to do batch + normalization, or a float which means just scaling + the input features with this float value. + If given a float value, we will remove batchnorm + layer in `ConvolutionModule` as well. + + - attention_dim: Hidden dim for multi-head attention model. + + - head: Number of heads of multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - beam_size: It is used in k2.ctc_loss + + - reduction: It is used in k2.ctc_loss + + - use_double_scores: It is used in k2.ctc_loss + + - weight_decay: The weight_decay for the optimizer. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, + "attention_dim": 512, + "nhead": 8, + # parameters for loss + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + # parameters for Noam + "weight_decay": 1e-6, + "warm_step": 80000, + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + batch: dict, + graph_compiler: BpeCtcTrainingGraphCompiler, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = graph_compiler.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + with torch.set_grad_enabled(is_training): + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + # nnet_output is (N, T, C) + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + + if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler): + # Works with a BPE model + token_ids = graph_compiler.texts_to_ids(texts) + decoding_graph = graph_compiler.compile(token_ids) + elif isinstance(graph_compiler, CtcTrainingGraphCompiler): + # Works with a phone lexicon + decoding_graph = graph_compiler.compile(texts) + else: + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + if params.att_rate != 0.0: + with torch.set_grad_enabled(is_training): + mmodel = model.module if hasattr(model, "module") else model + # Note: We need to generate an unsorted version of token_ids + # `encode_supervisions()` called above sorts text, but + # encoder_memory and memory_mask are not sorted, so we + # use an unsorted version `supervisions["text"]` to regenerate + # the token_ids + # + # See https://github.com/k2-fsa/icefall/issues/97 + # for more details + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + att_loss = mmodel.decoder_forward( + encoder_memory, + memory_mask, + token_ids=unsorted_token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss + else: + loss = ctc_loss + att_loss = torch.tensor([0]) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.att_rate != 0.0: + info["att_loss"] = att_loss.detach().cpu().item() + + info["loss"] = loss.detach().cpu().item() + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = supervisions["num_frames"].sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() + ) + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + graph_compiler: BpeCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + graph_compiler: + It is used to convert transcripts to FSAs. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + + if batch_idx % params.log_interval == 0: + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + # world_size = 2 + # params.master_port = 12355 + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + if "lang_bpe" in str(params.lang_dir): + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + elif "lang_phone" in str(params.lang_dir): + assert params.att_rate == 0, ( + "Attention decoder training does not support phone lang dirs " + "at this time due to a missing symbol. Set --att-rate=0 " + "for pure CTC training when using a phone-based lang dir." + ) + assert params.num_decoder_layers == 0, ( + "Attention decoder training does not support phone lang dirs " + "at this time due to a missing symbol. " + "Set --num-decoder-layers=0 for pure CTC training when using " + "a phone-based lang dir." + ) + graph_compiler = CtcTrainingGraphCompiler( + lexicon, + device=device, + ) + # Manually add the sos/eos ID with their default values + # from the BPE recipe which we're adapting here. + graph_compiler.sos_id = 1 + graph_compiler.eos_id = 1 + else: + raise ValueError( + f"Unsupported type of lang dir (we expected it to have " + f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" + ) + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) + + optimizer = Noam( + model.parameters(), + model_size=params.attention_dim, + factor=params.lr_factor, + warm_step=params.warm_step, + weight_decay=params.weight_decay, + ) + + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + + librispeech = MUCSAsrDataModule(args) + # params.full_libri = False + # if params.full_libri: + # train_cuts = librispeech.train_all_shuf_cuts() + # else: + train_cuts = librispeech.train_clean_mucs_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_dl = librispeech.train_dataloaders(train_cuts) + + valid_cuts = librispeech.dev_mucs_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) + train_dl.sampler.set_epoch(epoch) + + cur_lr = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + MUCSAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() + +#TypeError: trim_to_supervisions() got an unexpected keyword argument 'ignore_channel' + +#AssertionError: Trimmed cut has supervisions with different channels. Either set `ignore_channel=True` to keep original channels or `keep_overlapping=False` to retain only 1 supervision per trimmed cut. \ No newline at end of file diff --git a/egs/mucs/ASR/conformer_ctc/transformer.py b/egs/mucs/ASR/conformer_ctc/transformer.py new file mode 120000 index 000000000..1c3f43fcf --- /dev/null +++ b/egs/mucs/ASR/conformer_ctc/transformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/transformer.py \ No newline at end of file diff --git a/egs/mucs/ASR/local/compile_hlg.py b/egs/mucs/ASR/local/compile_hlg.py new file mode 100755 index 000000000..7a5a47163 --- /dev/null +++ b/egs/mucs/ASR/local/compile_hlg.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input lang_dir and generates HLG from + + - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt + - L, the lexicon, built from lang_dir/L_disambig.pt + + Caution: We use a lexicon that contains disambiguation symbols + + - G, the LM, built from data/lm/G_n_gram.fst.txt + +The generated HLG is saved in $lang_dir/HLG.pt +""" +import argparse +import logging +from pathlib import Path + +import k2 +import torch + +from icefall.lexicon import Lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lm", + type=str, + default="G_3_gram", + help="""Stem name for LM used in HLG compiling. + """, + ) + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + + return parser.parse_args() + + +def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: + """ + Args: + lang_dir: + The language directory, e.g., data/lang_phone or data/lang_bpe_5000. + lm: + The language stem base name. + + Return: + An FSA representing HLG. + """ + lexicon = Lexicon(lang_dir) + datapath = str(lang_dir).split('/')[0] + max_token_id = max(lexicon.tokens) + logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") + H = k2.ctc_topo(max_token_id) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + + if Path(f"{datapath}/lm/{lm}.pt").is_file(): + logging.info(f"Loading pre-compiled {lm}") + d = torch.load(f"{datapath}/lm/{lm}.pt") + G = k2.Fsa.from_dict(d) + else: + logging.info(f"Loading {lm}.fst.txt") + with open(f"{datapath}/lm/{lm}.fst.txt") as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + torch.save(G.as_dict(), f"{datapath}/lm/{lm}.pt") + + first_token_disambig_id = lexicon.token_table["#0"] + first_word_disambig_id = lexicon.word_table["#0"] + + L = k2.arc_sort(L) + G = k2.arc_sort(G) + + logging.info("Intersecting L and G") + LG = k2.compose(L, G) + logging.info(f"LG shape: {LG.shape}") + + logging.info("Connecting LG") + LG = k2.connect(LG) + logging.info(f"LG shape after k2.connect: {LG.shape}") + + logging.info(type(LG.aux_labels)) + logging.info("Determinizing LG") + + LG = k2.determinize(LG) + logging.info(type(LG.aux_labels)) + + logging.info("Connecting LG after k2.determinize") + LG = k2.connect(LG) + + logging.info("Removing disambiguation symbols on LG") + + LG.labels[LG.labels >= first_token_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set LG.properties to None + LG.__dict__["_properties"] = None + + assert isinstance(LG.aux_labels, k2.RaggedTensor) + LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 + + LG = k2.remove_epsilon(LG) + logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") + + LG = k2.connect(LG) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) + + logging.info("Arc sorting LG") + LG = k2.arc_sort(LG) + + logging.info("Composing H and LG") + # CAUTION: The name of the inner_labels is fixed + # to `tokens`. If you want to change it, please + # also change other places in icefall that are using + # it. + HLG = k2.compose(H, LG, inner_labels="tokens") + + logging.info("Connecting LG") + HLG = k2.connect(HLG) + + logging.info("Arc sorting LG") + HLG = k2.arc_sort(HLG) + logging.info(f"HLG.shape: {HLG.shape}") + + return HLG + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + + if (lang_dir / "HLG.pt").is_file(): + logging.info(f"{lang_dir}/HLG.pt already exists - skipping") + return + + logging.info(f"Processing {lang_dir}") + + HLG = compile_HLG(lang_dir, args.lm) + logging.info(f"Saving HLG.pt to {lang_dir}") + torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/mucs/ASR/local/compute_fbank_mucs.py b/egs/mucs/ASR/local/compute_fbank_mucs.py new file mode 100755 index 000000000..0f47c4a71 --- /dev/null +++ b/egs/mucs/ASR/local/compute_fbank_mucs.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the LibriSpeech dataset. +It looks for manifests in the directory data/manifests. +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import sentencepiece as spm +import torch +from filter_cuts import filter_cuts +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to the bpe.model. If not None, we will remove short and + long utterances before extracting features""", + ) + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + parser.add_argument( + "--manifestpath", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + parser.add_argument( + "--fbankpath", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + return parser.parse_args() + + +def compute_fbank_mucs( + bpe_model: Optional[str] = None, + dataset: Optional[str] = None, +): + src_dir = Path(args.manifestpath) + output_dir = Path(args.fbankpath) + num_jobs = min(48, os.cpu_count()) + num_mel_bins = 80 + + if bpe_model: + logging.info(f"Loading {bpe_model}") + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + + dataset_parts = ( + "train", + "test", + "dev", + ) + + prefix = "mucs" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, min_duration=None, keep_all_channels=False, + + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + compute_fbank_mucs(bpe_model=args.bpe_model, dataset=args.dataset) \ No newline at end of file diff --git a/egs/mucs/ASR/local/convert_transcript_words_to_tokens.py b/egs/mucs/ASR/local/convert_transcript_words_to_tokens.py new file mode 120000 index 000000000..2ce13fd69 --- /dev/null +++ b/egs/mucs/ASR/local/convert_transcript_words_to_tokens.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py \ No newline at end of file diff --git a/egs/mucs/ASR/local/filter_cuts.py b/egs/mucs/ASR/local/filter_cuts.py new file mode 120000 index 000000000..27aca1729 --- /dev/null +++ b/egs/mucs/ASR/local/filter_cuts.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/filter_cuts.py \ No newline at end of file diff --git a/egs/mucs/ASR/local/filter_scp.pl b/egs/mucs/ASR/local/filter_scp.pl new file mode 100755 index 000000000..b76d37f41 --- /dev/null +++ b/egs/mucs/ASR/local/filter_scp.pl @@ -0,0 +1,87 @@ +#!/usr/bin/env perl +# Copyright 2010-2012 Microsoft Corporation +# Johns Hopkins University (author: Daniel Povey) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This script takes a list of utterance-ids or any file whose first field +# of each line is an utterance-id, and filters an scp +# file (or any file whose "n-th" field is an utterance id), printing +# out only those lines whose "n-th" field is in id_list. The index of +# the "n-th" field is 1, by default, but can be changed by using +# the -f switch + +$exclude = 0; +$field = 1; +$shifted = 0; + +do { + $shifted=0; + if ($ARGV[0] eq "--exclude") { + $exclude = 1; + shift @ARGV; + $shifted=1; + } + if ($ARGV[0] eq "-f") { + $field = $ARGV[1]; + shift @ARGV; shift @ARGV; + $shifted=1 + } +} while ($shifted); + +if(@ARGV < 1 || @ARGV > 2) { + die "Usage: filter_scp.pl [--exclude] [-f ] id_list [in.scp] > out.scp \n" . + "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" . + "Note: only the first field of each line in id_list matters. With --exclude, prints\n" . + "only the lines that were *not* in id_list.\n" . + "Caution: previously, the -f option was interpreted as a zero-based field index.\n" . + "If your older scripts (written before Oct 2014) stopped working and you used the\n" . + "-f option, add 1 to the argument.\n" . + "See also: utils/filter_scp.pl .\n"; +} + + +$idlist = shift @ARGV; +open(F, "<$idlist") || die "Could not open id-list file $idlist"; +while() { + @A = split; + @A>=1 || die "Invalid id-list file line $_"; + $seen{$A[0]} = 1; +} + +if ($field == 1) { # Treat this as special case, since it is common. + while(<>) { + $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field."; + # $1 is what we filter on. + if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) { + print $_; + } + } +} else { + while(<>) { + @A = split; + @A > 0 || die "Invalid scp file line $_"; + @A >= $field || die "Invalid scp file line $_"; + if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) { + print $_; + } + } +} + +# tests: +# the following should print "foo 1" +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo) +# the following should print "bar 2". +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2) diff --git a/egs/mucs/ASR/local/prepare_lang.py b/egs/mucs/ASR/local/prepare_lang.py new file mode 120000 index 000000000..747f2ab39 --- /dev/null +++ b/egs/mucs/ASR/local/prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/mucs/ASR/local/prepare_lang_bpe.py b/egs/mucs/ASR/local/prepare_lang_bpe.py new file mode 120000 index 000000000..36b40e7fc --- /dev/null +++ b/egs/mucs/ASR/local/prepare_lang_bpe.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/mucs/ASR/local/prepare_lm_files.py b/egs/mucs/ASR/local/prepare_lm_files.py new file mode 100755 index 000000000..d6234e6e7 --- /dev/null +++ b/egs/mucs/ASR/local/prepare_lm_files.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 + +import argparse +import gzip +import logging +import os +import shutil +from pathlib import Path + +from tqdm.auto import tqdm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--out-dir", type=str, help="Output directory.") + parser.add_argument("--data-path", type=str, help="Input directory.") + parser.add_argument("--mode", type=str, help="Input split") + args = parser.parse_args() + return args + +def read_text(path): + with open(path, 'r') as f: + lines = f.read().split('\n') + return [' '.join(l.split(' ')[1:]) for l in lines] + +def create_files(text): + lexicon = {} + for line in text: + for word in line.split(' '): + if word.strip() == '': continue + if word not in lexicon: + lexicon[word] = ' '.join(list(word)) + with open(os.path.join(args.out_dir, 'mucs_lexicon.txt'), 'w') as f: + for word in lexicon: + f.write(word + '\t' + lexicon[word] + '\n') + with open(os.path.join(args.out_dir, 'mucs_vocab.txt'), 'w') as f: + for word in lexicon: + f.write(word + '\n') + with open(os.path.join(args.out_dir, 'mucs_vocab_text.txt'), 'w') as f: + for line in text: + f.write(line + '\n') + +def main(): + path = os.path.join(args.data_path, args.mode) + text = read_text(os.path.join(path, "text")) + create_files(text) + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(f"out_dir: {args.out_dir}") + logging.info(f"in_dir: {args.data_path}") + main() diff --git a/egs/mucs/ASR/local/prepare_manifest.py b/egs/mucs/ASR/local/prepare_manifest.py new file mode 100755 index 000000000..9dee0d1a9 --- /dev/null +++ b/egs/mucs/ASR/local/prepare_manifest.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 + +import sys +import logging +import shutil +import lhotse +import os +import tarfile +from pathlib import Path +from typing import Dict, Optional, Sequence, Union + +from tqdm import tqdm + +from lhotse import ( + RecordingSet, + SupervisionSegment, + SupervisionSet, + validate_recordings_and_supervisions, +) +from lhotse.recipes.utils import manifests_exist, read_manifests_if_cached +from lhotse.utils import Pathlike, safe_extract, urlretrieve_progress + + +def prepare_mucs( + corpus_dir: Pathlike, + output_dir: Optional[Pathlike] = None, + num_jobs: int = 1, +) -> Dict[str, Dict[str, Union[RecordingSet, SupervisionSet]]]: + """ + Returns the manifests which consist of the Recordings and Supervisions. + When all the manifests are available in the ``output_dir``, it will simply read and return them. + :param corpus_dir: Pathlike, the path of the data dir. + :param dataset_parts: string or sequence of strings representing dataset part names, e.g. 'train-clean-100', 'train-clean-5', 'dev-clean'. + By default we will infer which parts are available in ``corpus_dir``. + :param output_dir: Pathlike, the path where to write the manifests. + :param num_jobs: the number of parallel workers parsing the data. + :param link_previous_utt: If true adds previous utterance id to supervisions. + Useful for reconstructing chains of utterances as they were read. + If previous utterance was skipped from LibriTTS datasets previous_utt label is None. + :return: a Dict whose key is the dataset part, and the value is Dicts with the keys 'audio' and 'supervisions'. + """ + corpus_dir = Path(corpus_dir) + assert corpus_dir.is_dir(), f"No such directory: {corpus_dir}" + + dataset_parts = ["train", "test", "dev"] + + manifests = {} + + if output_dir is not None: + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + # Maybe the manifests already exist: we can read them and save a bit of preparation time. + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, output_dir=output_dir, prefix="mucs" + ) + + # Contents of the file + # ;ID |SEX| SUBSET |MINUTES| NAME + # 14 | F | train-clean-360 | 25.03 | ... + # 16 | F | train-clean-360 | 25.11 | ... + # 17 | M | train-clean-360 | 25.04 | ... + + + + for part in tqdm(dataset_parts, desc="Preparing mucs parts from espnet files"): + + if manifests_exist(part=part, output_dir=output_dir, prefix="mucs"): + logging.info(f"mucs subset: {part} already prepared - skipping.") + continue + recordings, supervisions, _ = lhotse.kaldi.load_kaldi_data_dir(os.path.join(corpus_dir, part), sampling_rate=16000) + validate_recordings_and_supervisions(recordings, supervisions) + + if output_dir is not None: + supervisions.to_file(output_dir / f"mucs_supervisions_{part}.jsonl.gz") + recordings.to_file(output_dir / f"mucs_recordings_{part}.jsonl.gz") + + manifests[part] = {"recordings": recordings, "supervisions": supervisions} + + return + +if __name__ == "__main__": + datapath = sys.argv[1] + nj = int(sys.argv[2]) + savepath = sys.argv[3] + print(datapath, nj, savepath) + prepare_mucs(datapath, savepath, nj) \ No newline at end of file diff --git a/egs/mucs/ASR/local/subset_data_dir.sh b/egs/mucs/ASR/local/subset_data_dir.sh new file mode 100755 index 000000000..2dd48bedd --- /dev/null +++ b/egs/mucs/ASR/local/subset_data_dir.sh @@ -0,0 +1,196 @@ +#!/usr/bin/env bash +# Copyright 2010-2011 Microsoft Corporation +# 2012-2013 Johns Hopkins University (Author: Daniel Povey) +# Apache 2.0 + + +# This script operates on a data directory, such as in data/train/. +# See http://kaldi-asr.org/doc/data_prep.html#data_prep_data +# for what these directories contain. + +# This script creates a subset of that data, consisting of some specified +# number of utterances. (The selected utterances are distributed evenly +# throughout the file, by the program ./subset_scp.pl). + +# There are six options, none compatible with any other. + +# If you give the --per-spk option, it will attempt to select the supplied +# number of utterances for each speaker (typically you would supply a much +# smaller number in this case). + +# If you give the --speakers option, it selects a subset of n randomly +# selected speakers. + +# If you give the --shortest option, it will give you the n shortest utterances. + +# If you give the --first option, it will just give you the n first utterances. + +# If you give the --last option, it will just give you the n last utterances. + +# If you give the --spk-list or --utt-list option, it reads the +# speakers/utterances to keep from /" (note, +# in this case there is no positional parameter; see usage message.) + + +shortest=false +perspk=false +speakers=false +first_opt= +spk_list= +utt_list= + +expect_args=3 +case $1 in + --first|--last) first_opt=$1; shift ;; + --per-spk) perspk=true; shift ;; + --shortest) shortest=true; shift ;; + --speakers) speakers=true; shift ;; + --spk-list) shift; spk_list=$1; shift; expect_args=2 ;; + --utt-list) shift; utt_list=$1; shift; expect_args=2 ;; + --*) echo "$0: invalid option '$1'"; exit 1 +esac + +if [ $# != $expect_args ]; then + echo "Usage:" + echo " subset_data_dir.sh [--speakers|--shortest|--first|--last|--per-spk] " + echo " subset_data_dir.sh [--spk-list ] " + echo " subset_data_dir.sh [--utt-list ] " + echo "By default, randomly selects utterances from the data directory." + echo "With --speakers, randomly selects enough speakers that we have utterances" + echo "With --per-spk, selects utterances per speaker, if available." + echo "With --first, selects the first utterances" + echo "With --last, selects the last utterances" + echo "With --shortest, selects the shortest utterances." + echo "With --spk-list, reads the speakers to keep from " + echo "With --utt-list, reads the utterances to keep from " + exit 1; +fi + +srcdir=$1 +if [[ $spk_list || $utt_list ]]; then + numutt= + destdir=$2 +else + numutt=$2 + destdir=$3 +fi + +export LC_ALL=C + +if [ ! -f $srcdir/utt2spk ]; then + echo "$0: no such file $srcdir/utt2spk" + exit 1 +fi + +if [[ $numutt && $numutt -gt $(wc -l <$srcdir/utt2spk) ]]; then + echo "$0: cannot subset to more utterances than you originally had." + exit 1 +fi + +if $shortest && [ ! -f $srcdir/feats.scp ]; then + echo "$0: you selected --shortest but no feats.scp exist." + exit 1 +fi + +mkdir -p $destdir || exit 1 + +if [[ $spk_list ]]; then + local/filter_scp.pl "$spk_list" $srcdir/spk2utt > $destdir/spk2utt || exit 1; + utils/spk2utt_to_utt2spk.pl < $destdir/spk2utt > $destdir/utt2spk || exit 1; +elif [[ $utt_list ]]; then + local/filter_scp.pl "$utt_list" $srcdir/utt2spk > $destdir/utt2spk || exit 1; + local/utt2spk_to_spk2utt.pl < $destdir/utt2spk > $destdir/spk2utt || exit 1; +elif $speakers; then + utils/shuffle_list.pl < $srcdir/spk2utt | + awk -v numutt=$numutt '{ if (tot < numutt){ print; } tot += (NF-1); }' | + sort > $destdir/spk2utt + utils/spk2utt_to_utt2spk.pl < $destdir/spk2utt > $destdir/utt2spk +elif $perspk; then + awk '{ n='$numutt'; printf("%s ",$1); + skip=1; while(n*(skip+1) <= NF-1) { skip++; } + for(x=2; x<=NF && x <= (n*skip+1); x += skip) { printf("%s ", $x); } + printf("\n"); }' <$srcdir/spk2utt >$destdir/spk2utt + utils/spk2utt_to_utt2spk.pl < $destdir/spk2utt > $destdir/utt2spk +else + if $shortest; then + # Select $numutt shortest utterances. + . ./path.sh + if [ -f $srcdir/utt2num_frames ]; then + ln -sf $(utils/make_absolute.sh $srcdir)/utt2num_frames $destdir/tmp.len + else + feat-to-len scp:$srcdir/feats.scp ark,t:$destdir/tmp.len || exit 1; + fi + sort -n -k2 $destdir/tmp.len | + awk '{print $1}' | + head -$numutt >$destdir/tmp.uttlist + local/filter_scp.pl $destdir/tmp.uttlist $srcdir/utt2spk >$destdir/utt2spk + rm $destdir/tmp.uttlist $destdir/tmp.len + else + # Select $numutt random utterances. + local/subset_scp.pl $first_opt $numutt $srcdir/utt2spk > $destdir/utt2spk || exit 1; + fi + local/utt2spk_to_spk2utt.pl < $destdir/utt2spk > $destdir/spk2utt +fi + +# Perform filtering. utt2spk and spk2utt files already exist by this point. +# Filter by utterance. +[ -f $srcdir/feats.scp ] && + local/filter_scp.pl $destdir/utt2spk <$srcdir/feats.scp >$destdir/feats.scp +[ -f $srcdir/vad.scp ] && + local/filter_scp.pl $destdir/utt2spk <$srcdir/vad.scp >$destdir/vad.scp +[ -f $srcdir/utt2lang ] && + local/filter_scp.pl $destdir/utt2spk <$srcdir/utt2lang >$destdir/utt2lang +[ -f $srcdir/utt2dur ] && + local/filter_scp.pl $destdir/utt2spk <$srcdir/utt2dur >$destdir/utt2dur +[ -f $srcdir/utt2num_frames ] && + local/filter_scp.pl $destdir/utt2spk <$srcdir/utt2num_frames >$destdir/utt2num_frames +[ -f $srcdir/utt2uniq ] && + local/filter_scp.pl $destdir/utt2spk <$srcdir/utt2uniq >$destdir/utt2uniq +[ -f $srcdir/wav.scp ] && + local/filter_scp.pl $destdir/utt2spk <$srcdir/wav.scp >$destdir/wav.scp +[ -f $srcdir/utt2warp ] && + local/filter_scp.pl $destdir/utt2spk <$srcdir/utt2warp >$destdir/utt2warp +[ -f $srcdir/text ] && + local/filter_scp.pl $destdir/utt2spk <$srcdir/text >$destdir/text + +# Filter by speaker. +[ -f $srcdir/spk2warp ] && + local/filter_scp.pl $destdir/spk2utt <$srcdir/spk2warp >$destdir/spk2warp +[ -f $srcdir/spk2gender ] && + local/filter_scp.pl $destdir/spk2utt <$srcdir/spk2gender >$destdir/spk2gender +[ -f $srcdir/cmvn.scp ] && + local/filter_scp.pl $destdir/spk2utt <$srcdir/cmvn.scp >$destdir/cmvn.scp + +# Filter by recording-id. +if [ -f $srcdir/segments ]; then + local/filter_scp.pl $destdir/utt2spk <$srcdir/segments >$destdir/segments + # Recording-ids are in segments. + awk '{print $2}' $destdir/segments | sort | uniq >$destdir/reco + # The next line overrides the command above for wav.scp, which would be incorrect. + [ -f $srcdir/wav.scp ] && + local/filter_scp.pl $destdir/reco <$srcdir/wav.scp >$destdir/wav.scp +else + # No segments; recording-ids are in wav.scp. + awk '{print $1}' $destdir/wav.scp | sort | uniq >$destdir/reco +fi + +[ -f $srcdir/reco2file_and_channel ] && + local/filter_scp.pl $destdir/reco <$srcdir/reco2file_and_channel >$destdir/reco2file_and_channel +[ -f $srcdir/reco2dur ] && + local/filter_scp.pl $destdir/reco <$srcdir/reco2dur >$destdir/reco2dur + +# Filter the STM file for proper sclite scoring. +# Copy over the comments from STM file. +[ -f $srcdir/stm ] && + (grep "^;;" $srcdir/stm + local/filter_scp.pl $destdir/reco $srcdir/stm) >$destdir/stm + +rm $destdir/reco + +# Copy frame_shift if present. +[ -f $srcdir/frame_shift ] && cp $srcdir/frame_shift $destdir + +srcutts=$(wc -l <$srcdir/utt2spk) +destutts=$(wc -l <$destdir/utt2spk) +echo "$0: reducing #utt from $srcutts to $destutts" +exit 0 diff --git a/egs/mucs/ASR/local/subset_scp.pl b/egs/mucs/ASR/local/subset_scp.pl new file mode 100755 index 000000000..11fddc09a --- /dev/null +++ b/egs/mucs/ASR/local/subset_scp.pl @@ -0,0 +1,105 @@ +#!/usr/bin/env perl +use warnings; #sed replacement for -w perl parameter +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# This program selects a subset of N elements in the scp. + +# By default, it selects them evenly from throughout the scp, in order to avoid +# selecting too many from the same speaker. It prints them on the standard +# output. +# With the option --first, it just selects the N first utterances. +# With the option --last, it just selects the N last utterances. + +# Last modified by JHU & HKUST @2013 + + +$quiet = 0; +$first = 0; +$last = 0; + +if (@ARGV > 0 && $ARGV[0] eq "--quiet") { + shift; + $quiet = 1; +} +if (@ARGV > 0 && $ARGV[0] eq "--first") { + shift; + $first = 1; +} +if (@ARGV > 0 && $ARGV[0] eq "--last") { + shift; + $last = 1; +} + +if(@ARGV < 2 ) { + die "Usage: subset_scp.pl [--quiet][--first|--last] N in.scp\n" . + " --quiet causes it to not die if N < num lines in scp.\n" . + " --first and --last make it equivalent to head or tail.\n" . + "See also: filter_scp.pl\n"; +} + +$N = shift @ARGV; +if($N == 0) { + die "First command-line parameter to subset_scp.pl must be an integer, got \"$N\""; +} +$inscp = shift @ARGV; +open(I, "<$inscp") || die "Opening input scp file $inscp"; + +@F = (); +while() { + push @F, $_; +} +$numlines = @F; +if($N > $numlines) { + if ($quiet) { + $N = $numlines; + } else { + die "You requested from subset_scp.pl more elements than available: $N > $numlines"; + } +} + +sub select_n { + my ($start,$end,$num_needed) = @_; + my $diff = $end - $start; + if ($num_needed > $diff) { + die "select_n: code error"; + } + if ($diff == 1 ) { + if ($num_needed > 0) { + print $F[$start]; + } + } else { + my $halfdiff = int($diff/2); + my $halfneeded = int($num_needed/2); + select_n($start, $start+$halfdiff, $halfneeded); + select_n($start+$halfdiff, $end, $num_needed - $halfneeded); + } +} + +if ( ! $first && ! $last) { + if ($N > 0) { + select_n(0, $numlines, $N); + } +} else { + if ($first) { # --first option: same as head. + for ($n = 0; $n < $N; $n++) { + print $F[$n]; + } + } else { # --last option: same as tail. + for ($n = @F - $N; $n < @F; $n++) { + print $F[$n]; + } + } +} diff --git a/egs/mucs/ASR/local/train_bpe_model.py b/egs/mucs/ASR/local/train_bpe_model.py new file mode 120000 index 000000000..6fad36421 --- /dev/null +++ b/egs/mucs/ASR/local/train_bpe_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/mucs/ASR/local/utt2spk_to_spk2utt.pl b/egs/mucs/ASR/local/utt2spk_to_spk2utt.pl new file mode 100755 index 000000000..6e0e438ca --- /dev/null +++ b/egs/mucs/ASR/local/utt2spk_to_spk2utt.pl @@ -0,0 +1,38 @@ +#!/usr/bin/env perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# converts an utt2spk file to a spk2utt file. +# Takes input from the stdin or from a file argument; +# output goes to the standard out. + +if ( @ARGV > 1 ) { + die "Usage: utt2spk_to_spk2utt.pl [ utt2spk ] > spk2utt"; +} + +while(<>){ + @A = split(" ", $_); + @A == 2 || die "Invalid line in utt2spk file: $_"; + ($u,$s) = @A; + if(!$seen_spk{$s}) { + $seen_spk{$s} = 1; + push @spklist, $s; + } + push (@{$spk_hash{$s}}, "$u"); +} +foreach $s (@spklist) { + $l = join(' ',@{$spk_hash{$s}}); + print "$s $l\n"; +} diff --git a/egs/mucs/ASR/local/validate_bpe_lexicon.py b/egs/mucs/ASR/local/validate_bpe_lexicon.py new file mode 120000 index 000000000..721bb48e7 --- /dev/null +++ b/egs/mucs/ASR/local/validate_bpe_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/mucs/ASR/local/validate_manifest.py b/egs/mucs/ASR/local/validate_manifest.py new file mode 120000 index 000000000..0a9725e87 --- /dev/null +++ b/egs/mucs/ASR/local/validate_manifest.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_manifest.py \ No newline at end of file diff --git a/egs/mucs/ASR/prepare.sh b/egs/mucs/ASR/prepare.sh new file mode 100755 index 000000000..40fa7ffc5 --- /dev/null +++ b/egs/mucs/ASR/prepare.sh @@ -0,0 +1,259 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=60 +stage=-1 +stop_stage=9 + +# We assume dl_dir (download dir) contains the following +# directories and files. download them from https://www.openslr.org/resources/104/ +# +# - $dl_dir/hi-en + +dl_dir=$PWD/download +mkdir -p $dl_dir + +raw_data_path="/data/Database/MUCS/" +dataset="hi-en" #hin-en or bn-en + +datadir="data_"$dataset +raw_kaldi_files_path=$dl_dir/$dataset/ + + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +vocab_size=400 + + +mkdir -p $datadir + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: prepare data files" + + mkdir -p $dl_dir/$dataset + for x in train dev test train_all; do + if [ -d "$dl_dir/$dataset/$x" ]; then rm -Rf $dl_dir/$dataset/$x; fi + done + mkdir -p $dl_dir/$dataset/{train,test,dev} + + + + cp -r $raw_data_path/$dataset/"train"/"transcripts"/* $dl_dir/$dataset/"train" + cp -r $raw_data_path/$dataset/"test"/"transcripts"/* $dl_dir/$dataset/"test" + + for x in train test + do + cp $dl_dir/$dataset/$x/"wav.scp" $dl_dir/$dataset/$x/"wav.scp_old" + cat $dl_dir/$dataset/$x/"wav.scp" | cut -d' ' -f1 > $dl_dir/$dataset/$x/wav_ids + cat $dl_dir/$dataset/$x/"wav.scp" | cut -d' ' -f2 | awk -v var="$raw_data_path/$dataset/$x/" '{print var$1}' > $dl_dir/$dataset/$x/wav_ids_with_fullpath + paste -d' ' $dl_dir/$dataset/$x/wav_ids $dl_dir/$dataset/$x/wav_ids_with_fullpath > $dl_dir/$dataset/$x/"wav.scp" + rm $dl_dir/$dataset/$x/wav_ids + rm $dl_dir/$dataset/$x/wav_ids_with_fullpath + done + ./local/subset_data_dir.sh --first $dl_dir/$dataset/"train" 1000 $dl_dir/$dataset/"dev" + total=$(wc -l $dl_dir/$dataset/"train"/"text" | cut -d' ' -f1) + count=$(expr $total - 1000) + + ./local/subset_data_dir.sh --first $dl_dir/$dataset/"train" $count $dl_dir/$dataset/"train_reduced" + mv $dl_dir/$dataset/"train" $dl_dir/$dataset/"train_all" + mv $dl_dir/$dataset/"train_reduced" $dl_dir/$dataset/"train" + + +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: prepare LM files" + mkdir -p $raw_kaldi_files_path/lm + if [ ! -e $raw_kaldi_files_path/lm/.done ]; then + ./local/prepare_lm_files.py --out-dir=$dl_dir/lm --data-path=$raw_kaldi_files_path --mode="train" + touch $raw_kaldi_files_path/lm/.done + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare MUCS manifest" + # We assume that you have downloaded the MUCS corpus + # to $dl_dir/ + mkdir -p $datadir/manifests + if [ ! -e $datadir/manifests/.mucs.done ]; then + # generate lhotse manifests from kaldi style files + ./local/prepare_manifest.py "$raw_kaldi_files_path" $nj $datadir/manifests + + touch $datadir/manifests/.mucs.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compute fbank for mucs" + mkdir -p $datadir/fbank + if [ ! -e $datadir/fbank/.mucs.done ]; then + ./local/compute_fbank_mucs.py --manifestpath $datadir/manifests/ --fbankpath $datadir/fbank + touch $datadir/fbank/.mucs.done + fi + + # exit + + if [ ! -e $datadir/fbank/.mucs-validated.done ]; then + log "Validating $datadir/fbank for mucs" + parts=( + train + test + dev + ) + for part in ${parts[@]}; do + python3 ./local/validate_manifest.py \ + $datadir/fbank/mucs_cuts_${part}.jsonl.gz + done + touch $datadir/fbank/.mucs-validated.done + fi +fi + + + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Prepare phone based lang" + lang_dir=$datadir/lang_phone + mkdir -p $lang_dir + + (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | + cat - $dl_dir/lm/mucs_lexicon.txt | + sort | uniq > $lang_dir/lexicon.txt + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang.py --lang-dir $lang_dir + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/disambig_L.fst + fi +fi + + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Prepare BPE based lang" + + lang_dir=$datadir/lang_bpe_${vocab_size} + mkdir -p $lang_dir + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp $datadir/lang_phone/words.txt $lang_dir + + if [ ! -f $lang_dir/transcript_words.txt ]; then + log "Generate data for BPE training" + cp download/lm/mucs_vocab_text.txt $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bpe.model + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi + +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Train LM from training data" + + lang_dir=$datadir/lang_bpe_${vocab_size} + + if [ ! -f $lang_dir/lm_3.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order 3 \ + -text $lang_dir/transcript_words.txt \ + -lm $lang_dir/lm_3.arpa + fi + + if [ ! -f $lang_dir/lm_4.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order 4 \ + -text $lang_dir/transcript_words.txt \ + -lm $lang_dir/lm_4.arpa + fi + +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Prepare G" + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + + mkdir -p $datadir/lm + if [ ! -f $datadir/lm/G_3_gram.fst.txt ]; then + # It is used in building HLG + python3 -m kaldilm \ + --read-symbol-table="$datadir/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + $datadir/lang_bpe_${vocab_size}/lm_3.arpa > $datadir/lm/G_3_gram.fst.txt + fi + + if [ ! -f $datadir/lm/G_4_gram.fst.txt ]; then + # It is used in building HLG + python3 -m kaldilm \ + --read-symbol-table="$datadir/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + $datadir/lang_bpe_${vocab_size}/lm_4.arpa > $datadir/lm/G_4_gram.fst.txt + fi + +fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Compile HLG" + + lang_dir=$datadir/lang_bpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir + +fi + diff --git a/egs/mucs/ASR/run.sh b/egs/mucs/ASR/run.sh new file mode 100755 index 000000000..4739b7850 --- /dev/null +++ b/egs/mucs/ASR/run.sh @@ -0,0 +1,39 @@ +#!/bin/bash +export CUDA_VISIBLE_DEVICES="0" + +set -e +dataset='hi-en' +datadir=data_"$dataset" +bpe=400 +# decode_methods="attention-decoder 1best nbest nbest-rescoring ctc-decoding whole-lattice-rescoring" +decode_methods="nbest nbest-rescoring whole-lattice-rescoring" + +num_paths=10 +max_duration=5 + +# ./conformer_ctc/train.py \ +# --num-epochs 60 \ +# --max-duration 300 \ +# --exp-dir ./conformer_ctc/exp_"$dataset"_bpe"$bpe" \ +# --manifest-dir $datadir/fbank \ +# --lang-dir $datadir/lang_bpe_"$bpe" \ +# --enable-musan False \ + +for decode_method in $decode_methods; +do + + ./conformer_ctc/decode.py \ + --epoch 59 \ + --avg 10 \ + --manifest-dir $datadir/fbank \ + --exp-dir ./conformer_ctc/exp_"$dataset"_bpe"$bpe" \ + --max-duration $max_duration \ + --lang-dir $datadir/lang_bpe_"$bpe" \ + --lm-dir $datadir/"lm" \ + --method $decode_method \ + --num-paths $num_paths \ + +done +exit + + \ No newline at end of file