mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-23 00:36:14 +00:00
update
This commit is contained in:
parent
eb25b173dc
commit
2e3ff0b31f
@ -166,6 +166,41 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="Number of entries in the memory for the Emformer",
|
help="Number of entries in the memory for the Emformer",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-distillation",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to eanble distillation.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--distillation-layer",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="On which encoder layer to perform KD"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-codebooks",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="Number of codebooks"
|
||||||
|
)
|
||||||
|
|
||||||
|
# distillation related args
|
||||||
|
parser.add_argument(
|
||||||
|
"--distil-delta",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Offset when doing KD"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--codebook-loss-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="The scale of codebook loss.",
|
||||||
|
)
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -358,41 +393,6 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--enable-distillation",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="Whether to eanble distillation.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--distillation-layer",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="On which encoder layer to perform KD"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-codebooks",
|
|
||||||
type=int,
|
|
||||||
default=16,
|
|
||||||
help="Number of codebooks"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--distil-delta",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="Offset when doing KD"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--codebook-loss-scale",
|
|
||||||
type=float,
|
|
||||||
default=0.1,
|
|
||||||
help="The scale of codebook loss.",
|
|
||||||
)
|
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -444,6 +444,7 @@ def get_params() -> AttributeDict:
|
|||||||
"""
|
"""
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
|
"frame_shift_ms": 10.0,
|
||||||
"best_train_loss": float("inf"),
|
"best_train_loss": float("inf"),
|
||||||
"best_valid_loss": float("inf"),
|
"best_valid_loss": float("inf"),
|
||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
@ -652,6 +653,9 @@ def extract_codebook_indexes(batch):
|
|||||||
cuts_pre_mixed = [
|
cuts_pre_mixed = [
|
||||||
c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts
|
c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts
|
||||||
]
|
]
|
||||||
|
for cut in cuts_pre_mixed:
|
||||||
|
cb = cut.codebook_indexes
|
||||||
|
print(f"All cuts have codebook indexes")
|
||||||
codebook_indexes, codebook_indexes_lens = collate_custom_field(
|
codebook_indexes, codebook_indexes_lens = collate_custom_field(
|
||||||
cuts_pre_mixed, "codebook_indexes", pad_value=-100
|
cuts_pre_mixed, "codebook_indexes", pad_value=-100
|
||||||
)
|
)
|
||||||
@ -969,6 +973,11 @@ def run(rank, world_size, args):
|
|||||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||||
logging.info("Training started")
|
logging.info("Training started")
|
||||||
|
|
||||||
|
# Note: it's better to set --spec-aug-time-warpi-factor=-1
|
||||||
|
# when doing distillation with vq.
|
||||||
|
if params.enable_distillation:
|
||||||
|
assert args.spec_aug_time_warp_factor < 1, "You need to disable time warp in MVQ KD"
|
||||||
|
|
||||||
if args.tensorboard and rank == 0:
|
if args.tensorboard and rank == 0:
|
||||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||||
else:
|
else:
|
||||||
@ -1034,10 +1043,10 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
if params.full_libri:
|
|
||||||
train_cuts = librispeech.train_all_shuf_cuts()
|
|
||||||
else:
|
|
||||||
train_cuts = librispeech.train_clean_100_cuts()
|
train_cuts = librispeech.train_clean_100_cuts()
|
||||||
|
if params.full_libri:
|
||||||
|
train_cuts += librispeech.train_clean_360_cuts()
|
||||||
|
train_cuts += librispeech.train_other_500_cuts()
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
@ -1067,14 +1076,14 @@ def run(rank, world_size, args):
|
|||||||
valid_cuts += librispeech.dev_other_cuts()
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
# if not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
# scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
# model=model,
|
||||||
train_dl=train_dl,
|
# train_dl=train_dl,
|
||||||
optimizer=optimizer,
|
# optimizer=optimizer,
|
||||||
sp=sp,
|
# sp=sp,
|
||||||
params=params,
|
# params=params,
|
||||||
)
|
# )
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = GradScaler(enabled=params.use_fp16)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user