Update train.py

This commit is contained in:
jinzr 2023-11-17 17:10:04 +08:00
parent 9bfaff59e3
commit fe35141e7e

View File

@ -1172,7 +1172,7 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
2**22 512
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)