From 92895f774f533b3a82ac3b7c2b45166ca079d89a Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Thu, 11 Jan 2024 16:45:05 +0800 Subject: [PATCH] clean up codes --- egs/aishell/ASR/whisper/train.py | 229 ++----------------------------- 1 file changed, 10 insertions(+), 219 deletions(-) diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py index 910a4dff8..edff5edfc 100644 --- a/egs/aishell/ASR/whisper/train.py +++ b/egs/aishell/ASR/whisper/train.py @@ -103,84 +103,14 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: module.batch_count = batch_count -def add_model_arguments(parser: argparse.ArgumentParser): +def add_deepspeed_arguments(parser: argparse.ArgumentParser): parser.add_argument( - "--num-encoder-layers", + "--deepspeed-config", type=str, - default="2,4,3,2,4", - help="Number of zipformer encoder layers, comma separated.", + default=None, + help="Path to deepspeed json config file.", ) - parser.add_argument( - "--feedforward-dims", - type=str, - default="1024,1024,2048,2048,1024", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -203,7 +133,7 @@ def get_parser(): parser.add_argument( "--num-epochs", type=int, - default=30, + default=10, help="Number of epochs to train.", ) @@ -237,17 +167,7 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." + "--base-lr", type=float, default=1e-5, help="The base learning rate." ) parser.add_argument( @@ -266,46 +186,6 @@ def get_parser(): """, ) - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - parser.add_argument( "--seed", type=int, @@ -371,7 +251,7 @@ def get_parser(): help="Whether to use half precision training.", ) - add_model_arguments(parser) + add_deepspeed_arguments(parser) return parser @@ -443,24 +323,6 @@ def get_params() -> AttributeDict: return params - -# def get_transducer_model(params: AttributeDict) -> nn.Module: -# encoder = get_encoder_model(params) -# decoder = get_decoder_model(params) -# joiner = get_joiner_model(params) - -# model = Transducer( -# encoder=encoder, -# decoder=decoder, -# joiner=joiner, -# encoder_dim=int(params.encoder_dims.split(",")[-1]), -# decoder_dim=params.decoder_dim, -# joiner_dim=params.joiner_dim, -# vocab_size=params.vocab_size, -# ) -# return model - - def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, @@ -648,12 +510,6 @@ def compute_loss( # convert it to torch tensor text_tokens_list = [torch.LongTensor(text_tokens) for text_tokens in text_tokens_list] - # prev_outputs_tokens = _batch_tensors( - # [tokens[:-1] for tokens in text_tokens_list], pad_value=tokenizer.eot - # ) - # target_tokens = _batch_tensors( - # [tokens[1:] for tokens in text_tokens_list], pad_value=tokenizer.eot - # ) prev_outputs_tokens = _batch_tensors( [tokens[:-1] for tokens in text_tokens_list], pad_value=50256 ) @@ -664,11 +520,6 @@ def compute_loss( [tokens.shape[0] - 1 for tokens in text_tokens_list] ) - #print(prev_outputs_tokens.shape, prev_outputs_tokens) - #print(target_tokens.shape, target_tokens) - #print(target_lengths.shape, target_lengths) - #print(text_tokens_list) - #print("==========================================") decoder_criterion = LabelSmoothingLoss(ignore_index=50256, label_smoothing=0.1, reduction="sum") ignore_prefix_size = 3 with torch.set_grad_enabled(is_training): @@ -678,11 +529,6 @@ def compute_loss( loss = decoder_criterion(text_logits, target_tokens.to(device)) text_logits = text_logits[:, ignore_prefix_size:, :] target_tokens = target_tokens[:, ignore_prefix_size:] - #print(text_logits.shape) - # print greedy results of text_logits - #print(text_logits.argmax(dim=-1)) - # convert it to list of list then decode - #print([tokenizer.decode(tokens) for tokens in text_logits.argmax(dim=-1).tolist()]) assert loss.requires_grad == is_training @@ -903,24 +749,6 @@ def train_one_epoch( params.batch_idx_train, ) - # if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - # logging.info("Computing validation loss") - # valid_info = compute_validation_loss( - # params=params, - # tokenizer=tokenizer, - # model=model, - # valid_dl=valid_dl, - # world_size=world_size, - # ) - # model.train() - # logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - # logging.info( - # f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - # ) - # if tb_writer is not None: - # valid_info.write_summary( - # tb_writer, "train/valid_", params.batch_idx_train - # ) loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value @@ -945,9 +773,7 @@ def run(rank, world_size, args): params.update(vars(args)) fix_random_seed(params.seed) - # rank = get_rank() - # world_size = get_world_size() - # setup_dist(rank, world_size, use_ddp_launch=True) + setup_dist(use_ddp_launch=True) setup_logger(f"{params.exp_dir}/log/log-train") @@ -996,22 +822,6 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - #parameters_names = [] - #parameters_names.append( - # [name_param_pair[0] for name_param_pair in model.named_parameters()] - #) - # optimizer = ScaledAdam( - # model.parameters(), - # lr=params.base_lr, - # clipping_scale=2.0, - # parameters_names=parameters_names, - # ) - # optimizer = ScaledAdam( - # model.parameters(), - # lr=params.base_lr, - # clipping_scale=2.0, - # ) - optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1073,18 +883,11 @@ def run(rank, world_size, args): return True - #aishell = AIShell(manifest_dir=args.manifest_dir) - #train_cuts = aishell.train_cuts() - #asr_datamodule = AishellAsrDataModule(args) + aishell = AishellAsrDataModule(args) - # train_cuts = asr_datamodule.train_cuts() - # train_cuts = train_cuts.filter(remove_short_and_long_utt) - # if args.enable_musan: - # cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") - # else: - # cuts_musan = None + @@ -1095,15 +898,7 @@ def run(rank, world_size, args): else: sampler_state_dict = None - # train_dl = asr_datamodule.train_dataloaders( - # train_cuts, - # on_the_fly_feats=False, - # cuts_musan=cuts_musan, - # sampler_state_dict=sampler_state_dict, - # ) - # valid_cuts = aishell.valid_cuts() - # valid_dl = asr_datamodule.valid_dataloaders(valid_cuts) train_dl = aishell.train_dataloaders(aishell.train_cuts()) valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) # if not params.print_diagnostics: @@ -1192,10 +987,6 @@ def display_and_save_batch( logging.info(f"features shape: {features.shape}") - # y = graph_compiler.texts_to_ids(supervisions["text"]) - # num_tokens = sum(len(i) for i in y) - # logging.info(f"num tokens: {num_tokens}") - # def scan_pessimistic_batches_for_oom( # model: Union[nn.Module, DDP],