mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 02:22:17 +00:00
* Remove ReLU in attention * Adding diagnostics code... * Refactor/simplify ConformerEncoder * First version of rand-combine iterated-training-like idea. * Improvements to diagnostics (RE those with 1 dim * Add pelu to this good-performing setup.. * Small bug fixes/imports * Add baseline for the PeLU expt, keeping only the small normalization-related changes. * pelu_base->expscale, add 2xExpScale in subsampling, and in feedforward units. * Double learning rate of exp-scale units * Combine ExpScale and swish for memory reduction * Add import * Fix backprop bug * Fix bug in diagnostics * Increase scale on Scale from 4 to 20 * Increase scale from 20 to 50. * Fix duplicate Swish; replace norm+swish with swish+exp-scale in convolution module * Reduce scale from 50 to 20 * Add deriv-balancing code * Double the threshold in brelu; slightly increase max_factor. * Fix exp dir * Convert swish nonlinearities to ReLU * Replace relu with swish-squared. * Restore ConvolutionModule to state before changes; change all Swish,Swish(Swish) to SwishOffset. * Replace norm on input layer with scale of 0.1. * Extensions to diagnostics code * Update diagnostics * Add BasicNorm module * Replace most normalizations with scales (still have norm in conv) * Change exp dir * Replace norm in ConvolutionModule with a scaling factor. * use nonzero threshold in DerivBalancer * Add min-abs-value 0.2 * Fix dirname * Change min-abs threshold from 0.2 to 0.5 * Scale up pos_bias_u and pos_bias_v before use. * Reduce max_factor to 0.01 * Fix q*scaling logic * Change max_factor in DerivBalancer from 0.025 to 0.01; fix scaling code. * init 1st conv module to smaller variance * Change how scales are applied; fix residual bug * Reduce min_abs from 0.5 to 0.2 * Introduce in_scale=0.5 for SwishExpScale * Fix scale from 0.5 to 2.0 as I really intended.. * Set scaling on SwishExpScale * Add identity pre_norm_final for diagnostics. * Add learnable post-scale for mha * Fix self.post-scale-mha * Another rework, use scales on linear/conv * Change dir name * Reduce initial scaling of modules * Bug-fix RE bias * Cosmetic change * Reduce initial_scale. * Replace ExpScaleRelu with DoubleSwish() * DoubleSwish fix * Use learnable scales for joiner and decoder * Add max-abs-value constraint in DerivBalancer * Add max-abs-value * Change dir name * Remove ExpScale in feedforward layes. * Reduce max-abs limit from 1000 to 100; introduce 2 DerivBalancer modules in conv layer. * Make DoubleSwish more memory efficient * Reduce constraints from deriv-balancer in ConvModule. * Add warmup mode * Remove max-positive constraint in deriv-balancing; add second DerivBalancer in conv module. * Add some extra info to diagnostics * Add deriv-balancer at output of embedding. * Add more stats. * Make epsilon in BasicNorm learnable, optionally. * Draft of 0mean changes.. * Rework of initialization * Fix typo * Remove dead code * Modifying initialization from normal->uniform; add initial_scale when initializing * bug fix re sqrt * Remove xscale from pos_embedding * Remove some dead code. * Cosmetic changes/renaming things * Start adding some files.. * Add more files.. * update decode.py file type * Add remaining files in pruned_transducer_stateless2 * Fix diagnostics-getting code * Scale down pruned loss in warmup mode * Reduce warmup scale on pruned loss form 0.1 to 0.01. * Remove scale_speed, make swish deriv more efficient. * Cosmetic changes to swish * Double warm_step * Fix bug with import * Change initial std from 0.05 to 0.025. * Set also scale for embedding to 0.025. * Remove logging code that broke with newer Lhotse; fix bug with pruned_loss * Add norm+balancer to VggSubsampling * Incorporate changes from master into pruned_transducer_stateless2. * Add max-abs=6, debugged version * Change 0.025,0.05 to 0.01 in initializations * Fix balancer code * Whitespace fix * Reduce initial pruned_loss scale from 0.01 to 0.0 * Increase warm_step (and valid_interval) * Change max-abs from 6 to 10 * Change how warmup works. * Add changes from master to decode.py, train.py * Simplify the warmup code; max_abs 10->6 * Make warmup work by scaling layer contributions; leave residual layer-drop * Fix bug * Fix test mode with random layer dropout * Add random-number-setting function in dataloader * Fix/patch how fix_random_seed() is imported. * Reduce layer-drop prob * Reduce layer-drop prob after warmup to 1 in 100 * Change power of lr-schedule from -0.5 to -0.333 * Increase model_warm_step to 4k * Change max-keep-prob to 0.95 * Refactoring and simplifying conformer and frontend * Rework conformer, remove some code. * Reduce 1st conv channels from 64 to 32 * Add another convolutional layer * Fix padding bug * Remove dropout in output layer * Reduce speed of some components * Initial refactoring to remove unnecessary vocab_size * Fix RE identity * Bug-fix * Add final dropout to conformer * Remove some un-used code * Replace nn.Linear with ScaledLinear in simple joiner * Make 2 projections.. * Reduce initial_speed * Use initial_speed=0.5 * Reduce initial_speed further from 0.5 to 0.25 * Reduce initial_speed from 0.5 to 0.25 * Change how warmup is applied. * Bug fix to warmup_scale * Fix test-mode * Remove final dropout * Make layer dropout rate 0.075, was 0.1. * First draft of model rework * Various bug fixes * Change learning speed of simple_lm_proj * Revert transducer_stateless/ to state in upstream/master * Fix to joiner to allow different dims * Some cleanups * Make training more efficient, avoid redoing some projections. * Change how warm-step is set * First draft of new approach to learning rates + init * Some fixes.. * Change initialization to 0.25 * Fix type of parameter * Fix weight decay formula by adding 1/1-beta * Fix weight decay formula by adding 1/1-beta * Fix checkpoint-writing * Fix to reading scheudler from optim * Simplified optimizer, rework somet things.. * Reduce model_warm_step from 4k to 3k * Fix bug in lambda * Bug-fix RE sign of target_rms * Changing initial_speed from 0.25 to 01 * Change some defaults in LR-setting rule. * Remove initial_speed * Set new scheduler * Change exponential part of lrate to be epoch based * Fix bug * Set 2n rule.. * Implement 2o schedule * Make lrate rule more symmetric * Implement 2p version of learning rate schedule. * Refactor how learning rate is set. * Fix import * Modify init (#301) * update icefall/__init__.py to import more common functions. * update icefall/__init__.py * make imports style consistent. * exclude black check for icefall/__init__.py in pyproject.toml. * Minor fixes for logging (#296) * Minor fixes for logging * Minor fix * Fix dir names * Modify beam search to be efficient with current joienr * Fix adding learning rate to tensorboard * Fix docs in optim.py * Support mix precision training on the reworked model (#305) * Add mix precision support * Minor fixes * Minor fixes * Minor fixes * Tedlium3 pruned transducer stateless (#261) * update tedlium3-pruned-transducer-stateless-codes * update README.md * update README.md * add fast beam search for decoding * do a change for RESULTS.md * do a change for RESULTS.md * do a fix * do some changes for pruned RNN-T * Add mix precision support * Minor fixes * Minor fixes * Updating RESULTS.md; fix in beam_search.py * Fix rebase * Code style check for librispeech pruned transducer stateless2 (#308) * Update results for tedlium3 pruned RNN-T (#307) * Update README.md * Fix CI errors. (#310) * Add more results * Fix tensorboard log location * Add one more epoch of full expt * fix comments * Add results for mixed precision with max-duration 300 * Changes for pretrained.py (tedlium3 pruned RNN-T) (#311) * GigaSpeech recipe (#120) * initial commit * support download, data prep, and fbank * on-the-fly feature extraction by default * support BPE based lang * support HLG for BPE * small fix * small fix * chunked feature extraction by default * Compute features for GigaSpeech by splitting the manifest. * Fixes after review. * Split manifests into 2000 pieces. * set audio duration mismatch tolerance to 0.01 * small fix * add conformer training recipe * Add conformer.py without pre-commit checking * lazy loading and use SingleCutSampler * DynamicBucketingSampler * use KaldifeatFbank to compute fbank for musan * use pretrained language model and lexicon * use 3gram to decode, 4gram to rescore * Add decode.py * Update .flake8 * Delete compute_fbank_gigaspeech.py * Use BucketingSampler for valid and test dataloader * Update params in train.py * Use bpe_500 * update params in decode.py * Decrease num_paths while CUDA OOM * Added README * Update RESULTS * black * Decrease num_paths while CUDA OOM * Decode with post-processing * Update results * Remove lazy_load option * Use default `storage_type` * Keep the original tolerance * Use split-lazy * black * Update pretrained model Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com> * Add LG decoding (#277) * Add LG decoding * Add log weight pushing * Minor fixes * Support computing RNN-T loss with torchaudio (#316) * Update results for torchaudio RNN-T. (#322) * Fix some typos. (#329) * fix fp16 option in example usage (#332) * Support averaging models with weight tying. (#333) * Support specifying iteration number of checkpoints for decoding. (#336) See also #289 * Modified conformer with multi datasets (#312) * Copy files for editing. * Use librispeech + gigaspeech with modified conformer. * Support specifying number of workers for on-the-fly feature extraction. * Feature extraction code for GigaSpeech. * Combine XL splits lazily during training. * Fix warnings in decoding. * Add decoding code for GigaSpeech. * Fix decoding the gigaspeech dataset. We have to use the decoder/joiner networks for the GigaSpeech dataset. * Disable speed perturbe for XL subset. * Compute the Nbest oracle WER for RNN-T decoding. * Minor fixes. * Minor fixes. * Add results. * Update results. * Update CI. * Update results. * Fix style issues. * Update results. * Fix style issues. * Update results. (#340) * Update results. * Typo fixes. * Validate generated manifest files. (#338) * Validate generated manifest files. (#338) * Save batch to disk on OOM. (#343) * Save batch to disk on OOM. * minor fixes * Fixes after review. * Fix style issues. * Fix decoding for gigaspeech in the libri + giga setup. (#345) * Model average (#344) * First upload of model average codes. * minor fix * update decode file * update .flake8 * rename pruned_transducer_stateless3 to pruned_transducer_stateless4 * change epoch number counter starting from 1 instead of 0 * minor fix of pruned_transducer_stateless4/train.py * refactor the checkpoint.py * minor fix, update docs, and modify the epoch number to count from 1 in the pruned_transducer_stateless4/decode.py * update author info * add docs of the scaling in function average_checkpoints_with_averaged_model * Save batch to disk on exception. (#350) * Bug fix (#352) * Keep model_avg on cpu (#348) * keep model_avg on cpu * explicitly convert model_avg to cpu * minor fix * remove device convertion for model_avg * modify usage of the model device in train.py * change model.device to next(model.parameters()).device for decoding * assert params.start_epoch>0 * assert params.start_epoch>0, params.start_epoch * Do some changes for aishell/ASR/transducer stateless/export.py (#347) * do some changes for aishell/ASR/transducer_stateless/export.py * Support decoding with averaged model when using --iter (#353) * support decoding with averaged model when using --iter * minor fix * monir fix of copyright date * Stringify torch.__version__ before serializing it. (#354) * Run decode.py in GitHub actions. (#356) * Ignore padding frames during RNN-T decoding. (#358) * Ignore padding frames during RNN-T decoding. * Fix outdated decoding code. * Minor fixes. * Support --iter in export.py (#360) * GigaSpeech RNN-T experiments (#318) * Copy RNN-T recipe from librispeech * flake8 * flake8 * Update params * gigaspeech decode * black * Update results * syntax highlight * Update RESULTS.md * typo * Update decoding script for gigaspeech and remove duplicate files. (#361) * Validate that there are no OOV tokens in BPE-based lexicons. (#359) * Validate that there are no OOV tokens in BPE-based lexicons. * Typo fixes. * Decode gigaspeech in GitHub actions (#362) * Add CI for gigaspeech. * Update results for libri+giga multi dataset setup. (#363) * Update results for libri+giga multi dataset setup. * Update GigaSpeech reults (#364) * Update decode.py * Update export.py * Update results * Update README.md * Fix GitHub CI for decoding GigaSpeech dev/test datasets (#366) * modify .flake8 * minor fix * minor fix Co-authored-by: Daniel Povey <dpovey@gmail.com> Co-authored-by: Wei Kang <wkang@pku.org.cn> Co-authored-by: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com> Co-authored-by: Guo Liyong <guonwpu@qq.com> Co-authored-by: Wang, Guanbo <wgb14@outlook.com> Co-authored-by: whsqkaak <whsqkaak@naver.com> Co-authored-by: pehonnet <pe.honnet@gmail.com>
716 lines
23 KiB
Python
Executable File
716 lines
23 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
|
|
# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import argparse
|
|
import logging
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import k2
|
|
import sentencepiece as spm
|
|
import torch
|
|
import torch.nn as nn
|
|
from asr_datamodule import GigaSpeechAsrDataModule
|
|
from conformer import Conformer
|
|
from gigaspeech_scoring import asr_text_post_processing
|
|
|
|
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
|
from icefall.decode import (
|
|
get_lattice,
|
|
nbest_decoding,
|
|
nbest_oracle,
|
|
one_best_decoding,
|
|
rescore_with_attention_decoder,
|
|
rescore_with_n_best_list,
|
|
rescore_with_whole_lattice,
|
|
)
|
|
from icefall.env import get_env_info
|
|
from icefall.lexicon import Lexicon
|
|
from icefall.utils import (
|
|
AttributeDict,
|
|
get_texts,
|
|
setup_logger,
|
|
store_transcripts,
|
|
write_error_stats,
|
|
)
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--epoch",
|
|
type=int,
|
|
default=0,
|
|
help="It specifies the checkpoint to use for decoding."
|
|
"Note: Epoch counts from 0.",
|
|
)
|
|
parser.add_argument(
|
|
"--avg",
|
|
type=int,
|
|
default=1,
|
|
help="Number of checkpoints to average. Automatically select "
|
|
"consecutive checkpoints before the checkpoint specified by "
|
|
"'--epoch'. ",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--method",
|
|
type=str,
|
|
default="attention-decoder",
|
|
help="""Decoding method.
|
|
Supported values are:
|
|
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
|
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
|
It needs neither a lexicon nor an n-gram LM.
|
|
- (1) 1best. Extract the best path from the decoding lattice as the
|
|
decoding result.
|
|
- (2) nbest. Extract n paths from the decoding lattice; the path
|
|
with the highest score is the decoding result.
|
|
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
|
|
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
|
|
the highest score is the decoding result.
|
|
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
|
|
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
|
|
is the decoding result.
|
|
- (5) attention-decoder. Extract n paths from the LM rescored
|
|
lattice, the path with the highest score is the decoding result.
|
|
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
|
rescoring method can achieve. Useful for debugging n-best
|
|
rescoring method.
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--num-paths",
|
|
type=int,
|
|
default=1000,
|
|
help="""Number of paths for n-best based decoding method.
|
|
Used only when "method" is one of the following values:
|
|
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--nbest-scale",
|
|
type=float,
|
|
default=0.5,
|
|
help="""The scale to be applied to `lattice.scores`.
|
|
It's needed if you use any kinds of n-best based rescoring.
|
|
Used only when "method" is one of the following values:
|
|
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
|
A smaller value results in more unique paths.
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--exp-dir",
|
|
type=str,
|
|
default="conformer_ctc/exp",
|
|
help="The experiment dir",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--lang-dir",
|
|
type=str,
|
|
default="data/lang_bpe_500",
|
|
help="The lang dir",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--lm-dir",
|
|
type=str,
|
|
default="data/lm",
|
|
help="""The LM dir.
|
|
It should contain either G_4_gram.pt or G_4_gram.fst.txt
|
|
""",
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
def get_params() -> AttributeDict:
|
|
params = AttributeDict(
|
|
{
|
|
# parameters for conformer
|
|
"subsampling_factor": 4,
|
|
"vgg_frontend": False,
|
|
"use_feat_batchnorm": True,
|
|
"feature_dim": 80,
|
|
"nhead": 8,
|
|
"attention_dim": 512,
|
|
"num_decoder_layers": 6,
|
|
# parameters for decoding
|
|
"search_beam": 20,
|
|
"output_beam": 8,
|
|
"min_active_states": 30,
|
|
"max_active_states": 10000,
|
|
"use_double_scores": True,
|
|
"env_info": get_env_info(),
|
|
}
|
|
)
|
|
return params
|
|
|
|
|
|
def post_processing(
|
|
results: List[Tuple[List[str], List[str]]],
|
|
) -> List[Tuple[List[str], List[str]]]:
|
|
new_results = []
|
|
for ref, hyp in results:
|
|
new_ref = asr_text_post_processing(" ".join(ref)).split()
|
|
new_hyp = asr_text_post_processing(" ".join(hyp)).split()
|
|
new_results.append((new_ref, new_hyp))
|
|
return new_results
|
|
|
|
|
|
def decode_one_batch(
|
|
params: AttributeDict,
|
|
model: nn.Module,
|
|
HLG: Optional[k2.Fsa],
|
|
H: Optional[k2.Fsa],
|
|
bpe_model: Optional[spm.SentencePieceProcessor],
|
|
batch: dict,
|
|
word_table: k2.SymbolTable,
|
|
sos_id: int,
|
|
eos_id: int,
|
|
G: Optional[k2.Fsa] = None,
|
|
) -> Dict[str, List[List[str]]]:
|
|
"""Decode one batch and return the result in a dict. The dict has the
|
|
following format:
|
|
|
|
- key: It indicates the setting used for decoding. For example,
|
|
if no rescoring is used, the key is the string `no_rescore`.
|
|
If LM rescoring is used, the key is the string `lm_scale_xxx`,
|
|
where `xxx` is the value of `lm_scale`. An example key is
|
|
`lm_scale_0.7`
|
|
- value: It contains the decoding result. `len(value)` equals to
|
|
batch size. `value[i]` is the decoding result for the i-th
|
|
utterance in the given batch.
|
|
Args:
|
|
params:
|
|
It's the return value of :func:`get_params`.
|
|
|
|
- params.method is "1best", it uses 1best decoding without LM rescoring.
|
|
- params.method is "nbest", it uses nbest decoding without LM rescoring.
|
|
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
|
|
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
|
|
rescoring.
|
|
|
|
model:
|
|
The neural model.
|
|
HLG:
|
|
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
|
H:
|
|
The ctc topo. Used only when params.method is ctc-decoding.
|
|
bpe_model:
|
|
The BPE model. Used only when params.method is ctc-decoding.
|
|
batch:
|
|
It is the return value from iterating
|
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
|
for the format of the `batch`.
|
|
word_table:
|
|
The word symbol table.
|
|
sos_id:
|
|
The token ID of the SOS.
|
|
eos_id:
|
|
The token ID of the EOS.
|
|
G:
|
|
An LM. It is not None when params.method is "nbest-rescoring"
|
|
or "whole-lattice-rescoring". In general, the G in HLG
|
|
is a 3-gram LM, while this G is a 4-gram LM.
|
|
Returns:
|
|
Return the decoding result. See above description for the format of
|
|
the returned dict. Note: If it decodes to nothing, then return None.
|
|
"""
|
|
if HLG is not None:
|
|
device = HLG.device
|
|
else:
|
|
device = H.device
|
|
feature = batch["inputs"]
|
|
assert feature.ndim == 3
|
|
feature = feature.to(device)
|
|
# at entry, feature is (N, T, C)
|
|
|
|
supervisions = batch["supervisions"]
|
|
|
|
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
|
|
# nnet_output is (N, T, C)
|
|
|
|
supervision_segments = torch.stack(
|
|
(
|
|
supervisions["sequence_idx"],
|
|
supervisions["start_frame"] // params.subsampling_factor,
|
|
supervisions["num_frames"] // params.subsampling_factor,
|
|
),
|
|
1,
|
|
).to(torch.int32)
|
|
|
|
if H is None:
|
|
assert HLG is not None
|
|
decoding_graph = HLG
|
|
else:
|
|
assert HLG is None
|
|
assert bpe_model is not None
|
|
decoding_graph = H
|
|
|
|
lattice = get_lattice(
|
|
nnet_output=nnet_output,
|
|
decoding_graph=decoding_graph,
|
|
supervision_segments=supervision_segments,
|
|
search_beam=params.search_beam,
|
|
output_beam=params.output_beam,
|
|
min_active_states=params.min_active_states,
|
|
max_active_states=params.max_active_states,
|
|
subsampling_factor=params.subsampling_factor,
|
|
)
|
|
|
|
if params.method == "ctc-decoding":
|
|
best_path = one_best_decoding(
|
|
lattice=lattice, use_double_scores=params.use_double_scores
|
|
)
|
|
# Note: `best_path.aux_labels` contains token IDs, not word IDs
|
|
# since we are using H, not HLG here.
|
|
#
|
|
# token_ids is a lit-of-list of IDs
|
|
token_ids = get_texts(best_path)
|
|
|
|
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
|
hyps = bpe_model.decode(token_ids)
|
|
|
|
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
|
hyps = [s.split() for s in hyps]
|
|
key = "ctc-decoding"
|
|
return {key: hyps}
|
|
|
|
if params.method == "nbest-oracle":
|
|
# Note: You can also pass rescored lattices to it.
|
|
# We choose the HLG decoded lattice for speed reasons
|
|
# as HLG decoding is faster and the oracle WER
|
|
# is only slightly worse than that of rescored lattices.
|
|
best_path = nbest_oracle(
|
|
lattice=lattice,
|
|
num_paths=params.num_paths,
|
|
ref_texts=supervisions["text"],
|
|
word_table=word_table,
|
|
nbest_scale=params.nbest_scale,
|
|
oov="<UNK>",
|
|
)
|
|
hyps = get_texts(best_path)
|
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
|
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
|
|
return {key: hyps}
|
|
|
|
if params.method in ["1best", "nbest"]:
|
|
if params.method == "1best":
|
|
best_path = one_best_decoding(
|
|
lattice=lattice, use_double_scores=params.use_double_scores
|
|
)
|
|
key = "no_rescore"
|
|
else:
|
|
best_path = nbest_decoding(
|
|
lattice=lattice,
|
|
num_paths=params.num_paths,
|
|
use_double_scores=params.use_double_scores,
|
|
nbest_scale=params.nbest_scale,
|
|
)
|
|
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
|
|
|
|
hyps = get_texts(best_path)
|
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
|
return {key: hyps}
|
|
|
|
assert params.method in [
|
|
"nbest-rescoring",
|
|
"whole-lattice-rescoring",
|
|
"attention-decoder",
|
|
]
|
|
|
|
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
|
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
|
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
|
|
|
if params.method == "nbest-rescoring":
|
|
best_path_dict = rescore_with_n_best_list(
|
|
lattice=lattice,
|
|
G=G,
|
|
num_paths=params.num_paths,
|
|
lm_scale_list=lm_scale_list,
|
|
nbest_scale=params.nbest_scale,
|
|
)
|
|
elif params.method == "whole-lattice-rescoring":
|
|
best_path_dict = rescore_with_whole_lattice(
|
|
lattice=lattice,
|
|
G_with_epsilon_loops=G,
|
|
lm_scale_list=lm_scale_list,
|
|
)
|
|
elif params.method == "attention-decoder":
|
|
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
|
rescored_lattice = rescore_with_whole_lattice(
|
|
lattice=lattice,
|
|
G_with_epsilon_loops=G,
|
|
lm_scale_list=None,
|
|
)
|
|
# TODO: pass `lattice` instead of `rescored_lattice` to
|
|
# `rescore_with_attention_decoder`
|
|
|
|
best_path_dict = rescore_with_attention_decoder(
|
|
lattice=rescored_lattice,
|
|
num_paths=params.num_paths,
|
|
model=model,
|
|
memory=memory,
|
|
memory_key_padding_mask=memory_key_padding_mask,
|
|
sos_id=sos_id,
|
|
eos_id=eos_id,
|
|
nbest_scale=params.nbest_scale,
|
|
)
|
|
else:
|
|
assert False, f"Unsupported decoding method: {params.method}"
|
|
|
|
ans = dict()
|
|
if best_path_dict is not None:
|
|
for lm_scale_str, best_path in best_path_dict.items():
|
|
hyps = get_texts(best_path)
|
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
|
ans[lm_scale_str] = hyps
|
|
else:
|
|
ans = None
|
|
return ans
|
|
|
|
|
|
def decode_dataset(
|
|
dl: torch.utils.data.DataLoader,
|
|
params: AttributeDict,
|
|
model: nn.Module,
|
|
HLG: Optional[k2.Fsa],
|
|
H: Optional[k2.Fsa],
|
|
bpe_model: Optional[spm.SentencePieceProcessor],
|
|
word_table: k2.SymbolTable,
|
|
sos_id: int,
|
|
eos_id: int,
|
|
G: Optional[k2.Fsa] = None,
|
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
|
"""Decode dataset.
|
|
|
|
Args:
|
|
dl:
|
|
PyTorch's dataloader containing the dataset to decode.
|
|
params:
|
|
It is returned by :func:`get_params`.
|
|
model:
|
|
The neural model.
|
|
HLG:
|
|
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
|
H:
|
|
The ctc topo. Used only when params.method is ctc-decoding.
|
|
bpe_model:
|
|
The BPE model. Used only when params.method is ctc-decoding.
|
|
word_table:
|
|
It is the word symbol table.
|
|
sos_id:
|
|
The token ID for SOS.
|
|
eos_id:
|
|
The token ID for EOS.
|
|
G:
|
|
An LM. It is not None when params.method is "nbest-rescoring"
|
|
or "whole-lattice-rescoring". In general, the G in HLG
|
|
is a 3-gram LM, while this G is a 4-gram LM.
|
|
Returns:
|
|
Return a dict, whose key may be "no-rescore" if no LM rescoring
|
|
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
|
|
Its value is a list of tuples. Each tuple contains two elements:
|
|
The first is the reference transcript, and the second is the
|
|
predicted result.
|
|
"""
|
|
num_cuts = 0
|
|
|
|
try:
|
|
num_batches = len(dl)
|
|
except TypeError:
|
|
num_batches = "?"
|
|
|
|
results = defaultdict(list)
|
|
for batch_idx, batch in enumerate(dl):
|
|
texts = batch["supervisions"]["text"]
|
|
|
|
hyps_dict = decode_one_batch(
|
|
params=params,
|
|
model=model,
|
|
HLG=HLG,
|
|
H=H,
|
|
bpe_model=bpe_model,
|
|
batch=batch,
|
|
word_table=word_table,
|
|
G=G,
|
|
sos_id=sos_id,
|
|
eos_id=eos_id,
|
|
)
|
|
|
|
if hyps_dict is not None:
|
|
for lm_scale, hyps in hyps_dict.items():
|
|
this_batch = []
|
|
assert len(hyps) == len(texts)
|
|
for hyp_words, ref_text in zip(hyps, texts):
|
|
ref_words = ref_text.split()
|
|
this_batch.append((ref_words, hyp_words))
|
|
|
|
results[lm_scale].extend(this_batch)
|
|
else:
|
|
assert (
|
|
len(results) > 0
|
|
), "It should not decode to empty in the first batch!"
|
|
this_batch = []
|
|
hyp_words = []
|
|
for ref_text in texts:
|
|
ref_words = ref_text.split()
|
|
this_batch.append((ref_words, hyp_words))
|
|
|
|
for lm_scale in results.keys():
|
|
results[lm_scale].extend(this_batch)
|
|
|
|
num_cuts += len(texts)
|
|
|
|
if batch_idx % 100 == 0:
|
|
batch_str = f"{batch_idx}/{num_batches}"
|
|
|
|
logging.info(
|
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
|
)
|
|
return results
|
|
|
|
|
|
def save_results(
|
|
params: AttributeDict,
|
|
test_set_name: str,
|
|
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
|
|
):
|
|
if params.method == "attention-decoder":
|
|
# Set it to False since there are too many logs.
|
|
enable_log = False
|
|
else:
|
|
enable_log = True
|
|
test_set_wers = dict()
|
|
for key, results in results_dict.items():
|
|
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
|
results = post_processing(results)
|
|
store_transcripts(filename=recog_path, texts=results)
|
|
if enable_log:
|
|
logging.info(f"The transcripts are stored in {recog_path}")
|
|
|
|
# The following prints out WERs, per-word error statistics and aligned
|
|
# ref/hyp pairs.
|
|
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
|
|
with open(errs_filename, "w") as f:
|
|
wer = write_error_stats(
|
|
f, f"{test_set_name}-{key}", results, enable_log=enable_log
|
|
)
|
|
test_set_wers[key] = wer
|
|
|
|
if enable_log:
|
|
logging.info(
|
|
"Wrote detailed error stats to {}".format(errs_filename)
|
|
)
|
|
|
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
|
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
|
|
with open(errs_info, "w") as f:
|
|
print("settings\tWER", file=f)
|
|
for key, val in test_set_wers:
|
|
print("{}\t{}".format(key, val), file=f)
|
|
|
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
|
note = "\tbest for {}".format(test_set_name)
|
|
for key, val in test_set_wers:
|
|
s += "{}\t{}{}\n".format(key, val, note)
|
|
note = ""
|
|
logging.info(s)
|
|
|
|
|
|
@torch.no_grad()
|
|
def main():
|
|
parser = get_parser()
|
|
GigaSpeechAsrDataModule.add_arguments(parser)
|
|
args = parser.parse_args()
|
|
args.exp_dir = Path(args.exp_dir)
|
|
args.lang_dir = Path(args.lang_dir)
|
|
args.lm_dir = Path(args.lm_dir)
|
|
|
|
params = get_params()
|
|
params.update(vars(args))
|
|
|
|
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
|
|
logging.info("Decoding started")
|
|
logging.info(params)
|
|
|
|
lexicon = Lexicon(params.lang_dir)
|
|
max_token_id = max(lexicon.tokens)
|
|
num_classes = max_token_id + 1 # +1 for the blank
|
|
|
|
device = torch.device("cpu")
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda", 0)
|
|
|
|
logging.info(f"device: {device}")
|
|
|
|
graph_compiler = BpeCtcTrainingGraphCompiler(
|
|
params.lang_dir,
|
|
device=device,
|
|
sos_token="<sos/eos>",
|
|
eos_token="<sos/eos>",
|
|
)
|
|
sos_id = graph_compiler.sos_id
|
|
eos_id = graph_compiler.eos_id
|
|
|
|
if params.method == "ctc-decoding":
|
|
HLG = None
|
|
H = k2.ctc_topo(
|
|
max_token=max_token_id,
|
|
modified=False,
|
|
device=device,
|
|
)
|
|
bpe_model = spm.SentencePieceProcessor()
|
|
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
|
else:
|
|
H = None
|
|
bpe_model = None
|
|
HLG = k2.Fsa.from_dict(
|
|
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
|
)
|
|
assert HLG.requires_grad is False
|
|
|
|
if not hasattr(HLG, "lm_scores"):
|
|
HLG.lm_scores = HLG.scores.clone()
|
|
|
|
if params.method in (
|
|
"nbest-rescoring",
|
|
"whole-lattice-rescoring",
|
|
"attention-decoder",
|
|
):
|
|
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
|
logging.info("Loading G_4_gram.fst.txt")
|
|
logging.warning("It may take 8 minutes.")
|
|
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
|
|
first_word_disambig_id = lexicon.word_table["#0"]
|
|
|
|
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
|
# G.aux_labels is not needed in later computations, so
|
|
# remove it here.
|
|
del G.aux_labels
|
|
# CAUTION: The following line is crucial.
|
|
# Arcs entering the back-off state have label equal to #0.
|
|
# We have to change it to 0 here.
|
|
G.labels[G.labels >= first_word_disambig_id] = 0
|
|
# See https://github.com/k2-fsa/k2/issues/874
|
|
# for why we need to set G.properties to None
|
|
G.__dict__["_properties"] = None
|
|
G = k2.Fsa.from_fsas([G]).to(device)
|
|
G = k2.arc_sort(G)
|
|
# Save a dummy value so that it can be loaded in C++.
|
|
# See https://github.com/pytorch/pytorch/issues/67902
|
|
# for why we need to do this.
|
|
G.dummy = 1
|
|
|
|
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
|
else:
|
|
logging.info("Loading pre-compiled G_4_gram.pt")
|
|
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
|
G = k2.Fsa.from_dict(d)
|
|
|
|
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
|
# Add epsilon self-loops to G as we will compose
|
|
# it with the whole lattice later
|
|
G = k2.add_epsilon_self_loops(G)
|
|
G = k2.arc_sort(G)
|
|
G = G.to(device)
|
|
|
|
# G.lm_scores is used to replace HLG.lm_scores during
|
|
# LM rescoring.
|
|
G.lm_scores = G.scores.clone()
|
|
else:
|
|
G = None
|
|
|
|
model = Conformer(
|
|
num_features=params.feature_dim,
|
|
nhead=params.nhead,
|
|
d_model=params.attention_dim,
|
|
num_classes=num_classes,
|
|
subsampling_factor=params.subsampling_factor,
|
|
num_decoder_layers=params.num_decoder_layers,
|
|
vgg_frontend=params.vgg_frontend,
|
|
use_feat_batchnorm=params.use_feat_batchnorm,
|
|
)
|
|
|
|
if params.avg == 1:
|
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
|
else:
|
|
start = params.epoch - params.avg + 1
|
|
filenames = []
|
|
for i in range(start, params.epoch + 1):
|
|
if start >= 0:
|
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
|
logging.info(f"averaging {filenames}")
|
|
model.to(device)
|
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
|
|
|
model.to(device)
|
|
model.eval()
|
|
num_param = sum([p.numel() for p in model.parameters()])
|
|
logging.info(f"Number of model parameters: {num_param}")
|
|
|
|
gigaspeech = GigaSpeechAsrDataModule(args)
|
|
|
|
dev_cuts = gigaspeech.dev_cuts()
|
|
test_cuts = gigaspeech.test_cuts()
|
|
|
|
dev_dl = gigaspeech.test_dataloaders(dev_cuts)
|
|
test_dl = gigaspeech.test_dataloaders(test_cuts)
|
|
|
|
test_sets = ["dev", "test"]
|
|
test_dls = [dev_dl, test_dl]
|
|
|
|
for test_set, test_dl in zip(test_sets, test_dls):
|
|
results_dict = decode_dataset(
|
|
dl=test_dl,
|
|
params=params,
|
|
model=model,
|
|
HLG=HLG,
|
|
H=H,
|
|
bpe_model=bpe_model,
|
|
word_table=lexicon.word_table,
|
|
G=G,
|
|
sos_id=sos_id,
|
|
eos_id=eos_id,
|
|
)
|
|
|
|
save_results(
|
|
params=params, test_set_name=test_set, results_dict=results_dict
|
|
)
|
|
|
|
logging.info("Done!")
|
|
|
|
|
|
torch.set_num_threads(1)
|
|
torch.set_num_interop_threads(1)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|