From f8d28f09987d44f017a8741f8dac84357c1da6d6 Mon Sep 17 00:00:00 2001 From: "LIyong.Guo" <839019390@qq.com> Date: Wed, 13 Jul 2022 21:16:45 +0800 Subject: [PATCH] update multi_quantization installation (#469) * update multi_quantization installation * Update egs/librispeech/ASR/pruned_transducer_stateless6/train.py Co-authored-by: Fangjun Kuang Co-authored-by: Fangjun Kuang --- egs/librispeech/ASR/distillation_with_hubert.sh | 6 +++--- egs/librispeech/ASR/pruned_transducer_stateless6/model.py | 6 ++++-- egs/librispeech/ASR/pruned_transducer_stateless6/train.py | 5 +++++ 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/distillation_with_hubert.sh b/egs/librispeech/ASR/distillation_with_hubert.sh index 3d4c4856a..9c47e8eae 100755 --- a/egs/librispeech/ASR/distillation_with_hubert.sh +++ b/egs/librispeech/ASR/distillation_with_hubert.sh @@ -77,9 +77,9 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ] && [ ! "$use_extracted_codebook" == fi # Install quantization toolkit: - # pip install git+https://github.com/danpovey/quantization.git@master - # when testing this code: - # commit c17ffe67aa2e6ca6b6855c50fde812f2eed7870b is used. + # pip install git+https://github.com/k2-fsa/multi_quantization.git + # or + # pip install multi_quantization has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)") if [ $has_quantization == 'False' ]; then diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index 66bb33e8d..1ed5636c8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -23,7 +23,7 @@ from scaling import ScaledLinear from icefall.utils import add_sos -from quantization.prediction import JointCodebookLoss +from multi_quantization.prediction import JointCodebookLoss class Transducer(nn.Module): @@ -75,7 +75,9 @@ class Transducer(nn.Module): self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if num_codebooks > 0: self.codebook_loss_net = JointCodebookLoss( - predictor_channels=encoder_dim, num_codebooks=num_codebooks + predictor_channels=encoder_dim, + num_codebooks=num_codebooks, + is_joint=False, ) def forward( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index b904e1e59..c054527ca 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -879,6 +879,11 @@ def run(rank, world_size, args): The return value of get_parser().parse_args() """ params = get_params() + + # Note: it's better to set --spec-aug-time-warpi-factor=-1 + # when doing distillation with vq. + assert args.spec_aug_time_warp_factor < 1 + params.update(vars(args)) if params.full_libri is False: params.valid_interval = 1600