diff --git a/egs/aishell/ASR/whisper/decode.py b/egs/aishell/ASR/whisper/decode.py index f9d3d4c06..578546bbe 100644 --- a/egs/aishell/ASR/whisper/decode.py +++ b/egs/aishell/ASR/whisper/decode.py @@ -390,7 +390,9 @@ def main(): ) ) else: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu') + model.load_state_dict(checkpoint, strict=True) + #load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()]) diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py index 5ae261335..158ad9443 100644 --- a/egs/aishell/ASR/whisper/train.py +++ b/egs/aishell/ASR/whisper/train.py @@ -159,6 +159,15 @@ def get_parser(): """, ) + parser.add_argument( + "--model-name", + type=str, + default="large-v2", + choices=["large-v2", "large-v3", "medium", "small", "tiny"], + help="""The model name to use. + """, + ) + parser.add_argument( "--base-lr", type=float, default=1e-5, help="The base learning rate." ) @@ -305,7 +314,7 @@ def get_params() -> AttributeDict: "batch_idx_train": 0, "log_interval": 50, "reset_interval": 200, - "valid_interval": 99999999999, # For the 100h subset, use 800 + "valid_interval": 999999999999999999, # For the 100h subset, use 800 # parameters for zipformer "feature_dim": 80, "subsampling_factor": 4, # not passed in, this is fixed. @@ -548,13 +557,14 @@ def compute_validation_loss( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - tokenizer=tokenizer, - model=model, - batch=batch, - is_training=False, - ) + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + tokenizer=tokenizer, + model=model, + batch=batch, + is_training=False, + ) assert loss.requires_grad is False tot_loss = tot_loss + loss_info @@ -621,24 +631,24 @@ def train_one_epoch( for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - # if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - # logging.info("Computing validation loss") - # valid_info = compute_validation_loss( - # params=params, - # tokenizer=tokenizer, - # model=model, - # valid_dl=valid_dl, - # world_size=world_size, - # ) - # model.train() - # logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - # logging.info( - # f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - # ) - # if tb_writer is not None: - # valid_info.write_summary( - # tb_writer, "train/valid_", params.batch_idx_train - # ) + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + tokenizer=tokenizer, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) try: with torch.cuda.amp.autocast(enabled=params.use_fp16): @@ -780,8 +790,7 @@ def run(rank, world_size, args): logging.info("About to create model") # TODO download model only on rank 0 # TODO may change compute validation loss using multiple cards - # model = load_model("medium") - model = load_model("large-v2") + model = load_model(params.model_name) del model.alignment_heads num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -900,9 +909,10 @@ def run(rank, world_size, args): model.save_checkpoint(save_dir=params.exp_dir, tag=f"epoch-{params.cur_epoch}", client_state={}) - convert_zero_checkpoint_to_fp32_state_dict( - params.exp_dir, f"epoch-{params.cur_epoch}.pt", - tag=f"epoch-{params.cur_epoch}") + if rank == 0: + convert_zero_checkpoint_to_fp32_state_dict( + params.exp_dir, f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", + tag=f"epoch-{params.cur_epoch}") else: save_checkpoint( params=params,