update deepspeed model loading

This commit is contained in:
Yuekai Zhang 2024-01-12 17:29:24 +08:00
parent b6418acda2
commit fa7ad4dc72
2 changed files with 44 additions and 32 deletions

View File

@ -390,7 +390,9 @@ def main():
) )
) )
else: else:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu')
model.load_state_dict(checkpoint, strict=True)
#load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device) model.to(device)
model.eval() model.eval()
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])

View File

@ -159,6 +159,15 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--model-name",
type=str,
default="large-v2",
choices=["large-v2", "large-v3", "medium", "small", "tiny"],
help="""The model name to use.
""",
)
parser.add_argument( parser.add_argument(
"--base-lr", type=float, default=1e-5, help="The base learning rate." "--base-lr", type=float, default=1e-5, help="The base learning rate."
) )
@ -305,7 +314,7 @@ def get_params() -> AttributeDict:
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 50, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 99999999999, # For the 100h subset, use 800 "valid_interval": 999999999999999999, # For the 100h subset, use 800
# parameters for zipformer # parameters for zipformer
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, # not passed in, this is fixed. "subsampling_factor": 4, # not passed in, this is fixed.
@ -548,6 +557,7 @@ def compute_validation_loss(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -621,24 +631,24 @@ def train_one_epoch(
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
# if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
# logging.info("Computing validation loss") logging.info("Computing validation loss")
# valid_info = compute_validation_loss( valid_info = compute_validation_loss(
# params=params, params=params,
# tokenizer=tokenizer, tokenizer=tokenizer,
# model=model, model=model,
# valid_dl=valid_dl, valid_dl=valid_dl,
# world_size=world_size, world_size=world_size,
# ) )
# model.train() model.train()
# logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
# logging.info( logging.info(
# f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
# ) )
# if tb_writer is not None: if tb_writer is not None:
# valid_info.write_summary( valid_info.write_summary(
# tb_writer, "train/valid_", params.batch_idx_train tb_writer, "train/valid_", params.batch_idx_train
# ) )
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.cuda.amp.autocast(enabled=params.use_fp16):
@ -780,8 +790,7 @@ def run(rank, world_size, args):
logging.info("About to create model") logging.info("About to create model")
# TODO download model only on rank 0 # TODO download model only on rank 0
# TODO may change compute validation loss using multiple cards # TODO may change compute validation loss using multiple cards
# model = load_model("medium") model = load_model(params.model_name)
model = load_model("large-v2")
del model.alignment_heads del model.alignment_heads
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -900,8 +909,9 @@ def run(rank, world_size, args):
model.save_checkpoint(save_dir=params.exp_dir, model.save_checkpoint(save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}", tag=f"epoch-{params.cur_epoch}",
client_state={}) client_state={})
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict( convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir, f"epoch-{params.cur_epoch}.pt", params.exp_dir, f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
tag=f"epoch-{params.cur_epoch}") tag=f"epoch-{params.cur_epoch}")
else: else:
save_checkpoint( save_checkpoint(