mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Fix stateless7 training error (#1082)
This commit is contained in:
parent
585e7b224f
commit
dbcf0b41db
@ -56,8 +56,8 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from decoder import Decoder
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from decoder import Decoder
|
||||
from gigaspeech import GigaSpeechAsrDataModule
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut, CutSet
|
||||
@ -753,6 +753,7 @@ def compute_loss(
|
||||
# We set allowed_excess_duration_ratio=0.1.
|
||||
max_frames = params.max_duration * 1000 // params.frame_shift_ms
|
||||
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
|
||||
if is_training:
|
||||
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
|
||||
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
|
@ -660,7 +660,7 @@ def compute_loss(
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
# For the uneven-sized batch, the total duration after padding would possibly
|
||||
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
|
||||
# cause OOM. Hence, for each batch, which is sorted in descending order by length,
|
||||
# we simply drop the last few shortest samples, so that the retained total frames
|
||||
# (after padding) would not exceed `allowed_max_frames`:
|
||||
# `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
|
||||
@ -668,6 +668,7 @@ def compute_loss(
|
||||
# We set allowed_excess_duration_ratio=0.1.
|
||||
max_frames = params.max_duration * 1000 // params.frame_shift_ms
|
||||
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
|
||||
if is_training:
|
||||
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
|
||||
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
|
@ -1551,7 +1551,7 @@ def is_module_available(*modules: str) -> bool:
|
||||
|
||||
def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int):
|
||||
"""For the uneven-sized batch, the total duration after padding would possibly
|
||||
cause OOM. Hence, for each batch, which is sorted descendingly by length,
|
||||
cause OOM. Hence, for each batch, which is sorted in descending order by length,
|
||||
we simply drop the last few shortest samples, so that the retained total frames
|
||||
(after padding) would not exceed the given allow_max_frames.
|
||||
|
||||
@ -1567,20 +1567,20 @@ def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int):
|
||||
|
||||
N, T, _ = features.size()
|
||||
assert T == supervisions["num_frames"].max(), (T, supervisions["num_frames"].max())
|
||||
keep_num_utt = allowed_max_frames // T
|
||||
kept_num_utt = allowed_max_frames // T
|
||||
|
||||
if keep_num_utt >= N:
|
||||
if kept_num_utt >= N or kept_num_utt == 0:
|
||||
return batch
|
||||
|
||||
# Note: we assume the samples in batch is sorted descendingly by length
|
||||
logging.info(
|
||||
f"Filtering uneven-sized batch, original batch size is {N}, "
|
||||
f"retained batch size is {keep_num_utt}."
|
||||
f"retained batch size is {kept_num_utt}."
|
||||
)
|
||||
batch["inputs"] = features[:keep_num_utt]
|
||||
batch["inputs"] = features[:kept_num_utt]
|
||||
for k, v in supervisions.items():
|
||||
assert len(v) == N, (len(v), N)
|
||||
batch["supervisions"][k] = v[:keep_num_utt]
|
||||
batch["supervisions"][k] = v[:kept_num_utt]
|
||||
|
||||
return batch
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user