Filter uneven-sized batch (#843)

* add filter_uneven_sized_batch fucntion

* set --filter-uneven-sized-batch=True as default
This commit is contained in:
Zengwei Yao 2023-01-16 20:15:35 +08:00 committed by GitHub
parent 5c8e9628cc
commit 2a463a420d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 68 additions and 1 deletions

View File

@ -82,7 +82,13 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
filter_uneven_sized_batch,
setup_logger,
str2bool,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -368,6 +374,21 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--filter-uneven-sized-batch",
type=str2bool,
default=True,
help="""Whether to filter uneven-sized minibatch.
For the uneven-sized batch, the total duration after padding would possibly
cause OOM. Hence, for each batch, which is sorted descendingly by length,
we simply drop the last few shortest samples, so that the retained total frames
(after padding) would not exceed `allowed_max_frames`:
`allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
where `max_frames = max_duration * 1000 // frame_shift_ms`.
We set allowed_excess_duration_ratio=0.1.
""",
)
add_model_arguments(parser)
return parser
@ -420,6 +441,9 @@ def get_params() -> AttributeDict:
"""
params = AttributeDict(
{
"frame_shift_ms": 10.0,
# only used when params.filter_uneven_sized_batch is True
"allowed_excess_duration_ratio": 0.1,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
@ -642,6 +666,13 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
if params.filter_uneven_sized_batch:
max_frames = params.max_duration * 1000 // params.frame_shift_ms
allowed_max_frames = int(
max_frames * (1.0 + params.allowed_excess_duration_ratio)
)
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)

View File

@ -1395,3 +1395,39 @@ def is_module_available(*modules: str) -> bool:
import importlib
return all(importlib.util.find_spec(m) is not None for m in modules)
def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int):
"""For the uneven-sized batch, the total duration after padding would possibly
cause OOM. Hence, for each batch, which is sorted descendingly by length,
we simply drop the last few shortest samples, so that the retained total frames
(after padding) would not exceed the given allow_max_frames.
Args:
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
allowed_max_frames:
The allowed max number of frames in batch.
"""
features = batch["inputs"]
supervisions = batch["supervisions"]
N, T, _ = features.size()
assert T == supervisions["num_frames"].max(), (T, supervisions["num_frames"].max())
keep_num_utt = allowed_max_frames // T
if keep_num_utt >= N:
return batch
# Note: we assume the samples in batch is sorted descendingly by length
logging.info(
f"Filtering uneven-sized batch, original batch size is {N}, "
f"retained batch size is {keep_num_utt}."
)
batch["inputs"] = features[:keep_num_utt]
for k, v in supervisions.items():
assert len(v) == N, (len(v), N)
batch["supervisions"][k] = v[:keep_num_utt]
return batch