mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
hardcode --filter-uneven-sized-batch (#854)
This commit is contained in:
parent
f5ff7a18eb
commit
6b1ab71dc9
@ -374,21 +374,6 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--filter-uneven-sized-batch",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="""Whether to filter uneven-sized minibatch.
|
|
||||||
For the uneven-sized batch, the total duration after padding would possibly
|
|
||||||
cause OOM. Hence, for each batch, which is sorted descendingly by length,
|
|
||||||
we simply drop the last few shortest samples, so that the retained total frames
|
|
||||||
(after padding) would not exceed `allowed_max_frames`:
|
|
||||||
`allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
|
|
||||||
where `max_frames = max_duration * 1000 // frame_shift_ms`.
|
|
||||||
We set allowed_excess_duration_ratio=0.1.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -442,7 +427,6 @@ def get_params() -> AttributeDict:
|
|||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
"frame_shift_ms": 10.0,
|
"frame_shift_ms": 10.0,
|
||||||
# only used when params.filter_uneven_sized_batch is True
|
|
||||||
"allowed_excess_duration_ratio": 0.1,
|
"allowed_excess_duration_ratio": 0.1,
|
||||||
"best_train_loss": float("inf"),
|
"best_train_loss": float("inf"),
|
||||||
"best_valid_loss": float("inf"),
|
"best_valid_loss": float("inf"),
|
||||||
@ -666,12 +650,16 @@ def compute_loss(
|
|||||||
warmup: a floating point value which increases throughout training;
|
warmup: a floating point value which increases throughout training;
|
||||||
values >= 1.0 are fully warmed up and have all modules present.
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
"""
|
"""
|
||||||
if params.filter_uneven_sized_batch:
|
# For the uneven-sized batch, the total duration after padding would possibly
|
||||||
max_frames = params.max_duration * 1000 // params.frame_shift_ms
|
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
|
||||||
allowed_max_frames = int(
|
# we simply drop the last few shortest samples, so that the retained total frames
|
||||||
max_frames * (1.0 + params.allowed_excess_duration_ratio)
|
# (after padding) would not exceed `allowed_max_frames`:
|
||||||
)
|
# `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
|
||||||
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
|
# where `max_frames = max_duration * 1000 // frame_shift_ms`.
|
||||||
|
# We set allowed_excess_duration_ratio=0.1.
|
||||||
|
max_frames = params.max_duration * 1000 // params.frame_shift_ms
|
||||||
|
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
|
||||||
|
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
|
||||||
|
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user