This commit is contained in:
marcoyang 2023-01-04 10:51:36 +08:00
parent eb25b173dc
commit 2e3ff0b31f

View File

@ -166,6 +166,41 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Number of entries in the memory for the Emformer",
)
parser.add_argument(
"--enable-distillation",
type=str2bool,
default=True,
help="Whether to eanble distillation.",
)
parser.add_argument(
"--distillation-layer",
type=int,
default=8,
help="On which encoder layer to perform KD"
)
parser.add_argument(
"--num-codebooks",
type=int,
default=16,
help="Number of codebooks"
)
# distillation related args
parser.add_argument(
"--distil-delta",
type=int,
default=None,
help="Offset when doing KD"
)
parser.add_argument(
"--codebook-loss-scale",
type=float,
default=0.1,
help="The scale of codebook loss.",
)
def get_parser():
parser = argparse.ArgumentParser(
@ -358,41 +393,6 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--enable-distillation",
type=str2bool,
default=True,
help="Whether to eanble distillation.",
)
parser.add_argument(
"--distillation-layer",
type=int,
default=8,
help="On which encoder layer to perform KD"
)
parser.add_argument(
"--num-codebooks",
type=int,
default=16,
help="Number of codebooks"
)
parser.add_argument(
"--distil-delta",
type=int,
default=None,
help="Offset when doing KD"
)
parser.add_argument(
"--codebook-loss-scale",
type=float,
default=0.1,
help="The scale of codebook loss.",
)
add_model_arguments(parser)
return parser
@ -444,6 +444,7 @@ def get_params() -> AttributeDict:
"""
params = AttributeDict(
{
"frame_shift_ms": 10.0,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
@ -652,6 +653,9 @@ def extract_codebook_indexes(batch):
cuts_pre_mixed = [
c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts
]
for cut in cuts_pre_mixed:
cb = cut.codebook_indexes
print(f"All cuts have codebook indexes")
codebook_indexes, codebook_indexes_lens = collate_custom_field(
cuts_pre_mixed, "codebook_indexes", pad_value=-100
)
@ -969,6 +973,11 @@ def run(rank, world_size, args):
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
# Note: it's better to set --spec-aug-time-warpi-factor=-1
# when doing distillation with vq.
if params.enable_distillation:
assert args.spec_aug_time_warp_factor < 1, "You need to disable time warp in MVQ KD"
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
@ -1034,10 +1043,10 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts = librispeech.train_all_shuf_cuts()
else:
train_cuts = librispeech.train_clean_100_cuts()
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
@ -1067,14 +1076,14 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
# if not params.print_diagnostics:
# scan_pessimistic_batches_for_oom(
# model=model,
# train_dl=train_dl,
# optimizer=optimizer,
# sp=sp,
# params=params,
# )
scaler = GradScaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: