mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
remove batch shaving
This commit is contained in:
parent
f5d2aa1f5d
commit
46b9be31cc
@ -71,7 +71,6 @@ from icefall.env import get_env_info
|
|||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
filter_uneven_sized_batch,
|
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
)
|
)
|
||||||
@ -278,13 +277,6 @@ def compute_loss(
|
|||||||
Returns:
|
Returns:
|
||||||
Return a tuple of two elements. The first element is the loss tensor.
|
Return a tuple of two elements. The first element is the loss tensor.
|
||||||
"""
|
"""
|
||||||
# For the uneven-sized batch, the total duration after padding would possibly
|
|
||||||
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
|
|
||||||
# we simply drop the last few shortest samples, so that the retained total frames
|
|
||||||
# (after padding) would not exceed `allowed_max_frames`:
|
|
||||||
# `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
|
|
||||||
# where `max_frames = max_duration * 1000 // frame_shift_ms`.
|
|
||||||
# We set allowed_excess_duration_ratio=0.1.
|
|
||||||
|
|
||||||
def preprocess(
|
def preprocess(
|
||||||
messages,
|
messages,
|
||||||
@ -375,10 +367,6 @@ def compute_loss(
|
|||||||
text = text.replace("?", "")
|
text = text.replace("?", "")
|
||||||
return text
|
return text
|
||||||
|
|
||||||
max_frames = params.max_duration * 1000 // params.frame_shift_ms
|
|
||||||
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
|
|
||||||
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
|
|
||||||
|
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user