Fix stateless7 training error (#1082)

This commit is contained in:
Fangjun Kuang 2023-05-23 12:52:02 +08:00 committed by GitHub
parent 585e7b224f
commit dbcf0b41db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 20 deletions

View File

@ -56,8 +56,8 @@ import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from decoder import Decoder
from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder
from gigaspeech import GigaSpeechAsrDataModule
from joiner import Joiner
from lhotse.cut import Cut, CutSet
@ -124,9 +124,9 @@ def add_finetune_arguments(parser: argparse.ArgumentParser):
default=None,
help="""
Modules to be initialized. It matches all parameters starting with
a specific key. The keys are given with Comma seperated. If None,
all modules will be initialised. For example, if you only want to
initialise all parameters staring with "encoder", use "encoder";
a specific key. The keys are given with Comma seperated. If None,
all modules will be initialised. For example, if you only want to
initialise all parameters staring with "encoder", use "encoder";
if you want to initialise parameters starting with encoder or decoder,
use "encoder,joiner".
""",
@ -185,7 +185,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
type=str,
default="256,256,256,256,256",
help="""Unmasked dimensions in the encoders, relates to augmentation
during training. Must be <= each of encoder_dims. Empirically, less
during training. Must be <= each of encoder_dims. Empirically, less
than 256 seems to make performance worse.
""",
)
@ -288,7 +288,7 @@ def get_parser():
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="""Path to the BPE model.
help="""Path to the BPE model.
This should be the bpe model of the original model
""",
)
@ -302,8 +302,8 @@ def get_parser():
type=float,
default=100000,
help="""Number of steps that affects how rapidly the learning rate
decreases. During fine-tuning, we set this very large so that the
learning rate slowly decays with number of batches. You may tune
decreases. During fine-tuning, we set this very large so that the
learning rate slowly decays with number of batches. You may tune
its value by yourself.
""",
)
@ -312,9 +312,9 @@ def get_parser():
"--lr-epochs",
type=float,
default=100,
help="""Number of epochs that affects how rapidly the learning rate
decreases. During fine-tuning, we set this very large so that the
learning rate slowly decays with number of batches. You may tune
help="""Number of epochs that affects how rapidly the learning rate
decreases. During fine-tuning, we set this very large so that the
learning rate slowly decays with number of batches. You may tune
its value by yourself.
""",
)
@ -753,7 +753,8 @@ def compute_loss(
# We set allowed_excess_duration_ratio=0.1.
max_frames = params.max_duration * 1000 // params.frame_shift_ms
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
if is_training:
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]

View File

@ -660,7 +660,7 @@ def compute_loss(
values >= 1.0 are fully warmed up and have all modules present.
"""
# For the uneven-sized batch, the total duration after padding would possibly
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
# cause OOM. Hence, for each batch, which is sorted in descending order by length,
# we simply drop the last few shortest samples, so that the retained total frames
# (after padding) would not exceed `allowed_max_frames`:
# `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
@ -668,7 +668,8 @@ def compute_loss(
# We set allowed_excess_duration_ratio=0.1.
max_frames = params.max_duration * 1000 // params.frame_shift_ms
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
if is_training:
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]

View File

@ -1551,7 +1551,7 @@ def is_module_available(*modules: str) -> bool:
def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int):
"""For the uneven-sized batch, the total duration after padding would possibly
cause OOM. Hence, for each batch, which is sorted descendingly by length,
cause OOM. Hence, for each batch, which is sorted in descending order by length,
we simply drop the last few shortest samples, so that the retained total frames
(after padding) would not exceed the given allow_max_frames.
@ -1567,20 +1567,20 @@ def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int):
N, T, _ = features.size()
assert T == supervisions["num_frames"].max(), (T, supervisions["num_frames"].max())
keep_num_utt = allowed_max_frames // T
kept_num_utt = allowed_max_frames // T
if keep_num_utt >= N:
if kept_num_utt >= N or kept_num_utt == 0:
return batch
# Note: we assume the samples in batch is sorted descendingly by length
logging.info(
f"Filtering uneven-sized batch, original batch size is {N}, "
f"retained batch size is {keep_num_utt}."
f"retained batch size is {kept_num_utt}."
)
batch["inputs"] = features[:keep_num_utt]
batch["inputs"] = features[:kept_num_utt]
for k, v in supervisions.items():
assert len(v) == N, (len(v), N)
batch["supervisions"][k] = v[:keep_num_utt]
batch["supervisions"][k] = v[:kept_num_utt]
return batch