mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
add enable-distillation argument option, fix monir typos
This commit is contained in:
parent
f8541b3ab1
commit
7833da7d2b
@ -188,7 +188,9 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
--spec-aug-time-warp-factor -1 \
|
||||
--max-duration 300 \
|
||||
--world-size ${WORLD_SIZE} \
|
||||
--num-epochs 20
|
||||
--num-epochs 20 \
|
||||
--exp-dir $exp_dir \
|
||||
--enable-distillation True
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
@ -200,5 +202,6 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--max-duration 200 \
|
||||
--exp-dir ./pruned_transducer_stateless6/exp
|
||||
--exp-dir $exp_dir \
|
||||
--enable-distillation True
|
||||
fi
|
||||
|
@ -128,7 +128,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
@ -143,6 +143,13 @@ def get_parser():
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-distillation",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to eanble distillation.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
|
@ -41,7 +41,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--full-libri 1 \
|
||||
--max-duration 550
|
||||
|
||||
# For distiallation with codebook_indexes:
|
||||
# For distillation with codebook_indexes:
|
||||
|
||||
./pruned_transducer_stateless6/train.py \
|
||||
--manifest-dir ./data/vq_fbank \
|
||||
@ -300,6 +300,13 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-distillation",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to eanble distillation.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -372,7 +379,6 @@ def get_params() -> AttributeDict:
|
||||
"model_warm_step": 3000, # arg given to model, not for lrate
|
||||
"env_info": get_env_info(),
|
||||
# parameters for distillation with codebook indexes.
|
||||
"enable_distiallation": True,
|
||||
"distillation_layer": 5, # 0-based index
|
||||
# Since output rate of hubert is 50, while that of encoder is 8,
|
||||
# two successive codebook_index are concatenated together.
|
||||
@ -394,7 +400,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
middle_output_layer=params.distillation_layer
|
||||
if params.enable_distiallation
|
||||
if params.enable_distillation
|
||||
else None,
|
||||
)
|
||||
return encoder
|
||||
@ -434,7 +440,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
num_codebooks=params.num_codebooks
|
||||
if params.enable_distiallation
|
||||
if params.enable_distillation
|
||||
else 0,
|
||||
)
|
||||
return model
|
||||
@ -615,7 +621,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
info = MetricsTracker()
|
||||
if is_training and params.enable_distiallation:
|
||||
if is_training and params.enable_distillation:
|
||||
codebook_indexes, _ = extract_codebook_indexes(batch)
|
||||
codebook_indexes = codebook_indexes.to(device)
|
||||
else:
|
||||
@ -645,7 +651,7 @@ def compute_loss(
|
||||
params.simple_loss_scale * simple_loss
|
||||
+ pruned_loss_scale * pruned_loss
|
||||
)
|
||||
if is_training and params.enable_distiallation:
|
||||
if is_training and params.enable_distillation:
|
||||
assert codebook_loss is not None
|
||||
loss += params.codebook_loss_scale * codebook_loss
|
||||
|
||||
@ -661,7 +667,7 @@ def compute_loss(
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||
if is_training and params.enable_distiallation:
|
||||
if is_training and params.enable_distillation:
|
||||
info["codebook_loss"] = codebook_loss.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
Loading…
x
Reference in New Issue
Block a user