mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
update multi_quantization installation (#469)
* update multi_quantization installation * Update egs/librispeech/ASR/pruned_transducer_stateless6/train.py Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com> Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
This commit is contained in:
parent
bc2882ddcc
commit
f8d28f0998
@ -77,9 +77,9 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ] && [ ! "$use_extracted_codebook" ==
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
# Install quantization toolkit:
|
# Install quantization toolkit:
|
||||||
# pip install git+https://github.com/danpovey/quantization.git@master
|
# pip install git+https://github.com/k2-fsa/multi_quantization.git
|
||||||
# when testing this code:
|
# or
|
||||||
# commit c17ffe67aa2e6ca6b6855c50fde812f2eed7870b is used.
|
# pip install multi_quantization
|
||||||
|
|
||||||
has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)")
|
has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)")
|
||||||
if [ $has_quantization == 'False' ]; then
|
if [ $has_quantization == 'False' ]; then
|
||||||
|
@ -23,7 +23,7 @@ from scaling import ScaledLinear
|
|||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos
|
||||||
|
|
||||||
from quantization.prediction import JointCodebookLoss
|
from multi_quantization.prediction import JointCodebookLoss
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
@ -75,7 +75,9 @@ class Transducer(nn.Module):
|
|||||||
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
||||||
if num_codebooks > 0:
|
if num_codebooks > 0:
|
||||||
self.codebook_loss_net = JointCodebookLoss(
|
self.codebook_loss_net = JointCodebookLoss(
|
||||||
predictor_channels=encoder_dim, num_codebooks=num_codebooks
|
predictor_channels=encoder_dim,
|
||||||
|
num_codebooks=num_codebooks,
|
||||||
|
is_joint=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -879,6 +879,11 @@ def run(rank, world_size, args):
|
|||||||
The return value of get_parser().parse_args()
|
The return value of get_parser().parse_args()
|
||||||
"""
|
"""
|
||||||
params = get_params()
|
params = get_params()
|
||||||
|
|
||||||
|
# Note: it's better to set --spec-aug-time-warpi-factor=-1
|
||||||
|
# when doing distillation with vq.
|
||||||
|
assert args.spec_aug_time_warp_factor < 1
|
||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
if params.full_libri is False:
|
if params.full_libri is False:
|
||||||
params.valid_interval = 1600
|
params.valid_interval = 1600
|
||||||
|
Loading…
x
Reference in New Issue
Block a user