mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-31 04:34:18 +00:00
fix too long audios
This commit is contained in:
parent
b76cd65abf
commit
1600f7db95
@ -441,6 +441,9 @@ def compute_loss(
|
|||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
feature = feature.to(device)
|
feature = feature.to(device)
|
||||||
feature = feature.transpose(1, 2) # (N, C, T)
|
feature = feature.transpose(1, 2) # (N, C, T)
|
||||||
|
# make sure feature T no more than 3000, otherwise cut it
|
||||||
|
if feature.shape[2] > 3000:
|
||||||
|
feature = feature[:, :, :3000]
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
@ -604,6 +607,18 @@ def train_one_epoch(
|
|||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
if params.deepspeed:
|
||||||
|
model.save_checkpoint(
|
||||||
|
save_dir=params.exp_dir,
|
||||||
|
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||||
|
client_state={},
|
||||||
|
)
|
||||||
|
if rank == 0:
|
||||||
|
convert_zero_checkpoint_to_fp32_state_dict(
|
||||||
|
params.exp_dir,
|
||||||
|
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
||||||
|
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user