mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
update multi_quantization installation (#469)
* update multi_quantization installation * Update egs/librispeech/ASR/pruned_transducer_stateless6/train.py Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com> Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
This commit is contained in:
parent
bc2882ddcc
commit
f8d28f0998
@ -77,9 +77,9 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ] && [ ! "$use_extracted_codebook" ==
|
||||
fi
|
||||
|
||||
# Install quantization toolkit:
|
||||
# pip install git+https://github.com/danpovey/quantization.git@master
|
||||
# when testing this code:
|
||||
# commit c17ffe67aa2e6ca6b6855c50fde812f2eed7870b is used.
|
||||
# pip install git+https://github.com/k2-fsa/multi_quantization.git
|
||||
# or
|
||||
# pip install multi_quantization
|
||||
|
||||
has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)")
|
||||
if [ $has_quantization == 'False' ]; then
|
||||
|
@ -23,7 +23,7 @@ from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
|
||||
from quantization.prediction import JointCodebookLoss
|
||||
from multi_quantization.prediction import JointCodebookLoss
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -75,7 +75,9 @@ class Transducer(nn.Module):
|
||||
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
||||
if num_codebooks > 0:
|
||||
self.codebook_loss_net = JointCodebookLoss(
|
||||
predictor_channels=encoder_dim, num_codebooks=num_codebooks
|
||||
predictor_channels=encoder_dim,
|
||||
num_codebooks=num_codebooks,
|
||||
is_joint=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
@ -879,6 +879,11 @@ def run(rank, world_size, args):
|
||||
The return value of get_parser().parse_args()
|
||||
"""
|
||||
params = get_params()
|
||||
|
||||
# Note: it's better to set --spec-aug-time-warpi-factor=-1
|
||||
# when doing distillation with vq.
|
||||
assert args.spec_aug_time_warp_factor < 1
|
||||
|
||||
params.update(vars(args))
|
||||
if params.full_libri is False:
|
||||
params.valid_interval = 1600
|
||||
|
Loading…
x
Reference in New Issue
Block a user