mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Filter uneven-sized batch (#843)
* add filter_uneven_sized_batch fucntion * set --filter-uneven-sized-batch=True as default
This commit is contained in:
parent
5c8e9628cc
commit
2a463a420d
@ -82,7 +82,13 @@ from icefall.checkpoint import (
|
|||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
filter_uneven_sized_batch,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
@ -368,6 +374,21 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--filter-uneven-sized-batch",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="""Whether to filter uneven-sized minibatch.
|
||||||
|
For the uneven-sized batch, the total duration after padding would possibly
|
||||||
|
cause OOM. Hence, for each batch, which is sorted descendingly by length,
|
||||||
|
we simply drop the last few shortest samples, so that the retained total frames
|
||||||
|
(after padding) would not exceed `allowed_max_frames`:
|
||||||
|
`allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
|
||||||
|
where `max_frames = max_duration * 1000 // frame_shift_ms`.
|
||||||
|
We set allowed_excess_duration_ratio=0.1.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -420,6 +441,9 @@ def get_params() -> AttributeDict:
|
|||||||
"""
|
"""
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
|
"frame_shift_ms": 10.0,
|
||||||
|
# only used when params.filter_uneven_sized_batch is True
|
||||||
|
"allowed_excess_duration_ratio": 0.1,
|
||||||
"best_train_loss": float("inf"),
|
"best_train_loss": float("inf"),
|
||||||
"best_valid_loss": float("inf"),
|
"best_valid_loss": float("inf"),
|
||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
@ -642,6 +666,13 @@ def compute_loss(
|
|||||||
warmup: a floating point value which increases throughout training;
|
warmup: a floating point value which increases throughout training;
|
||||||
values >= 1.0 are fully warmed up and have all modules present.
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
"""
|
"""
|
||||||
|
if params.filter_uneven_sized_batch:
|
||||||
|
max_frames = params.max_duration * 1000 // params.frame_shift_ms
|
||||||
|
allowed_max_frames = int(
|
||||||
|
max_frames * (1.0 + params.allowed_excess_duration_ratio)
|
||||||
|
)
|
||||||
|
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
|
||||||
|
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
|
@ -1395,3 +1395,39 @@ def is_module_available(*modules: str) -> bool:
|
|||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
return all(importlib.util.find_spec(m) is not None for m in modules)
|
return all(importlib.util.find_spec(m) is not None for m in modules)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int):
|
||||||
|
"""For the uneven-sized batch, the total duration after padding would possibly
|
||||||
|
cause OOM. Hence, for each batch, which is sorted descendingly by length,
|
||||||
|
we simply drop the last few shortest samples, so that the retained total frames
|
||||||
|
(after padding) would not exceed the given allow_max_frames.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch:
|
||||||
|
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||||
|
for the content in it.
|
||||||
|
allowed_max_frames:
|
||||||
|
The allowed max number of frames in batch.
|
||||||
|
"""
|
||||||
|
features = batch["inputs"]
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
|
||||||
|
N, T, _ = features.size()
|
||||||
|
assert T == supervisions["num_frames"].max(), (T, supervisions["num_frames"].max())
|
||||||
|
keep_num_utt = allowed_max_frames // T
|
||||||
|
|
||||||
|
if keep_num_utt >= N:
|
||||||
|
return batch
|
||||||
|
|
||||||
|
# Note: we assume the samples in batch is sorted descendingly by length
|
||||||
|
logging.info(
|
||||||
|
f"Filtering uneven-sized batch, original batch size is {N}, "
|
||||||
|
f"retained batch size is {keep_num_utt}."
|
||||||
|
)
|
||||||
|
batch["inputs"] = features[:keep_num_utt]
|
||||||
|
for k, v in supervisions.items():
|
||||||
|
assert len(v) == N, (len(v), N)
|
||||||
|
batch["supervisions"][k] = v[:keep_num_utt]
|
||||||
|
|
||||||
|
return batch
|
||||||
|
Loading…
x
Reference in New Issue
Block a user