mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 15:44:17 +00:00
update deepspeed model loading
This commit is contained in:
parent
b6418acda2
commit
fa7ad4dc72
@ -390,7 +390,9 @@ def main():
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu')
|
||||||
|
model.load_state_dict(checkpoint, strict=True)
|
||||||
|
#load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
@ -159,6 +159,15 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name",
|
||||||
|
type=str,
|
||||||
|
default="large-v2",
|
||||||
|
choices=["large-v2", "large-v3", "medium", "small", "tiny"],
|
||||||
|
help="""The model name to use.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base-lr", type=float, default=1e-5, help="The base learning rate."
|
"--base-lr", type=float, default=1e-5, help="The base learning rate."
|
||||||
)
|
)
|
||||||
@ -305,7 +314,7 @@ def get_params() -> AttributeDict:
|
|||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 50,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 99999999999, # For the 100h subset, use 800
|
"valid_interval": 999999999999999999, # For the 100h subset, use 800
|
||||||
# parameters for zipformer
|
# parameters for zipformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4, # not passed in, this is fixed.
|
"subsampling_factor": 4, # not passed in, this is fixed.
|
||||||
@ -548,13 +557,14 @@ def compute_validation_loss(
|
|||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
loss, loss_info = compute_loss(
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
params=params,
|
loss, loss_info = compute_loss(
|
||||||
tokenizer=tokenizer,
|
params=params,
|
||||||
model=model,
|
tokenizer=tokenizer,
|
||||||
batch=batch,
|
model=model,
|
||||||
is_training=False,
|
batch=batch,
|
||||||
)
|
is_training=False,
|
||||||
|
)
|
||||||
assert loss.requires_grad is False
|
assert loss.requires_grad is False
|
||||||
tot_loss = tot_loss + loss_info
|
tot_loss = tot_loss + loss_info
|
||||||
|
|
||||||
@ -621,24 +631,24 @@ def train_one_epoch(
|
|||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
# if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||||
# logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
# valid_info = compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
# params=params,
|
params=params,
|
||||||
# tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
# model=model,
|
model=model,
|
||||||
# valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
# world_size=world_size,
|
world_size=world_size,
|
||||||
# )
|
)
|
||||||
# model.train()
|
model.train()
|
||||||
# logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||||
# logging.info(
|
logging.info(
|
||||||
# f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||||
# )
|
)
|
||||||
# if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
# valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
# tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
# )
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
@ -780,8 +790,7 @@ def run(rank, world_size, args):
|
|||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
# TODO download model only on rank 0
|
# TODO download model only on rank 0
|
||||||
# TODO may change compute validation loss using multiple cards
|
# TODO may change compute validation loss using multiple cards
|
||||||
# model = load_model("medium")
|
model = load_model(params.model_name)
|
||||||
model = load_model("large-v2")
|
|
||||||
del model.alignment_heads
|
del model.alignment_heads
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
@ -900,9 +909,10 @@ def run(rank, world_size, args):
|
|||||||
model.save_checkpoint(save_dir=params.exp_dir,
|
model.save_checkpoint(save_dir=params.exp_dir,
|
||||||
tag=f"epoch-{params.cur_epoch}",
|
tag=f"epoch-{params.cur_epoch}",
|
||||||
client_state={})
|
client_state={})
|
||||||
convert_zero_checkpoint_to_fp32_state_dict(
|
if rank == 0:
|
||||||
params.exp_dir, f"epoch-{params.cur_epoch}.pt",
|
convert_zero_checkpoint_to_fp32_state_dict(
|
||||||
tag=f"epoch-{params.cur_epoch}")
|
params.exp_dir, f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
|
||||||
|
tag=f"epoch-{params.cur_epoch}")
|
||||||
else:
|
else:
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
params=params,
|
params=params,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user