mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
remove duplicated torch autocast
This commit is contained in:
parent
5fbeed9f96
commit
9939c2b72d
@ -501,6 +501,8 @@ def compute_loss(
|
|||||||
|
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
|
if params.use_fp16:
|
||||||
|
feature = feature.half()
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"]
|
feature_lens = supervisions["num_frames"]
|
||||||
@ -559,14 +561,13 @@ def compute_validation_loss(
|
|||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
loss, loss_info = compute_loss(
|
||||||
loss, loss_info = compute_loss(
|
params=params,
|
||||||
params=params,
|
tokenizer=tokenizer,
|
||||||
tokenizer=tokenizer,
|
model=model,
|
||||||
model=model,
|
batch=batch,
|
||||||
batch=batch,
|
is_training=False,
|
||||||
is_training=False,
|
)
|
||||||
)
|
|
||||||
assert loss.requires_grad is False
|
assert loss.requires_grad is False
|
||||||
tot_loss = tot_loss + loss_info
|
tot_loss = tot_loss + loss_info
|
||||||
|
|
||||||
@ -680,14 +681,13 @@ def train_one_epoch(
|
|||||||
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
loss, loss_info = compute_loss(
|
||||||
loss, loss_info = compute_loss(
|
params=params,
|
||||||
params=params,
|
tokenizer=tokenizer,
|
||||||
tokenizer=tokenizer,
|
model=model,
|
||||||
model=model,
|
batch=batch,
|
||||||
batch=batch,
|
is_training=True,
|
||||||
is_training=True,
|
)
|
||||||
)
|
|
||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user