mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
add enable-distillation argument option, fix monir typos
This commit is contained in:
parent
f8541b3ab1
commit
7833da7d2b
@ -188,7 +188,9 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
--spec-aug-time-warp-factor -1 \
|
--spec-aug-time-warp-factor -1 \
|
||||||
--max-duration 300 \
|
--max-duration 300 \
|
||||||
--world-size ${WORLD_SIZE} \
|
--world-size ${WORLD_SIZE} \
|
||||||
--num-epochs 20
|
--num-epochs 20 \
|
||||||
|
--exp-dir $exp_dir \
|
||||||
|
--enable-distillation True
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
@ -200,5 +202,6 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10 \
|
--avg 10 \
|
||||||
--max-duration 200 \
|
--max-duration 200 \
|
||||||
--exp-dir ./pruned_transducer_stateless6/exp
|
--exp-dir $exp_dir \
|
||||||
|
--enable-distillation True
|
||||||
fi
|
fi
|
||||||
|
@ -128,7 +128,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-averaged-model",
|
"--use-averaged-model",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=True,
|
||||||
help="Whether to load averaged model. Currently it only supports "
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
"using --epoch. If True, it would decode with the averaged model "
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
@ -143,6 +143,13 @@ def get_parser():
|
|||||||
help="The experiment dir",
|
help="The experiment dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-distillation",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to eanble distillation.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--bpe-model",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -41,7 +41,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--max-duration 550
|
--max-duration 550
|
||||||
|
|
||||||
# For distiallation with codebook_indexes:
|
# For distillation with codebook_indexes:
|
||||||
|
|
||||||
./pruned_transducer_stateless6/train.py \
|
./pruned_transducer_stateless6/train.py \
|
||||||
--manifest-dir ./data/vq_fbank \
|
--manifest-dir ./data/vq_fbank \
|
||||||
@ -300,6 +300,13 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-distillation",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to eanble distillation.",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -372,7 +379,6 @@ def get_params() -> AttributeDict:
|
|||||||
"model_warm_step": 3000, # arg given to model, not for lrate
|
"model_warm_step": 3000, # arg given to model, not for lrate
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
# parameters for distillation with codebook indexes.
|
# parameters for distillation with codebook indexes.
|
||||||
"enable_distiallation": True,
|
|
||||||
"distillation_layer": 5, # 0-based index
|
"distillation_layer": 5, # 0-based index
|
||||||
# Since output rate of hubert is 50, while that of encoder is 8,
|
# Since output rate of hubert is 50, while that of encoder is 8,
|
||||||
# two successive codebook_index are concatenated together.
|
# two successive codebook_index are concatenated together.
|
||||||
@ -394,7 +400,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
middle_output_layer=params.distillation_layer
|
middle_output_layer=params.distillation_layer
|
||||||
if params.enable_distiallation
|
if params.enable_distillation
|
||||||
else None,
|
else None,
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
@ -434,7 +440,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
joiner_dim=params.joiner_dim,
|
joiner_dim=params.joiner_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
num_codebooks=params.num_codebooks
|
num_codebooks=params.num_codebooks
|
||||||
if params.enable_distiallation
|
if params.enable_distillation
|
||||||
else 0,
|
else 0,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
@ -615,7 +621,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
if is_training and params.enable_distiallation:
|
if is_training and params.enable_distillation:
|
||||||
codebook_indexes, _ = extract_codebook_indexes(batch)
|
codebook_indexes, _ = extract_codebook_indexes(batch)
|
||||||
codebook_indexes = codebook_indexes.to(device)
|
codebook_indexes = codebook_indexes.to(device)
|
||||||
else:
|
else:
|
||||||
@ -645,7 +651,7 @@ def compute_loss(
|
|||||||
params.simple_loss_scale * simple_loss
|
params.simple_loss_scale * simple_loss
|
||||||
+ pruned_loss_scale * pruned_loss
|
+ pruned_loss_scale * pruned_loss
|
||||||
)
|
)
|
||||||
if is_training and params.enable_distiallation:
|
if is_training and params.enable_distillation:
|
||||||
assert codebook_loss is not None
|
assert codebook_loss is not None
|
||||||
loss += params.codebook_loss_scale * codebook_loss
|
loss += params.codebook_loss_scale * codebook_loss
|
||||||
|
|
||||||
@ -661,7 +667,7 @@ def compute_loss(
|
|||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||||
if is_training and params.enable_distiallation:
|
if is_training and params.enable_distillation:
|
||||||
info["codebook_loss"] = codebook_loss.detach().cpu().item()
|
info["codebook_loss"] = codebook_loss.detach().cpu().item()
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
Loading…
x
Reference in New Issue
Block a user