update multi_quantization installation (#469)

* update multi_quantization installation

* Update egs/librispeech/ASR/pruned_transducer_stateless6/train.py

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
This commit is contained in:
LIyong.Guo 2022-07-13 21:16:45 +08:00 committed by GitHub
parent bc2882ddcc
commit f8d28f0998
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 12 additions and 5 deletions

View File

@ -77,9 +77,9 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ] && [ ! "$use_extracted_codebook" ==
fi
# Install quantization toolkit:
# pip install git+https://github.com/danpovey/quantization.git@master
# when testing this code:
# commit c17ffe67aa2e6ca6b6855c50fde812f2eed7870b is used.
# pip install git+https://github.com/k2-fsa/multi_quantization.git
# or
# pip install multi_quantization
has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)")
if [ $has_quantization == 'False' ]; then

View File

@ -23,7 +23,7 @@ from scaling import ScaledLinear
from icefall.utils import add_sos
from quantization.prediction import JointCodebookLoss
from multi_quantization.prediction import JointCodebookLoss
class Transducer(nn.Module):
@ -75,7 +75,9 @@ class Transducer(nn.Module):
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
if num_codebooks > 0:
self.codebook_loss_net = JointCodebookLoss(
predictor_channels=encoder_dim, num_codebooks=num_codebooks
predictor_channels=encoder_dim,
num_codebooks=num_codebooks,
is_joint=False,
)
def forward(

View File

@ -879,6 +879,11 @@ def run(rank, world_size, args):
The return value of get_parser().parse_args()
"""
params = get_params()
# Note: it's better to set --spec-aug-time-warpi-factor=-1
# when doing distillation with vq.
assert args.spec_aug_time_warp_factor < 1
params.update(vars(args))
if params.full_libri is False:
params.valid_interval = 1600