mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
For older pytorch version
This commit is contained in:
parent
29d186118f
commit
34c380cd8b
@ -128,7 +128,7 @@ class Transducer(nn.Module):
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
with torch.autocast(device_type=x.device.type, enabled=False):
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=decoder_out.float(),
|
||||
am=encoder_out.float(),
|
||||
@ -158,7 +158,7 @@ class Transducer(nn.Module):
|
||||
# logits : [B, T, prune_range, C]
|
||||
logits = self.joiner(am_pruned, lm_pruned)
|
||||
|
||||
with torch.autocast(device_type=x.device.type, enabled=False):
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -647,9 +647,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.autocast(
|
||||
device_type=model.device.type, enabled=params.use_fp16
|
||||
):
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -920,9 +918,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
optimizer.zero_grad()
|
||||
with torch.autocast(
|
||||
device_type=model.device.type, enabled=params.use_fp16
|
||||
):
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
Loading…
x
Reference in New Issue
Block a user