icefall/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py
Zengwei Yao a9dccdc33f
Streaming merge (#369)
* Remove ReLU in attention

* Adding diagnostics code...

* Refactor/simplify ConformerEncoder

* First version of rand-combine iterated-training-like idea.

* Improvements to diagnostics (RE those with 1 dim

* Add pelu to this good-performing setup..

* Small bug fixes/imports

* Add baseline for the PeLU expt, keeping only the small normalization-related changes.

* pelu_base->expscale, add 2xExpScale in subsampling, and in feedforward units.

* Double learning rate of exp-scale units

* Combine ExpScale and swish for memory reduction

* Add import

* Fix backprop bug

* Fix bug in diagnostics

* Increase scale on Scale from 4 to 20

* Increase scale from 20 to 50.

* Fix duplicate Swish; replace norm+swish with swish+exp-scale in convolution module

* Reduce scale from 50 to 20

* Add deriv-balancing code

* Double the threshold in brelu; slightly increase max_factor.

* Fix exp dir

* Convert swish nonlinearities to ReLU

* Replace relu with swish-squared.

* Restore ConvolutionModule to state before changes; change all Swish,Swish(Swish) to SwishOffset.

* Replace norm on input layer with scale of 0.1.

* Extensions to diagnostics code

* Update diagnostics

* Add BasicNorm module

* Replace most normalizations with scales (still have norm in conv)

* Change exp dir

* Replace norm in ConvolutionModule with a scaling factor.

* use nonzero threshold in DerivBalancer

* Add min-abs-value 0.2

* Fix dirname

* Change min-abs threshold from 0.2 to 0.5

* Scale up pos_bias_u and pos_bias_v before use.

* Reduce max_factor to 0.01

* Fix q*scaling logic

* Change max_factor in DerivBalancer from 0.025 to 0.01; fix scaling code.

* init 1st conv module to smaller variance

* Change how scales are applied; fix residual bug

* Reduce min_abs from 0.5 to 0.2

* Introduce in_scale=0.5 for SwishExpScale

* Fix scale from 0.5 to 2.0 as I really intended..

* Set scaling on SwishExpScale

* Add identity pre_norm_final for diagnostics.

* Add learnable post-scale for mha

* Fix self.post-scale-mha

* Another rework, use scales on linear/conv

* Change dir name

* Reduce initial scaling of modules

* Bug-fix RE bias

* Cosmetic change

* Reduce initial_scale.

* Replace ExpScaleRelu with DoubleSwish()

* DoubleSwish fix

* Use learnable scales for joiner and decoder

* Add max-abs-value constraint in DerivBalancer

* Add max-abs-value

* Change dir name

* Remove ExpScale in feedforward layes.

* Reduce max-abs limit from 1000 to 100; introduce 2 DerivBalancer modules in conv layer.

* Make DoubleSwish more memory efficient

* Reduce constraints from deriv-balancer in ConvModule.

* Add warmup mode

* Remove max-positive constraint in deriv-balancing; add second DerivBalancer in conv module.

* Add some extra info to diagnostics

* Add deriv-balancer at output of embedding.

* Add more stats.

* Make epsilon in BasicNorm learnable, optionally.

* Draft of 0mean changes..

* Rework of initialization

* Fix typo

* Remove dead code

* Modifying initialization from normal->uniform; add initial_scale when initializing

* bug fix re sqrt

* Remove xscale from pos_embedding

* Remove some dead code.

* Cosmetic changes/renaming things

* Start adding some files..

* Add more files..

* update decode.py file type

* Add remaining files in pruned_transducer_stateless2

* Fix diagnostics-getting code

* Scale down pruned loss in warmup mode

* Reduce warmup scale on pruned loss form 0.1 to 0.01.

* Remove scale_speed, make swish deriv more efficient.

* Cosmetic changes to swish

* Double warm_step

* Fix bug with import

* Change initial std from 0.05 to 0.025.

* Set also scale for embedding to 0.025.

* Remove logging code that broke with newer Lhotse; fix bug with pruned_loss

* Add norm+balancer to VggSubsampling

* Incorporate changes from master into pruned_transducer_stateless2.

* Add max-abs=6, debugged version

* Change 0.025,0.05 to 0.01 in initializations

* Fix balancer code

* Whitespace fix

* Reduce initial pruned_loss scale from 0.01 to 0.0

* Increase warm_step (and valid_interval)

* Change max-abs from 6 to 10

* Change how warmup works.

* Add changes from master to decode.py, train.py

* Simplify the warmup code; max_abs 10->6

* Make warmup work by scaling layer contributions; leave residual layer-drop

* Fix bug

* Fix test mode with random layer dropout

* Add random-number-setting function in dataloader

* Fix/patch how fix_random_seed() is imported.

* Reduce layer-drop prob

* Reduce layer-drop prob after warmup to 1 in 100

* Change power of lr-schedule from -0.5 to -0.333

* Increase model_warm_step to 4k

* Change max-keep-prob to 0.95

* Refactoring and simplifying conformer and frontend

* Rework conformer, remove some code.

* Reduce 1st conv channels from 64 to 32

* Add another convolutional layer

* Fix padding bug

* Remove dropout in output layer

* Reduce speed of some components

* Initial refactoring to remove unnecessary vocab_size

* Fix RE identity

* Bug-fix

* Add final dropout to conformer

* Remove some un-used code

* Replace nn.Linear with ScaledLinear in simple joiner

* Make 2 projections..

* Reduce initial_speed

* Use initial_speed=0.5

* Reduce initial_speed further from 0.5 to 0.25

* Reduce initial_speed from 0.5 to 0.25

* Change how warmup is applied.

* Bug fix to warmup_scale

* Fix test-mode

* Remove final dropout

* Make layer dropout rate 0.075, was 0.1.

* First draft of model rework

* Various bug fixes

* Change learning speed of simple_lm_proj

* Revert transducer_stateless/ to state in upstream/master

* Fix to joiner to allow different dims

* Some cleanups

* Make training more efficient, avoid redoing some projections.

* Change how warm-step is set

* First draft of new approach to learning rates + init

* Some fixes..

* Change initialization to 0.25

* Fix type of parameter

* Fix weight decay formula by adding 1/1-beta

* Fix weight decay formula by adding 1/1-beta

* Fix checkpoint-writing

* Fix to reading scheudler from optim

* Simplified optimizer, rework somet things..

* Reduce model_warm_step from 4k to 3k

* Fix bug in lambda

* Bug-fix RE sign of target_rms

* Changing initial_speed from 0.25 to 01

* Change some defaults in LR-setting rule.

* Remove initial_speed

* Set new scheduler

* Change exponential part of lrate to be epoch based

* Fix bug

* Set 2n rule..

* Implement 2o schedule

* Make lrate rule more symmetric

* Implement 2p version of learning rate schedule.

* Refactor how learning rate is set.

* Fix import

* Modify init (#301)

* update icefall/__init__.py to import more common functions.

* update icefall/__init__.py

* make imports style consistent.

* exclude black check for icefall/__init__.py in pyproject.toml.

* Minor fixes for logging (#296)

* Minor fixes for logging

* Minor fix

* Fix dir names

* Modify beam search to be efficient with current joienr

* Fix adding learning rate to tensorboard

* Fix docs in optim.py

* Support mix precision training on the reworked model (#305)

* Add mix precision support

* Minor fixes

* Minor fixes

* Minor fixes

* Tedlium3 pruned transducer stateless (#261)

* update tedlium3-pruned-transducer-stateless-codes

* update README.md

* update README.md

* add fast beam search for decoding

* do a change for RESULTS.md

* do a change for RESULTS.md

* do a fix

* do some changes for pruned RNN-T

* Add mix precision support

* Minor fixes

* Minor fixes

* Updating RESULTS.md; fix in beam_search.py

* Fix rebase

* Code style check for librispeech pruned transducer stateless2 (#308)

* Update results for tedlium3 pruned RNN-T (#307)

* Update README.md

* Fix CI errors. (#310)

* Add more results

* Fix tensorboard log location

* Add one more epoch of full expt

* fix comments

* Add results for mixed precision with max-duration 300

* Changes for pretrained.py (tedlium3 pruned RNN-T) (#311)

* GigaSpeech recipe (#120)

* initial commit

* support download, data prep, and fbank

* on-the-fly feature extraction by default

* support BPE based lang

* support HLG for BPE

* small fix

* small fix

* chunked feature extraction by default

* Compute features for GigaSpeech by splitting the manifest.

* Fixes after review.

* Split manifests into 2000 pieces.

* set audio duration mismatch tolerance to 0.01

* small fix

* add conformer training recipe

* Add conformer.py without pre-commit checking

* lazy loading and use SingleCutSampler

* DynamicBucketingSampler

* use KaldifeatFbank to compute fbank for musan

* use pretrained language model and lexicon

* use 3gram to decode, 4gram to rescore

* Add decode.py

* Update .flake8

* Delete compute_fbank_gigaspeech.py

* Use BucketingSampler for valid and test dataloader

* Update params in train.py

* Use bpe_500

* update params in decode.py

* Decrease num_paths while CUDA OOM

* Added README

* Update RESULTS

* black

* Decrease num_paths while CUDA OOM

* Decode with post-processing

* Update results

* Remove lazy_load option

* Use default `storage_type`

* Keep the original tolerance

* Use split-lazy

* black

* Update pretrained model

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>

* Add LG decoding (#277)

* Add LG decoding

* Add log weight pushing

* Minor fixes

* Support computing RNN-T loss with torchaudio (#316)

* Update results for torchaudio RNN-T. (#322)

* Fix some typos. (#329)

* fix fp16 option in example usage (#332)

* Support averaging models with weight tying. (#333)

* Support specifying iteration number of checkpoints for decoding. (#336)

See also #289

* Modified conformer with multi datasets (#312)

* Copy files for editing.

* Use librispeech + gigaspeech with modified conformer.

* Support specifying number of workers for on-the-fly feature extraction.

* Feature extraction code for GigaSpeech.

* Combine XL splits lazily during training.

* Fix warnings in decoding.

* Add decoding code for GigaSpeech.

* Fix decoding the gigaspeech dataset.

We have to use the decoder/joiner networks for the GigaSpeech dataset.

* Disable speed perturbe for XL subset.

* Compute the Nbest oracle WER for RNN-T decoding.

* Minor fixes.

* Minor fixes.

* Add results.

* Update results.

* Update CI.

* Update results.

* Fix style issues.

* Update results.

* Fix style issues.

* Update results. (#340)

* Update results.

* Typo fixes.

* Validate generated manifest files. (#338)

* Validate generated manifest files. (#338)

* Save batch to disk on OOM. (#343)

* Save batch to disk on OOM.

* minor fixes

* Fixes after review.

* Fix style issues.

* Fix decoding for gigaspeech in the libri + giga setup. (#345)

* Model average (#344)

* First upload of model average codes.

* minor fix

* update decode file

* update .flake8

* rename pruned_transducer_stateless3 to pruned_transducer_stateless4

* change epoch number counter starting from 1 instead of 0

* minor fix of pruned_transducer_stateless4/train.py

* refactor the checkpoint.py

* minor fix, update docs, and modify the epoch number to count from 1 in the pruned_transducer_stateless4/decode.py

* update author info

* add docs of the scaling in function average_checkpoints_with_averaged_model

* Save batch to disk on exception. (#350)

* Bug fix (#352)

* Keep model_avg on cpu (#348)

* keep model_avg on cpu

* explicitly convert model_avg to cpu

* minor fix

* remove device convertion for model_avg

* modify usage of the model device in train.py

* change model.device to next(model.parameters()).device for decoding

* assert params.start_epoch>0

* assert params.start_epoch>0, params.start_epoch

* Do some changes for aishell/ASR/transducer stateless/export.py (#347)

* do some changes for aishell/ASR/transducer_stateless/export.py

* Support decoding with averaged model when using --iter (#353)

* support decoding with averaged model when using --iter

* minor fix

* monir fix of copyright date

* Stringify torch.__version__ before serializing it. (#354)

* Run decode.py in GitHub actions. (#356)

* Ignore padding frames during RNN-T decoding. (#358)

* Ignore padding frames during RNN-T decoding.

* Fix outdated decoding code.

* Minor fixes.

* Support --iter in export.py (#360)

* GigaSpeech RNN-T experiments (#318)

* Copy RNN-T recipe from librispeech

* flake8

* flake8

* Update params

* gigaspeech decode

* black

* Update results

* syntax highlight

* Update RESULTS.md

* typo

* Update decoding script for gigaspeech and remove duplicate files. (#361)

* Validate that there are no OOV tokens in BPE-based lexicons. (#359)

* Validate that there are no OOV tokens in BPE-based lexicons.

* Typo fixes.

* Decode gigaspeech in GitHub actions (#362)

* Add CI for gigaspeech.

* Update results for libri+giga multi dataset setup. (#363)

* Update results for libri+giga multi dataset setup.

* Update GigaSpeech reults (#364)

* Update decode.py

* Update export.py

* Update results

* Update README.md

* Fix GitHub CI for decoding GigaSpeech dev/test datasets (#366)

* modify .flake8

* minor fix

* minor fix

Co-authored-by: Daniel Povey <dpovey@gmail.com>
Co-authored-by: Wei Kang <wkang@pku.org.cn>
Co-authored-by: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com>
Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
Co-authored-by: Guo Liyong <guonwpu@qq.com>
Co-authored-by: Wang, Guanbo <wgb14@outlook.com>
Co-authored-by: whsqkaak <whsqkaak@naver.com>
Co-authored-by: pehonnet <pe.honnet@gmail.com>
2022-05-15 21:08:30 +08:00

170 lines
5.0 KiB
Python

#!/usr/bin/env python3
# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
from datetime import datetime
from pathlib import Path
import torch
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--num-workers",
type=int,
default=20,
help="Number of dataloading workers used for reading the audio.",
)
parser.add_argument(
"--batch-duration",
type=float,
default=600.0,
help="The maximum number of audio seconds in a batch."
"Determines batch size dynamically.",
)
parser.add_argument(
"--num-splits",
type=int,
required=True,
help="The number of splits of the XL subset",
)
parser.add_argument(
"--start",
type=int,
default=0,
help="Process pieces starting from this number (inclusive).",
)
parser.add_argument(
"--stop",
type=int,
default=-1,
help="Stop processing pieces until this number (exclusive).",
)
return parser
def compute_fbank_gigaspeech_splits(args):
num_splits = args.num_splits
output_dir = f"data/fbank/XL_split_{num_splits}"
output_dir = Path(output_dir)
assert output_dir.exists(), f"{output_dir} does not exist!"
num_digits = len(str(num_splits))
start = args.start
stop = args.stop
if stop < start:
stop = num_splits
stop = min(stop, num_splits)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
logging.info(f"device: {device}")
num_digits = 8 # num_digits is fixed by lhotse split-lazy
for i in range(start, stop):
idx = f"{i + 1}".zfill(num_digits)
logging.info(f"Processing {idx}/{num_splits}")
cuts_path = output_dir / f"cuts_XL.{idx}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
raw_cuts_path = output_dir / f"cuts_XL_raw.{idx}.jsonl.gz"
if not raw_cuts_path.is_file():
logging.info(f"{raw_cuts_path} does not exist - skipping it")
continue
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
logging.info("Computing features")
if (output_dir / f"feats_XL_{idx}.lca").exists():
logging.info(f"Removing {output_dir}/feats_XL_{idx}.lca")
os.remove(output_dir / f"feats_XL_{idx}.lca")
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{output_dir}/feats_XL_{idx}",
num_workers=args.num_workers,
batch_duration=args.batch_duration,
)
logging.info("About to split cuts into smaller chunks.")
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
logging.info(f"Saving to {cuts_path}")
cut_set.to_file(cuts_path)
logging.info(f"Saved to {cuts_path}")
def main():
now = datetime.now()
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
log_filename = "log-compute_fbank_gigaspeech_splits"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
log_filename = f"{log_filename}-{date_time}"
logging.basicConfig(
filename=log_filename,
format=formatter,
level=logging.INFO,
filemode="w",
)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
console.setFormatter(logging.Formatter(formatter))
logging.getLogger("").addHandler(console)
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
compute_fbank_gigaspeech_splits(args)
if __name__ == "__main__":
main()