diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py index 77035318d..910a4dff8 100644 --- a/egs/aishell/ASR/whisper/train.py +++ b/egs/aishell/ASR/whisper/train.py @@ -430,9 +430,9 @@ def get_params() -> AttributeDict: "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, - "log_interval": 1, + "log_interval": 50, "reset_interval": 200, - "valid_interval": 50, # For the 100h subset, use 800 + "valid_interval": 99999999999, # For the 100h subset, use 800 # parameters for zipformer "feature_dim": 80, "subsampling_factor": 4, # not passed in, this is fixed. @@ -632,8 +632,8 @@ def compute_loss( feature = feature.to(device) feature = feature.transpose(1, 2) # (N, C, T) # pad feature from B,80,T to B,80,3000 - feature = torch.nn.functional.pad(feature, (0, 3000 - feature.shape[-1])) - print(feature.shape, 23333333) + #feature = torch.nn.functional.pad(feature, (0, 3000 - feature.shape[-1])) + #print(feature.shape, 23333333) supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) @@ -783,24 +783,24 @@ def train_one_epoch( for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - # if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - # logging.info("Computing validation loss") - # valid_info = compute_validation_loss( - # params=params, - # tokenizer=tokenizer, - # model=model, - # valid_dl=valid_dl, - # world_size=world_size, - # ) - # model.train() - # logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - # logging.info( - # f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - # ) - # if tb_writer is not None: - # valid_info.write_summary( - # tb_writer, "train/valid_", params.batch_idx_train - # ) + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + tokenizer=tokenizer, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) try: with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( @@ -967,8 +967,10 @@ def run(rank, world_size, args): logging.info("About to create model") - model = whisper.load_model("medium") - #model = load_model("medium") + #model = whisper.load_model("medium") + # TODO download model only on rank 0 + # TODO may change compute validation loss using multiple cards + model = load_model("medium") del model.alignment_heads tokenizer = whisper.tokenizer.get_tokenizer( model.is_multilingual, language="zh", task="transcribe"