mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Merge branch 'k2-fsa:master' into dev/k2ssl
This commit is contained in:
commit
61458e71e5
@ -1343,8 +1343,7 @@ def main():
|
||||
run(rank=0, world_size=1, args=args)
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
main()
|
||||
|
@ -935,8 +935,7 @@ def main():
|
||||
run(rank=0, world_size=1, args=args)
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
main()
|
||||
|
@ -160,8 +160,10 @@ class PiecewiseLinear(object):
|
||||
extra_x_vals.append(extra_x_val)
|
||||
if len(extra_x_vals) > 0:
|
||||
x_vals = sorted(set(x_vals + extra_x_vals))
|
||||
|
||||
y_vals1 = [self(x) for x in x_vals]
|
||||
y_vals2 = [p(x) for x in x_vals]
|
||||
|
||||
return (
|
||||
PiecewiseLinear(*zip(x_vals, y_vals1)),
|
||||
PiecewiseLinear(*zip(x_vals, y_vals2)),
|
||||
|
@ -593,6 +593,9 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.continue_finetune:
|
||||
assert params.start_epoch > 0, params.start_epoch
|
||||
if rank == 0:
|
||||
# model_avg is only used with rank 0
|
||||
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||
checkpoints = load_checkpoint_if_available(
|
||||
params=params, model=model, model_avg=model_avg
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user