This commit is contained in:
Yifan Yang 2023-06-01 11:35:33 +08:00
parent 5d59f48193
commit 3a07dbddf0
4 changed files with 94 additions and 98 deletions

View File

@ -299,7 +299,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
if [ -f ../../../../gigaspeech/ASR/data/fbank/XL_split/.split_completed ]; then
ln -svf $(realpath ../../../../gigaspeech/ASR/data/fbank/XL_split) .
else
log "Abort! Please run gigaspeech prepare.sh --stage 5 --stop-stage 6"
log "Abort! Please run ../../gigaspeech/ASR/prepare.sh --stage 5 --stop-stage 6"
exit 1
fi
@ -315,7 +315,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
ln -svf $(realpath ../../../../commonvoice/ASR/data/en/fbank/cv-en_train_split_1000) .
ln -svf $(realpath ../../../../commonvoice/ASR/data/en/fbank/cv-en_cuts_train.jsonl.gz) .
else
log "Abort! Please run commonvoice prepare.sh --stage 5 --stop-stage 6"
log "Abort! Please run ../../commonvoice/ASR/prepare.sh --stage 5 --stop-stage 6"
exit 1
fi
@ -330,7 +330,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
if [ -f ../../../../peoples_speech/ASR/data/fbank/.peoples_speech_train.done ]; then
ln -svf $(realpath ../../../../peoples_speech/ASR/data/fbank/peoples_speech_train_split) .
else
log "Abort! Please run commonvoice prepare.sh --stage 5 --stop-stage 6"
log "Abort! Please run ../../peoples_speech/prepare.sh --stage 5 --stop-stage 6"
exit 1
fi

View File

@ -33,11 +33,11 @@ class MultiDataset:
- librispeech_cuts_train-all-shuf.jsonl.gz
- XL_split_2000/cuts_XL.*.jsonl.gz
- cv-en_cuts_train.jsonl.gz
- peoples_speech_train_split/peoples_speech_cuts_dirty.*.jsonl.gz
- peoples_speech_train_split/peoples_speech_cuts_dirty_sa.*.jsonl.gz
- peoples_speech_train_split/peoples_speech_cuts_clean.*.jsonl.gz
- peoples_speech_train_split/peoples_speech_cuts_clean_sa.*.jsonl.gz
- cv-en_cuts_train.jsonl.gz
"""
self.manifest_dir = Path(manifest_dir)
@ -45,15 +45,13 @@ class MultiDataset:
logging.info("About to get multidataset train cuts")
# LibriSpeech
logging.info(f"Loading LibriSpeech in lazy mode")
logging.info("Loading LibriSpeech in lazy mode")
librispeech_cuts = load_manifest_lazy(
self.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
)
# GigaSpeech
filenames = glob.glob(
f"{self.manifest_dir}/XL_split_2000/cuts_XL.*.jsonl.gz"
)
filenames = glob.glob(f"{self.manifest_dir}/XL_split/cuts_XL.*.jsonl.gz")
pattern = re.compile(r"cuts_XL.([0-9]+).jsonl.gz")
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
@ -68,21 +66,17 @@ class MultiDataset:
)
# CommonVoice
logging.info(f"Loading CommonVoice in lazy mode")
logging.info("Loading CommonVoice in lazy mode")
commonvoice_cuts = load_manifest_lazy(
self.manifest_dir / f"cv-en_cuts_train.jsonl.gz"
)
# People's Speech
filenames = glob.glob(
f"{self.manifest_dir}/peoples_speech_train_split/peoples_speech_cuts_*.*.jsonl.gz"
sorted_filenames = sorted(
glob.glob(
f"{self.manifest_dir}/peoples_speech_train_split/peoples_speech_cuts_*[yna].*.jsonl.gz"
)
)
pattern = re.compile(r"peoples_speech_cuts.([0-9]+).jsonl.gz")
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
sorted_filenames = [f[1] for f in idx_filenames]
logging.info(
f"Loading People's Speech {len(sorted_filenames)} splits in lazy mode"

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/multidataset.py

View File

@ -62,20 +62,21 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from zipformer import Zipformer2
from scaling import ScheduledFloat
from decoder import Decoder
from joiner import Joiner
from subsampling import Conv2dSubsampling
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import Transducer
from multidataset import MultiDataset
from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints
@ -84,40 +85,38 @@ from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx,
update_averaged_model,
)
from icefall.hooks import register_inf_check_hooks
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
get_parameter_groups_with_lrs
)
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def get_adjusted_batch_count(
params: AttributeDict) -> float:
def get_adjusted_batch_count(params: AttributeDict) -> float:
# returns the number of batches we would have used so far if we had used the reference
# duration. This is for purposes of set_batch_count().
return (params.batch_idx_train * (params.max_duration * params.world_size) /
params.ref_duration)
return (
params.batch_idx_train
* (params.max_duration * params.world_size)
/ params.ref_duration
)
def set_batch_count(
model: Union[nn.Module, DDP], batch_count: float
) -> None:
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
if isinstance(model, DDP):
# get underlying nn.Module
model = model.module
for name, module in model.named_modules():
if hasattr(module, 'batch_count'):
if hasattr(module, "batch_count"):
module.batch_count = batch_count
if hasattr(module, 'name'):
if hasattr(module, "name"):
module.name = name
@ -154,35 +153,35 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--encoder-dim",
type=str,
default="192,256,384,512,384,256",
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
)
parser.add_argument(
"--query-head-dim",
type=str,
default="32",
help="Query/key dimension per head in encoder stacks: a single int or comma-separated list."
help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
)
parser.add_argument(
"--value-head-dim",
type=str,
default="12",
help="Value dimension per head in encoder stacks: a single int or comma-separated list."
help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
)
parser.add_argument(
"--pos-head-dim",
type=str,
default="4",
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list."
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
)
parser.add_argument(
"--pos-dim",
type=int,
default="48",
help="Positional-encoding embedding dimension"
help="Positional-encoding embedding dimension",
)
parser.add_argument(
@ -190,7 +189,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
type=str,
default="192,192,256,256,256,192",
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"A single int or comma-separated list. Must be <= each corresponding encoder_dim."
"A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
)
parser.add_argument(
@ -230,7 +229,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
type=str,
default="16,32,64,-1",
help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
" Must be just -1 if --causal=False"
" Must be just -1 if --causal=False",
)
parser.add_argument(
@ -239,7 +238,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
default="64,128,256,-1",
help="Maximum left-contexts for causal training, measured in frames which will "
"be converted to a number of chunks. If splitting into chunks, "
"chunk left-context frames will be chosen randomly from this list; else not relevant."
"chunk left-context frames will be chosen randomly from this list; else not relevant.",
)
@ -313,10 +312,7 @@ def get_parser():
)
parser.add_argument(
"--base-lr",
type=float,
default=0.045,
help="The base learning rate."
"--base-lr", type=float, default=0.045, help="The base learning rate."
)
parser.add_argument(
@ -340,15 +336,14 @@ def get_parser():
type=float,
default=600,
help="Reference batch duration for purposes of adjusting batch counts for setting various "
"schedules inside the model"
"schedules inside the model",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
@ -371,8 +366,7 @@ def get_parser():
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
help="The scale to smooth the loss with am (output of encoder network)" "part.",
)
parser.add_argument(
@ -450,6 +444,13 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--use-multidataset",
type=str2bool,
default=False,
help="Whether to use multidataset to train.",
)
add_model_arguments(parser)
return parser
@ -522,7 +523,7 @@ def get_params() -> AttributeDict:
def _to_int_tuple(s: str):
return tuple(map(int, s.split(',')))
return tuple(map(int, s.split(",")))
def get_encoder_embed(params: AttributeDict) -> nn.Module:
@ -537,7 +538,7 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
encoder_embed = Conv2dSubsampling(
in_channels=params.feature_dim,
out_channels=_to_int_tuple(params.encoder_dim)[0],
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
)
return encoder_embed
@ -596,7 +597,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder=encoder,
decoder=decoder,
joiner=joiner,
encoder_dim=int(max(params.encoder_dim.split(','))),
encoder_dim=int(max(params.encoder_dim.split(","))),
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
@ -745,11 +746,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
@ -779,27 +776,24 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s if batch_idx_train >= warm_step
s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step
1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss = (
simple_loss_scale * simple_loss +
pruned_loss_scale * pruned_loss
)
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -895,7 +889,8 @@ def train_one_epoch(
saved_bad_model = False
def save_bad_model(suffix: str = ""):
save_checkpoint_impl(filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
save_checkpoint_impl(
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
model=model,
model_avg=model_avg,
params=params,
@ -903,7 +898,8 @@ def train_one_epoch(
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=0)
rank=0,
)
for batch_idx, batch in enumerate(train_dl):
if batch_idx % 10 == 0:
@ -988,7 +984,9 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}")
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
@ -998,8 +996,8 @@ def train_one_epoch(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, " +
(f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
f"lr: {cur_lr:.2e}, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
)
if tb_writer is not None:
@ -1010,9 +1008,7 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
@ -1029,7 +1025,9 @@ def train_one_epoch(
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
@ -1103,12 +1101,10 @@ def run(rank, world_size, args):
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank],
find_unused_parameters=True)
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = ScaledAdam(
get_parameter_groups_with_lrs(
model, lr=params.base_lr, include_names=True),
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
lr=params.base_lr, # should have no effect
clipping_scale=2.0,
)
@ -1138,6 +1134,10 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
if params.use_multidataset:
multidataset = MultiDataset(params.manifest_dir)
train_cuts = multidataset.train_cuts()
else:
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
@ -1197,7 +1197,7 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
if not params.use_multidataset and not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
@ -1206,8 +1206,7 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=params.use_fp16,
init_scale=1.0)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1328,7 +1327,9 @@ def scan_pessimistic_batches_for_oom(
)
display_and_save_batch(batch, params=params, sp=sp)
raise
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
def main():