mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
remove duplicated torch autocast
This commit is contained in:
parent
5fbeed9f96
commit
9939c2b72d
@ -501,6 +501,8 @@ def compute_loss(
|
|||||||
|
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
|
if params.use_fp16:
|
||||||
|
feature = feature.half()
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"]
|
feature_lens = supervisions["num_frames"]
|
||||||
@ -559,7 +561,6 @@ def compute_validation_loss(
|
|||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -680,7 +681,6 @@ def train_one_epoch(
|
|||||||
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user