diff --git a/egs/librispeech/ASR/codebook_index_extraction.sh b/egs/librispeech/ASR/codebook_index_extraction.sh index b962d0081..91e780c08 100644 --- a/egs/librispeech/ASR/codebook_index_extraction.sh +++ b/egs/librispeech/ASR/codebook_index_extraction.sh @@ -98,7 +98,7 @@ if [ $stage -eq 4 ]; then # Note: order of split manfiests is 1-based, while gpu is 0-based. export CUDA_VISIBLE_DEVICES=`(expr $1 + 5)` ./vq_pruned_transducer_stateless2/hubert_code_indices.py \ - --memory-layer=${memory_layer} + --memory-layer=${memory_layer} \ --num-splits $num_jobs \ --subset=$2 \ --manifest-idx $1 \ @@ -117,7 +117,7 @@ if [ $stage -eq 4 ]; then wait fi -cdidx_manifests_dir=`pwd`/data/globalrandom-scaledquantizer-refine_iter-5-${num_utts}-$model_id-${mem_layer}layer-${quantizer_id}-bytes_per_frame-${bytes_per_frame}-enable-refine-True +cdidx_manifests_dir=`pwd`/data/globalrandom-scaledquantizer-refine_iter-5-${num_utts}-$model_id-${memory_layer}layer-${quantizer_id}-bytes_per_frame-${bytes_per_frame}-enable-refine-True if [ $stage -eq 5 ]; then for subset in ${train_subsets}; do combined_list=`find $cdidx_manifests_dir/splits$num_jobs/ -name cuts_train-${sbuset}*` diff --git a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/decode.py index bf4b2c248..085a27ce2 100755 --- a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/decode.py @@ -444,6 +444,7 @@ def main(): params = get_params() params.update(vars(args)) + params.extra_output_layer=None assert params.decoding_method in ( "greedy_search", diff --git a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_decode.py b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_decode.py index b4a290525..5e2552d22 100755 --- a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_decode.py +++ b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_decode.py @@ -159,11 +159,11 @@ def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) params = AttributeDict() params.update(vars(args)) params.update(vq_config) + params.exp_dir = Path(params.exp_dir) setup_logger(f"{params.exp_dir}/log-ctc_greedy_search/log-decode") logging.info("Decoding started") diff --git a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_memory_embeddings.py b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_memory_embeddings.py index efe7f39a6..9805cfd76 100755 --- a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_memory_embeddings.py +++ b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_memory_embeddings.py @@ -29,15 +29,12 @@ from icefall.utils import ( setup_logger, ) -from hubert_utils import extract_layers_result, load_hubert_model, vq_config - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - return parser - +from hubert_utils import ( + extract_layers_result, + load_hubert_model, + get_parser, + vq_config, +) def compute_memory( model: torch.nn.Module, diff --git a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_utils.py b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_utils.py index 31cf28219..67721ad3f 100644 --- a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_utils.py +++ b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_utils.py @@ -33,6 +33,7 @@ from omegaconf import OmegaConf vq_config = { # TODO: Maybe better to convert this class to yaml driven config. # parameters about hubert model inference. + "exp_dir": "./vq_pruned_transducer_stateless2/exp/", "model_dir": "./vq_pruned_transducer_stateless2/exp/hubert_models/", "input_strategy": "AudioSamples", "enable_spec_aug": False,