mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
zipformer BF16 training recipe (#1700)
Support Zipformer AMP +BF16 training
This commit is contained in:
parent
3b434fe83c
commit
a6c02a4d8c
@ -307,6 +307,23 @@ done
|
|||||||
|
|
||||||
To decode with external language models, please refer to the documentation [here](https://k2-fsa.github.io/icefall/decoding-with-langugage-models/index.html).
|
To decode with external language models, please refer to the documentation [here](https://k2-fsa.github.io/icefall/decoding-with-langugage-models/index.html).
|
||||||
|
|
||||||
|
We also support training Zipformer with AMP+bf16 format (requires bf16 support). See [here](https://github.com/k2-fsa/icefall/pull/1700) for more details and pre-trained models. **The same command can be used for decoding and exporting the model.**
|
||||||
|
|
||||||
|
The amp+bf16 training command is:
|
||||||
|
```bash
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
./zipformer/train.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--num-epochs 50 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--use-fp16 0 \
|
||||||
|
--use-bf16 1 \
|
||||||
|
--exp-dir zipformer/exp_amp_bf16 \
|
||||||
|
--causal 0 \
|
||||||
|
--full-libri 1 \
|
||||||
|
--max-duration 1000
|
||||||
|
```
|
||||||
|
|
||||||
##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M
|
##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M
|
||||||
|
|
||||||
The tensorboard log can be found at
|
The tensorboard log can be found at
|
||||||
|
@ -297,7 +297,7 @@ class SoftmaxFunction(torch.autograd.Function):
|
|||||||
# (presumably) that op does not support float16, and autocast
|
# (presumably) that op does not support float16, and autocast
|
||||||
# is enabled.
|
# is enabled.
|
||||||
if torch.is_autocast_enabled():
|
if torch.is_autocast_enabled():
|
||||||
ans = ans.to(torch.float16)
|
ans = ans.to(torch.get_autocast_gpu_dtype())
|
||||||
ctx.save_for_backward(ans)
|
ctx.save_for_backward(ans)
|
||||||
ctx.x_dtype = x.dtype
|
ctx.x_dtype = x.dtype
|
||||||
ctx.dim = dim
|
ctx.dim = dim
|
||||||
@ -1234,7 +1234,7 @@ class DoubleSwishFunction(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor) -> Tensor:
|
def forward(ctx, x: Tensor) -> Tensor:
|
||||||
requires_grad = x.requires_grad
|
requires_grad = x.requires_grad
|
||||||
if x.dtype == torch.float16:
|
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
|
||||||
x = x.to(torch.float32)
|
x = x.to(torch.float32)
|
||||||
|
|
||||||
s = torch.sigmoid(x - 1.0)
|
s = torch.sigmoid(x - 1.0)
|
||||||
@ -1346,7 +1346,7 @@ class SwooshLFunction(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor) -> Tensor:
|
def forward(ctx, x: Tensor) -> Tensor:
|
||||||
requires_grad = x.requires_grad
|
requires_grad = x.requires_grad
|
||||||
if x.dtype == torch.float16:
|
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
|
||||||
x = x.to(torch.float32)
|
x = x.to(torch.float32)
|
||||||
|
|
||||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||||
@ -1379,7 +1379,7 @@ class SwooshLFunction(torch.autograd.Function):
|
|||||||
d_int = d_scaled.to(torch.uint8)
|
d_int = d_scaled.to(torch.uint8)
|
||||||
ctx.save_for_backward(d_int)
|
ctx.save_for_backward(d_int)
|
||||||
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
||||||
y = y.to(torch.float16)
|
y = y.to(torch.get_autocast_gpu_dtype())
|
||||||
return y
|
return y
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -1425,7 +1425,7 @@ class SwooshRFunction(torch.autograd.Function):
|
|||||||
def forward(ctx, x: Tensor) -> Tensor:
|
def forward(ctx, x: Tensor) -> Tensor:
|
||||||
requires_grad = x.requires_grad
|
requires_grad = x.requires_grad
|
||||||
|
|
||||||
if x.dtype == torch.float16:
|
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
|
||||||
x = x.to(torch.float32)
|
x = x.to(torch.float32)
|
||||||
|
|
||||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||||
@ -1455,7 +1455,7 @@ class SwooshRFunction(torch.autograd.Function):
|
|||||||
d_int = d_scaled.to(torch.uint8)
|
d_int = d_scaled.to(torch.uint8)
|
||||||
ctx.save_for_backward(d_int)
|
ctx.save_for_backward(d_int)
|
||||||
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
||||||
y = y.to(torch.float16)
|
y = y.to(torch.get_autocast_gpu_dtype())
|
||||||
return y
|
return y
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -521,6 +521,13 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-bf16",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to use bf16 in AMP.",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -1027,7 +1034,9 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(
|
||||||
|
enabled=params.use_autocast, dtype=params.dtype
|
||||||
|
):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1047,9 +1056,7 @@ def train_one_epoch(
|
|||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info(
|
logging.info(f"Caught exception: {e}.")
|
||||||
f"Caught exception: {e}."
|
|
||||||
)
|
|
||||||
save_bad_model()
|
save_bad_model()
|
||||||
display_and_save_batch(batch, params=params, sp=sp)
|
display_and_save_batch(batch, params=params, sp=sp)
|
||||||
raise
|
raise
|
||||||
@ -1090,7 +1097,7 @@ def train_one_epoch(
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % 100 == 0 and params.use_fp16:
|
if batch_idx % 100 == 0 and params.use_autocast:
|
||||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||||
# of the grad scaler is configurable, but we can't configure it to have different
|
# of the grad scaler is configurable, but we can't configure it to have different
|
||||||
# behavior depending on the current grad scale.
|
# behavior depending on the current grad scale.
|
||||||
@ -1109,14 +1116,14 @@ def train_one_epoch(
|
|||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = max(scheduler.get_last_lr())
|
cur_lr = max(scheduler.get_last_lr())
|
||||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, "
|
f"Epoch {params.cur_epoch}, "
|
||||||
f"batch {batch_idx}, loss[{loss_info}], "
|
f"batch {batch_idx}, loss[{loss_info}], "
|
||||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||||
f"lr: {cur_lr:.2e}, "
|
f"lr: {cur_lr:.2e}, "
|
||||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
+ (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "")
|
||||||
)
|
)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
@ -1128,7 +1135,7 @@ def train_one_epoch(
|
|||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
if params.use_fp16:
|
if params.use_autocast:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||||
)
|
)
|
||||||
@ -1204,9 +1211,25 @@ def run(rank, world_size, args):
|
|||||||
params.ctc_loss_scale = 1.0
|
params.ctc_loss_scale = 1.0
|
||||||
else:
|
else:
|
||||||
assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, (
|
assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, (
|
||||||
params.ctc_loss_scale, params.attention_decoder_loss_scale
|
params.ctc_loss_scale,
|
||||||
|
params.attention_decoder_loss_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if params.use_bf16: # amp + bf16
|
||||||
|
assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!"
|
||||||
|
assert not params.use_fp16, "You can only use either fp16 or bf16"
|
||||||
|
params.dtype = torch.bfloat16
|
||||||
|
params.use_autocast = True
|
||||||
|
elif params.use_fp16: # amp + fp16
|
||||||
|
params.dtype = torch.float16
|
||||||
|
params.use_autocast = True
|
||||||
|
else: # fp32
|
||||||
|
params.dtype = torch.float32
|
||||||
|
params.use_autocast = False
|
||||||
|
|
||||||
|
logging.info(f"Using dtype={params.dtype}")
|
||||||
|
logging.info(f"Use AMP={params.use_autocast}")
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
@ -1339,7 +1362,7 @@ def run(rank, world_size, args):
|
|||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1439,7 +1462,9 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(
|
||||||
|
enabled=params.use_autocast, dtype=params.dtype
|
||||||
|
):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user