mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
remove duplicated torch autocast
This commit is contained in:
parent
5fbeed9f96
commit
9939c2b72d
@ -501,6 +501,8 @@ def compute_loss(
|
||||
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
if params.use_fp16:
|
||||
feature = feature.half()
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"]
|
||||
@ -559,14 +561,13 @@ def compute_validation_loss(
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
batch=batch,
|
||||
is_training=False,
|
||||
)
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
batch=batch,
|
||||
is_training=False,
|
||||
)
|
||||
assert loss.requires_grad is False
|
||||
tot_loss = tot_loss + loss_info
|
||||
|
||||
@ -680,14 +681,13 @@ def train_one_epoch(
|
||||
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
||||
)
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
)
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user