mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Merge branch 'k2-fsa:master' into dev/k2ssl
This commit is contained in:
commit
61458e71e5
@ -1343,8 +1343,7 @@ def main():
|
|||||||
run(rank=0, world_size=1, args=args)
|
run(rank=0, world_size=1, args=args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
main()
|
||||||
|
@ -935,8 +935,7 @@ def main():
|
|||||||
run(rank=0, world_size=1, args=args)
|
run(rank=0, world_size=1, args=args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
main()
|
||||||
|
@ -160,8 +160,10 @@ class PiecewiseLinear(object):
|
|||||||
extra_x_vals.append(extra_x_val)
|
extra_x_vals.append(extra_x_val)
|
||||||
if len(extra_x_vals) > 0:
|
if len(extra_x_vals) > 0:
|
||||||
x_vals = sorted(set(x_vals + extra_x_vals))
|
x_vals = sorted(set(x_vals + extra_x_vals))
|
||||||
|
|
||||||
y_vals1 = [self(x) for x in x_vals]
|
y_vals1 = [self(x) for x in x_vals]
|
||||||
y_vals2 = [p(x) for x in x_vals]
|
y_vals2 = [p(x) for x in x_vals]
|
||||||
|
|
||||||
return (
|
return (
|
||||||
PiecewiseLinear(*zip(x_vals, y_vals1)),
|
PiecewiseLinear(*zip(x_vals, y_vals1)),
|
||||||
PiecewiseLinear(*zip(x_vals, y_vals2)),
|
PiecewiseLinear(*zip(x_vals, y_vals2)),
|
||||||
|
@ -593,6 +593,9 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.continue_finetune:
|
if params.continue_finetune:
|
||||||
assert params.start_epoch > 0, params.start_epoch
|
assert params.start_epoch > 0, params.start_epoch
|
||||||
|
if rank == 0:
|
||||||
|
# model_avg is only used with rank 0
|
||||||
|
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||||
checkpoints = load_checkpoint_if_available(
|
checkpoints = load_checkpoint_if_available(
|
||||||
params=params, model=model, model_avg=model_avg
|
params=params, model=model, model_avg=model_avg
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user