For older pytorch version

This commit is contained in:
pkufool 2022-04-05 07:50:52 +08:00
parent 29d186118f
commit 34c380cd8b
2 changed files with 4 additions and 8 deletions

View File

@ -128,7 +128,7 @@ class Transducer(nn.Module):
boundary[:, 2] = y_lens boundary[:, 2] = y_lens
boundary[:, 3] = x_lens boundary[:, 3] = x_lens
with torch.autocast(device_type=x.device.type, enabled=False): with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=decoder_out.float(), lm=decoder_out.float(),
am=encoder_out.float(), am=encoder_out.float(),
@ -158,7 +158,7 @@ class Transducer(nn.Module):
# logits : [B, T, prune_range, C] # logits : [B, T, prune_range, C]
logits = self.joiner(am_pruned, lm_pruned) logits = self.joiner(am_pruned, lm_pruned)
with torch.autocast(device_type=x.device.type, enabled=False): with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -647,9 +647,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.autocast( with torch.cuda.amp.autocast(enabled=params.use_fp16):
device_type=model.device.type, enabled=params.use_fp16
):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -920,9 +918,7 @@ def scan_pessimistic_batches_for_oom(
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
optimizer.zero_grad() optimizer.zero_grad()
with torch.autocast( with torch.cuda.amp.autocast(enabled=params.use_fp16):
device_type=model.device.type, enabled=params.use_fp16
):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,