mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
clean up codes
This commit is contained in:
parent
98d11abedb
commit
92895f774f
@ -103,84 +103,14 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||
module.batch_count = batch_count
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
def add_deepspeed_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--num-encoder-layers",
|
||||
"--deepspeed-config",
|
||||
type=str,
|
||||
default="2,4,3,2,4",
|
||||
help="Number of zipformer encoder layers, comma separated.",
|
||||
default=None,
|
||||
help="Path to deepspeed json config file.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--feedforward-dims",
|
||||
type=str,
|
||||
default="1024,1024,2048,2048,1024",
|
||||
help="Feedforward dimension of the zipformer encoder layers, comma separated.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nhead",
|
||||
type=str,
|
||||
default="8,8,8,8,8",
|
||||
help="Number of attention heads in the zipformer encoder layers.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-dims",
|
||||
type=str,
|
||||
default="384,384,384,384,384",
|
||||
help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attention-dims",
|
||||
type=str,
|
||||
default="192,192,192,192,192",
|
||||
help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
|
||||
not the same as embedding dimension.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-unmasked-dims",
|
||||
type=str,
|
||||
default="256,256,256,256,256",
|
||||
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
||||
"Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
|
||||
" worse.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--zipformer-downsampling-factors",
|
||||
type=str,
|
||||
default="1,2,4,8,2",
|
||||
help="Downsampling factor for each stack of encoder layers.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cnn-module-kernels",
|
||||
type=str,
|
||||
default="31,31,31,31,31",
|
||||
help="Sizes of kernels in convolution modules",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-dim",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Embedding dimension in the decoder model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-dim",
|
||||
type=int,
|
||||
default=512,
|
||||
help="""Dimension used in the joiner model.
|
||||
Outputs from the encoder and decoder model are projected
|
||||
to this dimension before adding.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -203,7 +133,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=30,
|
||||
default=10,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
@ -237,17 +167,7 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_char",
|
||||
help="""The lang dir
|
||||
It contains language related input files such as
|
||||
"lexicon.txt"
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--base-lr", type=float, default=0.05, help="The base learning rate."
|
||||
"--base-lr", type=float, default=1e-5, help="The base learning rate."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -266,46 +186,6 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prune-range",
|
||||
type=int,
|
||||
default=5,
|
||||
help="The prune range for rnnt loss, it means how many symbols(context)"
|
||||
"we are using to compute the loss",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-scale",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="The scale to smooth the loss with lm "
|
||||
"(output of prediction network) part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--am-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The scale to smooth the loss with am (output of encoder network) part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--simple-loss-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="To get pruning ranges, we will calculate a simple version"
|
||||
"loss(joiner is just addition), this simple loss also uses for"
|
||||
"training (as a regularization item). We will scale the simple loss"
|
||||
"with this parameter before adding to the final loss.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
@ -371,7 +251,7 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
add_deepspeed_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
@ -443,24 +323,6 @@ def get_params() -> AttributeDict:
|
||||
|
||||
return params
|
||||
|
||||
|
||||
# def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
# encoder = get_encoder_model(params)
|
||||
# decoder = get_decoder_model(params)
|
||||
# joiner = get_joiner_model(params)
|
||||
|
||||
# model = Transducer(
|
||||
# encoder=encoder,
|
||||
# decoder=decoder,
|
||||
# joiner=joiner,
|
||||
# encoder_dim=int(params.encoder_dims.split(",")[-1]),
|
||||
# decoder_dim=params.decoder_dim,
|
||||
# joiner_dim=params.joiner_dim,
|
||||
# vocab_size=params.vocab_size,
|
||||
# )
|
||||
# return model
|
||||
|
||||
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
@ -648,12 +510,6 @@ def compute_loss(
|
||||
# convert it to torch tensor
|
||||
text_tokens_list = [torch.LongTensor(text_tokens) for text_tokens in text_tokens_list]
|
||||
|
||||
# prev_outputs_tokens = _batch_tensors(
|
||||
# [tokens[:-1] for tokens in text_tokens_list], pad_value=tokenizer.eot
|
||||
# )
|
||||
# target_tokens = _batch_tensors(
|
||||
# [tokens[1:] for tokens in text_tokens_list], pad_value=tokenizer.eot
|
||||
# )
|
||||
prev_outputs_tokens = _batch_tensors(
|
||||
[tokens[:-1] for tokens in text_tokens_list], pad_value=50256
|
||||
)
|
||||
@ -664,11 +520,6 @@ def compute_loss(
|
||||
[tokens.shape[0] - 1 for tokens in text_tokens_list]
|
||||
)
|
||||
|
||||
#print(prev_outputs_tokens.shape, prev_outputs_tokens)
|
||||
#print(target_tokens.shape, target_tokens)
|
||||
#print(target_lengths.shape, target_lengths)
|
||||
#print(text_tokens_list)
|
||||
#print("==========================================")
|
||||
decoder_criterion = LabelSmoothingLoss(ignore_index=50256, label_smoothing=0.1, reduction="sum")
|
||||
ignore_prefix_size = 3
|
||||
with torch.set_grad_enabled(is_training):
|
||||
@ -678,11 +529,6 @@ def compute_loss(
|
||||
loss = decoder_criterion(text_logits, target_tokens.to(device))
|
||||
text_logits = text_logits[:, ignore_prefix_size:, :]
|
||||
target_tokens = target_tokens[:, ignore_prefix_size:]
|
||||
#print(text_logits.shape)
|
||||
# print greedy results of text_logits
|
||||
#print(text_logits.argmax(dim=-1))
|
||||
# convert it to list of list then decode
|
||||
#print([tokenizer.decode(tokens) for tokens in text_logits.argmax(dim=-1).tolist()])
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
@ -903,24 +749,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
# if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
# logging.info("Computing validation loss")
|
||||
# valid_info = compute_validation_loss(
|
||||
# params=params,
|
||||
# tokenizer=tokenizer,
|
||||
# model=model,
|
||||
# valid_dl=valid_dl,
|
||||
# world_size=world_size,
|
||||
# )
|
||||
# model.train()
|
||||
# logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
# logging.info(
|
||||
# f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
# )
|
||||
# if tb_writer is not None:
|
||||
# valid_info.write_summary(
|
||||
# tb_writer, "train/valid_", params.batch_idx_train
|
||||
# )
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
@ -945,9 +773,7 @@ def run(rank, world_size, args):
|
||||
params.update(vars(args))
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
# rank = get_rank()
|
||||
# world_size = get_world_size()
|
||||
# setup_dist(rank, world_size, use_ddp_launch=True)
|
||||
|
||||
setup_dist(use_ddp_launch=True)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
@ -996,22 +822,6 @@ def run(rank, world_size, args):
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
#parameters_names = []
|
||||
#parameters_names.append(
|
||||
# [name_param_pair[0] for name_param_pair in model.named_parameters()]
|
||||
#)
|
||||
# optimizer = ScaledAdam(
|
||||
# model.parameters(),
|
||||
# lr=params.base_lr,
|
||||
# clipping_scale=2.0,
|
||||
# parameters_names=parameters_names,
|
||||
# )
|
||||
# optimizer = ScaledAdam(
|
||||
# model.parameters(),
|
||||
# lr=params.base_lr,
|
||||
# clipping_scale=2.0,
|
||||
# )
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr)
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||
|
||||
@ -1073,18 +883,11 @@ def run(rank, world_size, args):
|
||||
|
||||
return True
|
||||
|
||||
#aishell = AIShell(manifest_dir=args.manifest_dir)
|
||||
#train_cuts = aishell.train_cuts()
|
||||
#asr_datamodule = AishellAsrDataModule(args)
|
||||
|
||||
|
||||
aishell = AishellAsrDataModule(args)
|
||||
# train_cuts = asr_datamodule.train_cuts()
|
||||
# train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
# if args.enable_musan:
|
||||
# cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
|
||||
# else:
|
||||
# cuts_musan = None
|
||||
|
||||
|
||||
|
||||
|
||||
@ -1095,15 +898,7 @@ def run(rank, world_size, args):
|
||||
else:
|
||||
sampler_state_dict = None
|
||||
|
||||
# train_dl = asr_datamodule.train_dataloaders(
|
||||
# train_cuts,
|
||||
# on_the_fly_feats=False,
|
||||
# cuts_musan=cuts_musan,
|
||||
# sampler_state_dict=sampler_state_dict,
|
||||
# )
|
||||
|
||||
# valid_cuts = aishell.valid_cuts()
|
||||
# valid_dl = asr_datamodule.valid_dataloaders(valid_cuts)
|
||||
train_dl = aishell.train_dataloaders(aishell.train_cuts())
|
||||
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
|
||||
# if not params.print_diagnostics:
|
||||
@ -1192,10 +987,6 @@ def display_and_save_batch(
|
||||
|
||||
logging.info(f"features shape: {features.shape}")
|
||||
|
||||
# y = graph_compiler.texts_to_ids(supervisions["text"])
|
||||
# num_tokens = sum(len(i) for i in y)
|
||||
# logging.info(f"num tokens: {num_tokens}")
|
||||
|
||||
|
||||
# def scan_pessimistic_batches_for_oom(
|
||||
# model: Union[nn.Module, DDP],
|
||||
|
Loading…
x
Reference in New Issue
Block a user