Fix bug in batching code for scalars

This commit is contained in:
Daniel Povey 2022-07-12 08:36:45 +08:00
parent 25cb8308d5
commit 8c44ff26f7

View File

@ -85,7 +85,7 @@ class BatchedOptimizer(Optimizer):
yield p_stacked, state # <-- calling code will do the actual optimization here!
# Now un-stack the parameter changes
for i,p in enumerate(batch):
p[:] = p_stacked[i]
p.copy_(p_stacked[i])