Fix bug in batching code for scalars

This commit is contained in:
Daniel Povey 2022-07-12 08:36:45 +08:00
parent 25cb8308d5
commit 8c44ff26f7

View File

@ -85,7 +85,7 @@ class BatchedOptimizer(Optimizer):
yield p_stacked, state # <-- calling code will do the actual optimization here! yield p_stacked, state # <-- calling code will do the actual optimization here!
# Now un-stack the parameter changes # Now un-stack the parameter changes
for i,p in enumerate(batch): for i,p in enumerate(batch):
p[:] = p_stacked[i] p.copy_(p_stacked[i])