minor changes

This commit is contained in:
marcoyang 2024-08-21 16:35:32 +08:00
parent dc353dcc7b
commit b585e14de3
2 changed files with 51 additions and 23 deletions

View File

@ -307,6 +307,23 @@ done
To decode with external language models, please refer to the documentation [here](https://k2-fsa.github.io/icefall/decoding-with-langugage-models/index.html).
We also support training Zipformer with AMP+bf16 format (requires bf16 support). See [here](https://github.com/k2-fsa/icefall/pull/1700) for more details and pre-trained models. **The same command can be used for decoding and exporting the model.**
The amp+bf16 training command is:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer/train.py \
--world-size 4 \
--num-epochs 50 \
--start-epoch 1 \
--use-fp16 0 \
--use-bf16 1 \
--exp-dir zipformer/exp_amp_bf16 \
--causal 0 \
--full-libri 1 \
--max-duration 1000
```
##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M
The tensorboard log can be found at

View File

@ -1034,7 +1034,9 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16, dtype=params.dtype):
with torch.cuda.amp.autocast(
enabled=params.use_autocast, dtype=params.dtype
):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1054,9 +1056,7 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
except Exception as e:
logging.info(
f"Caught exception: {e}."
)
logging.info(f"Caught exception: {e}.")
save_bad_model()
display_and_save_batch(batch, params=params, sp=sp)
raise
@ -1097,7 +1097,7 @@ def train_one_epoch(
rank=rank,
)
if batch_idx % 100 == 0 and params.use_fp16:
if batch_idx % 100 == 0 and params.use_autocast:
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
@ -1116,14 +1116,14 @@ def train_one_epoch(
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "")
)
if tb_writer is not None:
@ -1135,7 +1135,7 @@ def train_one_epoch(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
if params.use_autocast:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
@ -1211,15 +1211,24 @@ def run(rank, world_size, args):
params.ctc_loss_scale = 1.0
else:
assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, (
params.ctc_loss_scale, params.attention_decoder_loss_scale
params.ctc_loss_scale,
params.attention_decoder_loss_scale,
)
if params.use_bf16:
assert torch.cuda.is_bf16_supported(), f"Your GPU does not support bf16!"
if params.use_bf16: # amp + bf16
assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!"
assert not params.use_fp16, "You can only use either fp16 or bf16"
params.dtype = torch.bfloat16
else:
params.use_autocast = True
elif params.use_fp16: # amp + fp16
params.dtype = torch.float16
params.use_autocast = True
else: # fp32
params.dtype = torch.float32
params.use_autocast = False
logging.info(f"Using dtype={params.dtype}")
logging.info(f"Use AMP={params.use_autocast}")
logging.info(params)
@ -1344,16 +1353,16 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
# if not params.print_diagnostics:
# scan_pessimistic_batches_for_oom(
# model=model,
# train_dl=train_dl,
# optimizer=optimizer,
# sp=sp,
# params=params,
# )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1453,7 +1462,9 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16, dtype=params.dtype):
with torch.cuda.amp.autocast(
enabled=params.use_autocast, dtype=params.dtype
):
loss, _ = compute_loss(
params=params,
model=model,