hardcode --filter-uneven-sized-batch (#854)

This commit is contained in:
Zengwei Yao 2023-01-27 21:24:12 +08:00 committed by GitHub
parent f5ff7a18eb
commit 6b1ab71dc9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -374,21 +374,6 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--filter-uneven-sized-batch",
type=str2bool,
default=True,
help="""Whether to filter uneven-sized minibatch.
For the uneven-sized batch, the total duration after padding would possibly
cause OOM. Hence, for each batch, which is sorted descendingly by length,
we simply drop the last few shortest samples, so that the retained total frames
(after padding) would not exceed `allowed_max_frames`:
`allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
where `max_frames = max_duration * 1000 // frame_shift_ms`.
We set allowed_excess_duration_ratio=0.1.
""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -442,7 +427,6 @@ def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"frame_shift_ms": 10.0, "frame_shift_ms": 10.0,
# only used when params.filter_uneven_sized_batch is True
"allowed_excess_duration_ratio": 0.1, "allowed_excess_duration_ratio": 0.1,
"best_train_loss": float("inf"), "best_train_loss": float("inf"),
"best_valid_loss": float("inf"), "best_valid_loss": float("inf"),
@ -666,11 +650,15 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
if params.filter_uneven_sized_batch: # For the uneven-sized batch, the total duration after padding would possibly
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
# we simply drop the last few shortest samples, so that the retained total frames
# (after padding) would not exceed `allowed_max_frames`:
# `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
# where `max_frames = max_duration * 1000 // frame_shift_ms`.
# We set allowed_excess_duration_ratio=0.1.
max_frames = params.max_duration * 1000 // params.frame_shift_ms max_frames = params.max_duration * 1000 // params.frame_shift_ms
allowed_max_frames = int( allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
max_frames * (1.0 + params.allowed_excess_duration_ratio)
)
batch = filter_uneven_sized_batch(batch, allowed_max_frames) batch = filter_uneven_sized_batch(batch, allowed_max_frames)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device device = model.device if isinstance(model, DDP) else next(model.parameters()).device