add enable-distillation argument option, fix monir typos

This commit is contained in:
yaozengwei 2022-06-17 11:05:50 +08:00
parent f8541b3ab1
commit 7833da7d2b
3 changed files with 26 additions and 10 deletions

View File

@ -188,7 +188,9 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
--spec-aug-time-warp-factor -1 \
--max-duration 300 \
--world-size ${WORLD_SIZE} \
--num-epochs 20
--num-epochs 20 \
--exp-dir $exp_dir \
--enable-distillation True
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
@ -200,5 +202,6 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
--epoch 20 \
--avg 10 \
--max-duration 200 \
--exp-dir ./pruned_transducer_stateless6/exp
--exp-dir $exp_dir \
--enable-distillation True
fi

View File

@ -128,7 +128,7 @@ def get_parser():
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
@ -143,6 +143,13 @@ def get_parser():
help="The experiment dir",
)
parser.add_argument(
"--enable-distillation",
type=str2bool,
default=True,
help="Whether to eanble distillation.",
)
parser.add_argument(
"--bpe-model",
type=str,

View File

@ -41,7 +41,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--full-libri 1 \
--max-duration 550
# For distiallation with codebook_indexes:
# For distillation with codebook_indexes:
./pruned_transducer_stateless6/train.py \
--manifest-dir ./data/vq_fbank \
@ -300,6 +300,13 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--enable-distillation",
type=str2bool,
default=True,
help="Whether to eanble distillation.",
)
return parser
@ -372,7 +379,6 @@ def get_params() -> AttributeDict:
"model_warm_step": 3000, # arg given to model, not for lrate
"env_info": get_env_info(),
# parameters for distillation with codebook indexes.
"enable_distiallation": True,
"distillation_layer": 5, # 0-based index
# Since output rate of hubert is 50, while that of encoder is 8,
# two successive codebook_index are concatenated together.
@ -394,7 +400,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
middle_output_layer=params.distillation_layer
if params.enable_distiallation
if params.enable_distillation
else None,
)
return encoder
@ -434,7 +440,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
num_codebooks=params.num_codebooks
if params.enable_distiallation
if params.enable_distillation
else 0,
)
return model
@ -615,7 +621,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)
info = MetricsTracker()
if is_training and params.enable_distiallation:
if is_training and params.enable_distillation:
codebook_indexes, _ = extract_codebook_indexes(batch)
codebook_indexes = codebook_indexes.to(device)
else:
@ -645,7 +651,7 @@ def compute_loss(
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
if is_training and params.enable_distiallation:
if is_training and params.enable_distillation:
assert codebook_loss is not None
loss += params.codebook_loss_scale * codebook_loss
@ -661,7 +667,7 @@ def compute_loss(
info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if is_training and params.enable_distiallation:
if is_training and params.enable_distillation:
info["codebook_loss"] = codebook_loss.detach().cpu().item()
return loss, info