mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Fix stateless7 training error (#1082)
This commit is contained in:
parent
585e7b224f
commit
dbcf0b41db
@ -56,8 +56,8 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from decoder import Decoder
|
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from decoder import Decoder
|
||||||
from gigaspeech import GigaSpeechAsrDataModule
|
from gigaspeech import GigaSpeechAsrDataModule
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut, CutSet
|
from lhotse.cut import Cut, CutSet
|
||||||
@ -124,9 +124,9 @@ def add_finetune_arguments(parser: argparse.ArgumentParser):
|
|||||||
default=None,
|
default=None,
|
||||||
help="""
|
help="""
|
||||||
Modules to be initialized. It matches all parameters starting with
|
Modules to be initialized. It matches all parameters starting with
|
||||||
a specific key. The keys are given with Comma seperated. If None,
|
a specific key. The keys are given with Comma seperated. If None,
|
||||||
all modules will be initialised. For example, if you only want to
|
all modules will be initialised. For example, if you only want to
|
||||||
initialise all parameters staring with "encoder", use "encoder";
|
initialise all parameters staring with "encoder", use "encoder";
|
||||||
if you want to initialise parameters starting with encoder or decoder,
|
if you want to initialise parameters starting with encoder or decoder,
|
||||||
use "encoder,joiner".
|
use "encoder,joiner".
|
||||||
""",
|
""",
|
||||||
@ -185,7 +185,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
type=str,
|
type=str,
|
||||||
default="256,256,256,256,256",
|
default="256,256,256,256,256",
|
||||||
help="""Unmasked dimensions in the encoders, relates to augmentation
|
help="""Unmasked dimensions in the encoders, relates to augmentation
|
||||||
during training. Must be <= each of encoder_dims. Empirically, less
|
during training. Must be <= each of encoder_dims. Empirically, less
|
||||||
than 256 seems to make performance worse.
|
than 256 seems to make performance worse.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -288,7 +288,7 @@ def get_parser():
|
|||||||
"--bpe-model",
|
"--bpe-model",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/bpe.model",
|
||||||
help="""Path to the BPE model.
|
help="""Path to the BPE model.
|
||||||
This should be the bpe model of the original model
|
This should be the bpe model of the original model
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -302,8 +302,8 @@ def get_parser():
|
|||||||
type=float,
|
type=float,
|
||||||
default=100000,
|
default=100000,
|
||||||
help="""Number of steps that affects how rapidly the learning rate
|
help="""Number of steps that affects how rapidly the learning rate
|
||||||
decreases. During fine-tuning, we set this very large so that the
|
decreases. During fine-tuning, we set this very large so that the
|
||||||
learning rate slowly decays with number of batches. You may tune
|
learning rate slowly decays with number of batches. You may tune
|
||||||
its value by yourself.
|
its value by yourself.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -312,9 +312,9 @@ def get_parser():
|
|||||||
"--lr-epochs",
|
"--lr-epochs",
|
||||||
type=float,
|
type=float,
|
||||||
default=100,
|
default=100,
|
||||||
help="""Number of epochs that affects how rapidly the learning rate
|
help="""Number of epochs that affects how rapidly the learning rate
|
||||||
decreases. During fine-tuning, we set this very large so that the
|
decreases. During fine-tuning, we set this very large so that the
|
||||||
learning rate slowly decays with number of batches. You may tune
|
learning rate slowly decays with number of batches. You may tune
|
||||||
its value by yourself.
|
its value by yourself.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -753,7 +753,8 @@ def compute_loss(
|
|||||||
# We set allowed_excess_duration_ratio=0.1.
|
# We set allowed_excess_duration_ratio=0.1.
|
||||||
max_frames = params.max_duration * 1000 // params.frame_shift_ms
|
max_frames = params.max_duration * 1000 // params.frame_shift_ms
|
||||||
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
|
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
|
||||||
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
|
if is_training:
|
||||||
|
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
|
||||||
|
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
|
@ -660,7 +660,7 @@ def compute_loss(
|
|||||||
values >= 1.0 are fully warmed up and have all modules present.
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
"""
|
"""
|
||||||
# For the uneven-sized batch, the total duration after padding would possibly
|
# For the uneven-sized batch, the total duration after padding would possibly
|
||||||
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
|
# cause OOM. Hence, for each batch, which is sorted in descending order by length,
|
||||||
# we simply drop the last few shortest samples, so that the retained total frames
|
# we simply drop the last few shortest samples, so that the retained total frames
|
||||||
# (after padding) would not exceed `allowed_max_frames`:
|
# (after padding) would not exceed `allowed_max_frames`:
|
||||||
# `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
|
# `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
|
||||||
@ -668,7 +668,8 @@ def compute_loss(
|
|||||||
# We set allowed_excess_duration_ratio=0.1.
|
# We set allowed_excess_duration_ratio=0.1.
|
||||||
max_frames = params.max_duration * 1000 // params.frame_shift_ms
|
max_frames = params.max_duration * 1000 // params.frame_shift_ms
|
||||||
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
|
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
|
||||||
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
|
if is_training:
|
||||||
|
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
|
||||||
|
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
|
@ -1551,7 +1551,7 @@ def is_module_available(*modules: str) -> bool:
|
|||||||
|
|
||||||
def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int):
|
def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int):
|
||||||
"""For the uneven-sized batch, the total duration after padding would possibly
|
"""For the uneven-sized batch, the total duration after padding would possibly
|
||||||
cause OOM. Hence, for each batch, which is sorted descendingly by length,
|
cause OOM. Hence, for each batch, which is sorted in descending order by length,
|
||||||
we simply drop the last few shortest samples, so that the retained total frames
|
we simply drop the last few shortest samples, so that the retained total frames
|
||||||
(after padding) would not exceed the given allow_max_frames.
|
(after padding) would not exceed the given allow_max_frames.
|
||||||
|
|
||||||
@ -1567,20 +1567,20 @@ def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int):
|
|||||||
|
|
||||||
N, T, _ = features.size()
|
N, T, _ = features.size()
|
||||||
assert T == supervisions["num_frames"].max(), (T, supervisions["num_frames"].max())
|
assert T == supervisions["num_frames"].max(), (T, supervisions["num_frames"].max())
|
||||||
keep_num_utt = allowed_max_frames // T
|
kept_num_utt = allowed_max_frames // T
|
||||||
|
|
||||||
if keep_num_utt >= N:
|
if kept_num_utt >= N or kept_num_utt == 0:
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
# Note: we assume the samples in batch is sorted descendingly by length
|
# Note: we assume the samples in batch is sorted descendingly by length
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Filtering uneven-sized batch, original batch size is {N}, "
|
f"Filtering uneven-sized batch, original batch size is {N}, "
|
||||||
f"retained batch size is {keep_num_utt}."
|
f"retained batch size is {kept_num_utt}."
|
||||||
)
|
)
|
||||||
batch["inputs"] = features[:keep_num_utt]
|
batch["inputs"] = features[:kept_num_utt]
|
||||||
for k, v in supervisions.items():
|
for k, v in supervisions.items():
|
||||||
assert len(v) == N, (len(v), N)
|
assert len(v) == N, (len(v), N)
|
||||||
batch["supervisions"][k] = v[:keep_num_utt]
|
batch["supervisions"][k] = v[:kept_num_utt]
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user