Fix stateless7 training error (#1082)

This commit is contained in:
Fangjun Kuang 2023-05-23 12:52:02 +08:00 committed by GitHub
parent 585e7b224f
commit dbcf0b41db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 20 deletions

View File

@ -56,8 +56,8 @@ import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from decoder import Decoder
from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder
from gigaspeech import GigaSpeechAsrDataModule
from joiner import Joiner
from lhotse.cut import Cut, CutSet
@ -753,6 +753,7 @@ def compute_loss(
# We set allowed_excess_duration_ratio=0.1.
max_frames = params.max_duration * 1000 // params.frame_shift_ms
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
if is_training:
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device

View File

@ -660,7 +660,7 @@ def compute_loss(
values >= 1.0 are fully warmed up and have all modules present.
"""
# For the uneven-sized batch, the total duration after padding would possibly
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
# cause OOM. Hence, for each batch, which is sorted in descending order by length,
# we simply drop the last few shortest samples, so that the retained total frames
# (after padding) would not exceed `allowed_max_frames`:
# `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
@ -668,6 +668,7 @@ def compute_loss(
# We set allowed_excess_duration_ratio=0.1.
max_frames = params.max_duration * 1000 // params.frame_shift_ms
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
if is_training:
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device

View File

@ -1551,7 +1551,7 @@ def is_module_available(*modules: str) -> bool:
def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int):
"""For the uneven-sized batch, the total duration after padding would possibly
cause OOM. Hence, for each batch, which is sorted descendingly by length,
cause OOM. Hence, for each batch, which is sorted in descending order by length,
we simply drop the last few shortest samples, so that the retained total frames
(after padding) would not exceed the given allow_max_frames.
@ -1567,20 +1567,20 @@ def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int):
N, T, _ = features.size()
assert T == supervisions["num_frames"].max(), (T, supervisions["num_frames"].max())
keep_num_utt = allowed_max_frames // T
kept_num_utt = allowed_max_frames // T
if keep_num_utt >= N:
if kept_num_utt >= N or kept_num_utt == 0:
return batch
# Note: we assume the samples in batch is sorted descendingly by length
logging.info(
f"Filtering uneven-sized batch, original batch size is {N}, "
f"retained batch size is {keep_num_utt}."
f"retained batch size is {kept_num_utt}."
)
batch["inputs"] = features[:keep_num_utt]
batch["inputs"] = features[:kept_num_utt]
for k, v in supervisions.items():
assert len(v) == N, (len(v), N)
batch["supervisions"][k] = v[:keep_num_utt]
batch["supervisions"][k] = v[:kept_num_utt]
return batch