mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-31 04:34:18 +00:00
fix too long audios
This commit is contained in:
parent
b76cd65abf
commit
1600f7db95
@ -441,6 +441,9 @@ def compute_loss(
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
feature = feature.transpose(1, 2) # (N, C, T)
|
||||
# make sure feature T no more than 3000, otherwise cut it
|
||||
if feature.shape[2] > 3000:
|
||||
feature = feature[:, :, :3000]
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
@ -604,6 +607,18 @@ def train_one_epoch(
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
if params.deepspeed:
|
||||
model.save_checkpoint(
|
||||
save_dir=params.exp_dir,
|
||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||
client_state={},
|
||||
)
|
||||
if rank == 0:
|
||||
convert_zero_checkpoint_to_fp32_state_dict(
|
||||
params.exp_dir,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||
)
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
|
Loading…
x
Reference in New Issue
Block a user