This commit is contained in:
Guo Liyong 2022-05-07 16:18:40 +08:00
parent ac9655c450
commit 2dc86e2eda
5 changed files with 11 additions and 12 deletions

View File

@ -98,7 +98,7 @@ if [ $stage -eq 4 ]; then
# Note: order of split manfiests is 1-based, while gpu is 0-based.
export CUDA_VISIBLE_DEVICES=`(expr $1 + 5)`
./vq_pruned_transducer_stateless2/hubert_code_indices.py \
--memory-layer=${memory_layer}
--memory-layer=${memory_layer} \
--num-splits $num_jobs \
--subset=$2 \
--manifest-idx $1 \
@ -117,7 +117,7 @@ if [ $stage -eq 4 ]; then
wait
fi
cdidx_manifests_dir=`pwd`/data/globalrandom-scaledquantizer-refine_iter-5-${num_utts}-$model_id-${mem_layer}layer-${quantizer_id}-bytes_per_frame-${bytes_per_frame}-enable-refine-True
cdidx_manifests_dir=`pwd`/data/globalrandom-scaledquantizer-refine_iter-5-${num_utts}-$model_id-${memory_layer}layer-${quantizer_id}-bytes_per_frame-${bytes_per_frame}-enable-refine-True
if [ $stage -eq 5 ]; then
for subset in ${train_subsets}; do
combined_list=`find $cdidx_manifests_dir/splits$num_jobs/ -name cuts_train-${sbuset}*`

View File

@ -444,6 +444,7 @@ def main():
params = get_params()
params.update(vars(args))
params.extra_output_layer=None
assert params.decoding_method in (
"greedy_search",

View File

@ -159,11 +159,11 @@ def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = AttributeDict()
params.update(vars(args))
params.update(vq_config)
params.exp_dir = Path(params.exp_dir)
setup_logger(f"{params.exp_dir}/log-ctc_greedy_search/log-decode")
logging.info("Decoding started")

View File

@ -29,15 +29,12 @@ from icefall.utils import (
setup_logger,
)
from hubert_utils import extract_layers_result, load_hubert_model, vq_config
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
return parser
from hubert_utils import (
extract_layers_result,
load_hubert_model,
get_parser,
vq_config,
)
def compute_memory(
model: torch.nn.Module,

View File

@ -33,6 +33,7 @@ from omegaconf import OmegaConf
vq_config = {
# TODO: Maybe better to convert this class to yaml driven config.
# parameters about hubert model inference.
"exp_dir": "./vq_pruned_transducer_stateless2/exp/",
"model_dir": "./vq_pruned_transducer_stateless2/exp/hubert_models/",
"input_strategy": "AudioSamples",
"enable_spec_aug": False,