clean up codes

This commit is contained in:
Yuekai Zhang 2024-01-11 16:45:05 +08:00
parent 98d11abedb
commit 92895f774f

View File

@ -103,84 +103,14 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
module.batch_count = batch_count
def add_model_arguments(parser: argparse.ArgumentParser):
def add_deepspeed_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-encoder-layers",
"--deepspeed-config",
type=str,
default="2,4,3,2,4",
help="Number of zipformer encoder layers, comma separated.",
default=None,
help="Path to deepspeed json config file.",
)
parser.add_argument(
"--feedforward-dims",
type=str,
default="1024,1024,2048,2048,1024",
help="Feedforward dimension of the zipformer encoder layers, comma separated.",
)
parser.add_argument(
"--nhead",
type=str,
default="8,8,8,8,8",
help="Number of attention heads in the zipformer encoder layers.",
)
parser.add_argument(
"--encoder-dims",
type=str,
default="384,384,384,384,384",
help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
)
parser.add_argument(
"--attention-dims",
type=str,
default="192,192,192,192,192",
help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
not the same as embedding dimension.""",
)
parser.add_argument(
"--encoder-unmasked-dims",
type=str,
default="256,256,256,256,256",
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
" worse.",
)
parser.add_argument(
"--zipformer-downsampling-factors",
type=str,
default="1,2,4,8,2",
help="Downsampling factor for each stack of encoder layers.",
)
parser.add_argument(
"--cnn-module-kernels",
type=str,
default="31,31,31,31,31",
help="Sizes of kernels in convolution modules",
)
parser.add_argument(
"--decoder-dim",
type=int,
default=512,
help="Embedding dimension in the decoder model.",
)
parser.add_argument(
"--joiner-dim",
type=int,
default=512,
help="""Dimension used in the joiner model.
Outputs from the encoder and decoder model are projected
to this dimension before adding.
""",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -203,7 +133,7 @@ def get_parser():
parser.add_argument(
"--num-epochs",
type=int,
default=30,
default=10,
help="Number of epochs to train.",
)
@ -237,17 +167,7 @@ def get_parser():
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--base-lr", type=float, default=0.05, help="The base learning rate."
"--base-lr", type=float, default=1e-5, help="The base learning rate."
)
parser.add_argument(
@ -266,46 +186,6 @@ def get_parser():
""",
)
parser.add_argument(
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--prune-range",
type=int,
default=5,
help="The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss",
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.25,
help="The scale to smooth the loss with lm "
"(output of prediction network) part.",
)
parser.add_argument(
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network) part.",
)
parser.add_argument(
"--simple-loss-scale",
type=float,
default=0.5,
help="To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss.",
)
parser.add_argument(
"--seed",
type=int,
@ -371,7 +251,7 @@ def get_parser():
help="Whether to use half precision training.",
)
add_model_arguments(parser)
add_deepspeed_arguments(parser)
return parser
@ -443,24 +323,6 @@ def get_params() -> AttributeDict:
return params
# def get_transducer_model(params: AttributeDict) -> nn.Module:
# encoder = get_encoder_model(params)
# decoder = get_decoder_model(params)
# joiner = get_joiner_model(params)
# model = Transducer(
# encoder=encoder,
# decoder=decoder,
# joiner=joiner,
# encoder_dim=int(params.encoder_dims.split(",")[-1]),
# decoder_dim=params.decoder_dim,
# joiner_dim=params.joiner_dim,
# vocab_size=params.vocab_size,
# )
# return model
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
@ -648,12 +510,6 @@ def compute_loss(
# convert it to torch tensor
text_tokens_list = [torch.LongTensor(text_tokens) for text_tokens in text_tokens_list]
# prev_outputs_tokens = _batch_tensors(
# [tokens[:-1] for tokens in text_tokens_list], pad_value=tokenizer.eot
# )
# target_tokens = _batch_tensors(
# [tokens[1:] for tokens in text_tokens_list], pad_value=tokenizer.eot
# )
prev_outputs_tokens = _batch_tensors(
[tokens[:-1] for tokens in text_tokens_list], pad_value=50256
)
@ -664,11 +520,6 @@ def compute_loss(
[tokens.shape[0] - 1 for tokens in text_tokens_list]
)
#print(prev_outputs_tokens.shape, prev_outputs_tokens)
#print(target_tokens.shape, target_tokens)
#print(target_lengths.shape, target_lengths)
#print(text_tokens_list)
#print("==========================================")
decoder_criterion = LabelSmoothingLoss(ignore_index=50256, label_smoothing=0.1, reduction="sum")
ignore_prefix_size = 3
with torch.set_grad_enabled(is_training):
@ -678,11 +529,6 @@ def compute_loss(
loss = decoder_criterion(text_logits, target_tokens.to(device))
text_logits = text_logits[:, ignore_prefix_size:, :]
target_tokens = target_tokens[:, ignore_prefix_size:]
#print(text_logits.shape)
# print greedy results of text_logits
#print(text_logits.argmax(dim=-1))
# convert it to list of list then decode
#print([tokenizer.decode(tokens) for tokens in text_logits.argmax(dim=-1).tolist()])
assert loss.requires_grad == is_training
@ -903,24 +749,6 @@ def train_one_epoch(
params.batch_idx_train,
)
# if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
# logging.info("Computing validation loss")
# valid_info = compute_validation_loss(
# params=params,
# tokenizer=tokenizer,
# model=model,
# valid_dl=valid_dl,
# world_size=world_size,
# )
# model.train()
# logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
# logging.info(
# f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
# )
# if tb_writer is not None:
# valid_info.write_summary(
# tb_writer, "train/valid_", params.batch_idx_train
# )
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
@ -945,9 +773,7 @@ def run(rank, world_size, args):
params.update(vars(args))
fix_random_seed(params.seed)
# rank = get_rank()
# world_size = get_world_size()
# setup_dist(rank, world_size, use_ddp_launch=True)
setup_dist(use_ddp_launch=True)
setup_logger(f"{params.exp_dir}/log/log-train")
@ -996,22 +822,6 @@ def run(rank, world_size, args):
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
#parameters_names = []
#parameters_names.append(
# [name_param_pair[0] for name_param_pair in model.named_parameters()]
#)
# optimizer = ScaledAdam(
# model.parameters(),
# lr=params.base_lr,
# clipping_scale=2.0,
# parameters_names=parameters_names,
# )
# optimizer = ScaledAdam(
# model.parameters(),
# lr=params.base_lr,
# clipping_scale=2.0,
# )
optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
@ -1073,18 +883,11 @@ def run(rank, world_size, args):
return True
#aishell = AIShell(manifest_dir=args.manifest_dir)
#train_cuts = aishell.train_cuts()
#asr_datamodule = AishellAsrDataModule(args)
aishell = AishellAsrDataModule(args)
# train_cuts = asr_datamodule.train_cuts()
# train_cuts = train_cuts.filter(remove_short_and_long_utt)
# if args.enable_musan:
# cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
# else:
# cuts_musan = None
@ -1095,15 +898,7 @@ def run(rank, world_size, args):
else:
sampler_state_dict = None
# train_dl = asr_datamodule.train_dataloaders(
# train_cuts,
# on_the_fly_feats=False,
# cuts_musan=cuts_musan,
# sampler_state_dict=sampler_state_dict,
# )
# valid_cuts = aishell.valid_cuts()
# valid_dl = asr_datamodule.valid_dataloaders(valid_cuts)
train_dl = aishell.train_dataloaders(aishell.train_cuts())
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
# if not params.print_diagnostics:
@ -1192,10 +987,6 @@ def display_and_save_batch(
logging.info(f"features shape: {features.shape}")
# y = graph_compiler.texts_to_ids(supervisions["text"])
# num_tokens = sum(len(i) for i in y)
# logging.info(f"num tokens: {num_tokens}")
# def scan_pessimistic_batches_for_oom(
# model: Union[nn.Module, DDP],