mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
For older pytorch version
This commit is contained in:
parent
29d186118f
commit
34c380cd8b
@ -128,7 +128,7 @@ class Transducer(nn.Module):
|
|||||||
boundary[:, 2] = y_lens
|
boundary[:, 2] = y_lens
|
||||||
boundary[:, 3] = x_lens
|
boundary[:, 3] = x_lens
|
||||||
|
|
||||||
with torch.autocast(device_type=x.device.type, enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
lm=decoder_out.float(),
|
lm=decoder_out.float(),
|
||||||
am=encoder_out.float(),
|
am=encoder_out.float(),
|
||||||
@ -158,7 +158,7 @@ class Transducer(nn.Module):
|
|||||||
# logits : [B, T, prune_range, C]
|
# logits : [B, T, prune_range, C]
|
||||||
logits = self.joiner(am_pruned, lm_pruned)
|
logits = self.joiner(am_pruned, lm_pruned)
|
||||||
|
|
||||||
with torch.autocast(device_type=x.device.type, enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
logits=logits.float(),
|
logits=logits.float(),
|
||||||
symbols=y_padded,
|
symbols=y_padded,
|
||||||
|
@ -647,9 +647,7 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
with torch.autocast(
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
device_type=model.device.type, enabled=params.use_fp16
|
|
||||||
):
|
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -920,9 +918,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
with torch.autocast(
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
device_type=model.device.type, enabled=params.use_fp16
|
|
||||||
):
|
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user