Merge branch 'k2-fsa:master' into dev/k2ssl

This commit is contained in:
Yifan Yang 2025-04-13 16:36:17 +08:00 committed by GitHub
commit 61458e71e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 11 additions and 8 deletions

View File

@ -1343,8 +1343,7 @@ def main():
run(rank=0, world_size=1, args=args) run(rank=0, world_size=1, args=args)
if __name__ == "__main__":
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
if __name__ == "__main__":
main() main()

View File

@ -935,8 +935,7 @@ def main():
run(rank=0, world_size=1, args=args) run(rank=0, world_size=1, args=args)
if __name__ == "__main__":
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
if __name__ == "__main__":
main() main()

View File

@ -160,8 +160,10 @@ class PiecewiseLinear(object):
extra_x_vals.append(extra_x_val) extra_x_vals.append(extra_x_val)
if len(extra_x_vals) > 0: if len(extra_x_vals) > 0:
x_vals = sorted(set(x_vals + extra_x_vals)) x_vals = sorted(set(x_vals + extra_x_vals))
y_vals1 = [self(x) for x in x_vals] y_vals1 = [self(x) for x in x_vals]
y_vals2 = [p(x) for x in x_vals] y_vals2 = [p(x) for x in x_vals]
return ( return (
PiecewiseLinear(*zip(x_vals, y_vals1)), PiecewiseLinear(*zip(x_vals, y_vals1)),
PiecewiseLinear(*zip(x_vals, y_vals2)), PiecewiseLinear(*zip(x_vals, y_vals2)),

View File

@ -593,6 +593,9 @@ def run(rank, world_size, args):
if params.continue_finetune: if params.continue_finetune:
assert params.start_epoch > 0, params.start_epoch assert params.start_epoch > 0, params.start_epoch
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
checkpoints = load_checkpoint_if_available( checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg params=params, model=model, model_avg=model_avg
) )