mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
update
This commit is contained in:
parent
5d59f48193
commit
3a07dbddf0
@ -299,7 +299,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
if [ -f ../../../../gigaspeech/ASR/data/fbank/XL_split/.split_completed ]; then
|
||||
ln -svf $(realpath ../../../../gigaspeech/ASR/data/fbank/XL_split) .
|
||||
else
|
||||
log "Abort! Please run gigaspeech prepare.sh --stage 5 --stop-stage 6"
|
||||
log "Abort! Please run ../../gigaspeech/ASR/prepare.sh --stage 5 --stop-stage 6"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@ -315,7 +315,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
ln -svf $(realpath ../../../../commonvoice/ASR/data/en/fbank/cv-en_train_split_1000) .
|
||||
ln -svf $(realpath ../../../../commonvoice/ASR/data/en/fbank/cv-en_cuts_train.jsonl.gz) .
|
||||
else
|
||||
log "Abort! Please run commonvoice prepare.sh --stage 5 --stop-stage 6"
|
||||
log "Abort! Please run ../../commonvoice/ASR/prepare.sh --stage 5 --stop-stage 6"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@ -330,7 +330,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
if [ -f ../../../../peoples_speech/ASR/data/fbank/.peoples_speech_train.done ]; then
|
||||
ln -svf $(realpath ../../../../peoples_speech/ASR/data/fbank/peoples_speech_train_split) .
|
||||
else
|
||||
log "Abort! Please run commonvoice prepare.sh --stage 5 --stop-stage 6"
|
||||
log "Abort! Please run ../../peoples_speech/prepare.sh --stage 5 --stop-stage 6"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
@ -33,11 +33,11 @@ class MultiDataset:
|
||||
|
||||
- librispeech_cuts_train-all-shuf.jsonl.gz
|
||||
- XL_split_2000/cuts_XL.*.jsonl.gz
|
||||
- cv-en_cuts_train.jsonl.gz
|
||||
- peoples_speech_train_split/peoples_speech_cuts_dirty.*.jsonl.gz
|
||||
- peoples_speech_train_split/peoples_speech_cuts_dirty_sa.*.jsonl.gz
|
||||
- peoples_speech_train_split/peoples_speech_cuts_clean.*.jsonl.gz
|
||||
- peoples_speech_train_split/peoples_speech_cuts_clean_sa.*.jsonl.gz
|
||||
- cv-en_cuts_train.jsonl.gz
|
||||
"""
|
||||
self.manifest_dir = Path(manifest_dir)
|
||||
|
||||
@ -45,15 +45,13 @@ class MultiDataset:
|
||||
logging.info("About to get multidataset train cuts")
|
||||
|
||||
# LibriSpeech
|
||||
logging.info(f"Loading LibriSpeech in lazy mode")
|
||||
logging.info("Loading LibriSpeech in lazy mode")
|
||||
librispeech_cuts = load_manifest_lazy(
|
||||
self.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
|
||||
)
|
||||
|
||||
# GigaSpeech
|
||||
filenames = glob.glob(
|
||||
f"{self.manifest_dir}/XL_split_2000/cuts_XL.*.jsonl.gz"
|
||||
)
|
||||
filenames = glob.glob(f"{self.manifest_dir}/XL_split/cuts_XL.*.jsonl.gz")
|
||||
|
||||
pattern = re.compile(r"cuts_XL.([0-9]+).jsonl.gz")
|
||||
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
|
||||
@ -68,21 +66,17 @@ class MultiDataset:
|
||||
)
|
||||
|
||||
# CommonVoice
|
||||
logging.info(f"Loading CommonVoice in lazy mode")
|
||||
logging.info("Loading CommonVoice in lazy mode")
|
||||
commonvoice_cuts = load_manifest_lazy(
|
||||
self.manifest_dir / f"cv-en_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# People's Speech
|
||||
filenames = glob.glob(
|
||||
f"{self.manifest_dir}/peoples_speech_train_split/peoples_speech_cuts_*.*.jsonl.gz"
|
||||
sorted_filenames = sorted(
|
||||
glob.glob(
|
||||
f"{self.manifest_dir}/peoples_speech_train_split/peoples_speech_cuts_*[yna].*.jsonl.gz"
|
||||
)
|
||||
)
|
||||
|
||||
pattern = re.compile(r"peoples_speech_cuts.([0-9]+).jsonl.gz")
|
||||
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
|
||||
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
|
||||
|
||||
sorted_filenames = [f[1] for f in idx_filenames]
|
||||
|
||||
logging.info(
|
||||
f"Loading People's Speech {len(sorted_filenames)} splits in lazy mode"
|
||||
|
||||
1
egs/librispeech/ASR/zipformer/multidataset.py
Symbolic link
1
egs/librispeech/ASR/zipformer/multidataset.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/multidataset.py
|
||||
@ -62,20 +62,21 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from zipformer import Zipformer2
|
||||
from scaling import ScheduledFloat
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from subsampling import Conv2dSubsampling
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from multidataset import MultiDataset
|
||||
from optim import Eden, ScaledAdam
|
||||
from scaling import ScheduledFloat
|
||||
from subsampling import Conv2dSubsampling
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer2
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
@ -84,40 +85,38 @@ from icefall.checkpoint import (
|
||||
save_checkpoint_with_global_batch_idx,
|
||||
update_averaged_model,
|
||||
)
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
get_parameter_groups_with_lrs
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
]
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
|
||||
def get_adjusted_batch_count(
|
||||
params: AttributeDict) -> float:
|
||||
def get_adjusted_batch_count(params: AttributeDict) -> float:
|
||||
# returns the number of batches we would have used so far if we had used the reference
|
||||
# duration. This is for purposes of set_batch_count().
|
||||
return (params.batch_idx_train * (params.max_duration * params.world_size) /
|
||||
params.ref_duration)
|
||||
return (
|
||||
params.batch_idx_train
|
||||
* (params.max_duration * params.world_size)
|
||||
/ params.ref_duration
|
||||
)
|
||||
|
||||
|
||||
def set_batch_count(
|
||||
model: Union[nn.Module, DDP], batch_count: float
|
||||
) -> None:
|
||||
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||
if isinstance(model, DDP):
|
||||
# get underlying nn.Module
|
||||
model = model.module
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, 'batch_count'):
|
||||
if hasattr(module, "batch_count"):
|
||||
module.batch_count = batch_count
|
||||
if hasattr(module, 'name'):
|
||||
if hasattr(module, "name"):
|
||||
module.name = name
|
||||
|
||||
|
||||
@ -154,35 +153,35 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
"--encoder-dim",
|
||||
type=str,
|
||||
default="192,256,384,512,384,256",
|
||||
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
|
||||
help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--query-head-dim",
|
||||
type=str,
|
||||
default="32",
|
||||
help="Query/key dimension per head in encoder stacks: a single int or comma-separated list."
|
||||
help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--value-head-dim",
|
||||
type=str,
|
||||
default="12",
|
||||
help="Value dimension per head in encoder stacks: a single int or comma-separated list."
|
||||
help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pos-head-dim",
|
||||
type=str,
|
||||
default="4",
|
||||
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list."
|
||||
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pos-dim",
|
||||
type=int,
|
||||
default="48",
|
||||
help="Positional-encoding embedding dimension"
|
||||
help="Positional-encoding embedding dimension",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -190,7 +189,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
type=str,
|
||||
default="192,192,256,256,256,192",
|
||||
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
||||
"A single int or comma-separated list. Must be <= each corresponding encoder_dim."
|
||||
"A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -230,7 +229,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
type=str,
|
||||
default="16,32,64,-1",
|
||||
help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
|
||||
" Must be just -1 if --causal=False"
|
||||
" Must be just -1 if --causal=False",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -239,7 +238,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
default="64,128,256,-1",
|
||||
help="Maximum left-contexts for causal training, measured in frames which will "
|
||||
"be converted to a number of chunks. If splitting into chunks, "
|
||||
"chunk left-context frames will be chosen randomly from this list; else not relevant."
|
||||
"chunk left-context frames will be chosen randomly from this list; else not relevant.",
|
||||
)
|
||||
|
||||
|
||||
@ -313,10 +312,7 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--base-lr",
|
||||
type=float,
|
||||
default=0.045,
|
||||
help="The base learning rate."
|
||||
"--base-lr", type=float, default=0.045, help="The base learning rate."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -340,15 +336,14 @@ def get_parser():
|
||||
type=float,
|
||||
default=600,
|
||||
help="Reference batch duration for purposes of adjusting batch counts for setting various "
|
||||
"schedules inside the model"
|
||||
"schedules inside the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -371,8 +366,7 @@ def get_parser():
|
||||
"--am-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The scale to smooth the loss with am (output of encoder network)"
|
||||
"part.",
|
||||
help="The scale to smooth the loss with am (output of encoder network)" "part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -450,6 +444,13 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-multidataset",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use multidataset to train.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -522,7 +523,7 @@ def get_params() -> AttributeDict:
|
||||
|
||||
|
||||
def _to_int_tuple(s: str):
|
||||
return tuple(map(int, s.split(',')))
|
||||
return tuple(map(int, s.split(",")))
|
||||
|
||||
|
||||
def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
||||
@ -537,7 +538,7 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
||||
encoder_embed = Conv2dSubsampling(
|
||||
in_channels=params.feature_dim,
|
||||
out_channels=_to_int_tuple(params.encoder_dim)[0],
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||
)
|
||||
return encoder_embed
|
||||
|
||||
@ -596,7 +597,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
encoder_dim=int(max(params.encoder_dim.split(','))),
|
||||
encoder_dim=int(max(params.encoder_dim.split(","))),
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
@ -745,11 +746,7 @@ def compute_loss(
|
||||
warmup: a floating point value which increases throughout training;
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
device = (
|
||||
model.device
|
||||
if isinstance(model, DDP)
|
||||
else next(model.parameters()).device
|
||||
)
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
@ -779,27 +776,24 @@ def compute_loss(
|
||||
# take down the scale on the simple loss from 1.0 at the start
|
||||
# to params.simple_loss scale by warm_step.
|
||||
simple_loss_scale = (
|
||||
s if batch_idx_train >= warm_step
|
||||
s
|
||||
if batch_idx_train >= warm_step
|
||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
||||
)
|
||||
pruned_loss_scale = (
|
||||
1.0 if batch_idx_train >= warm_step
|
||||
1.0
|
||||
if batch_idx_train >= warm_step
|
||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||
)
|
||||
|
||||
loss = (
|
||||
simple_loss_scale * simple_loss +
|
||||
pruned_loss_scale * pruned_loss
|
||||
)
|
||||
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
@ -895,7 +889,8 @@ def train_one_epoch(
|
||||
saved_bad_model = False
|
||||
|
||||
def save_bad_model(suffix: str = ""):
|
||||
save_checkpoint_impl(filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||
save_checkpoint_impl(
|
||||
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
params=params,
|
||||
@ -903,7 +898,8 @@ def train_one_epoch(
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=0)
|
||||
rank=0,
|
||||
)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx % 10 == 0:
|
||||
@ -988,7 +984,9 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}")
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
@ -998,8 +996,8 @@ def train_one_epoch(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
f"lr: {cur_lr:.2e}, " +
|
||||
(f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||
f"lr: {cur_lr:.2e}, "
|
||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
@ -1010,9 +1008,7 @@ def train_one_epoch(
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
if params.use_fp16:
|
||||
tb_writer.add_scalar(
|
||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||
@ -1029,7 +1025,9 @@ def train_one_epoch(
|
||||
)
|
||||
model.train()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
@ -1103,12 +1101,10 @@ def run(rank, world_size, args):
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank],
|
||||
find_unused_parameters=True)
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
optimizer = ScaledAdam(
|
||||
get_parameter_groups_with_lrs(
|
||||
model, lr=params.base_lr, include_names=True),
|
||||
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
|
||||
lr=params.base_lr, # should have no effect
|
||||
clipping_scale=2.0,
|
||||
)
|
||||
@ -1138,6 +1134,10 @@ def run(rank, world_size, args):
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
if params.use_multidataset:
|
||||
multidataset = MultiDataset(params.manifest_dir)
|
||||
train_cuts = multidataset.train_cuts()
|
||||
else:
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
if params.full_libri:
|
||||
train_cuts += librispeech.train_clean_360_cuts()
|
||||
@ -1197,7 +1197,7 @@ def run(rank, world_size, args):
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.print_diagnostics:
|
||||
if not params.use_multidataset and not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
@ -1206,8 +1206,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16,
|
||||
init_scale=1.0)
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
@ -1328,7 +1327,9 @@ def scan_pessimistic_batches_for_oom(
|
||||
)
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
raise
|
||||
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user