Zipformer output length (#686)

* add assertion for output length

* add comment in filter_cuts

* add length filter to Zipformer recipes
This commit is contained in:
Desh Raj 2022-11-15 22:29:45 -05:00 committed by GitHub
parent 855c76655b
commit c8ce243255
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 116 additions and 51 deletions

View File

@ -101,6 +101,9 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor):
# Note: for ./lstm_transducer_stateless/lstm.py, the formula is
# T = ((num_frames - 3) // 2 - 1) // 2
# Note: for ./pruned_transducer_stateless7/zipformer.py, the formula is
# T = ((num_frames - 7) // 2 + 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):

View File

@ -59,7 +59,6 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from zipformer import Zipformer
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
@ -71,6 +70,7 @@ from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints
@ -79,9 +79,9 @@ from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx,
update_averaged_model,
)
from icefall.hooks import register_inf_check_hooks
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[
@ -89,14 +89,12 @@ LRSchedulerType = Union[
]
def set_batch_count(
model: Union[nn.Module, DDP], batch_count: float
) -> None:
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
if isinstance(model, DDP):
# get underlying nn.Module
model = model.module
for module in model.modules():
if hasattr(module, 'batch_count'):
if hasattr(module, "batch_count"):
module.batch_count = batch_count
@ -126,7 +124,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--encoder-dims",
type=str,
default="384,384,384,384,384",
help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated"
help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
)
parser.add_argument(
@ -134,7 +132,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
type=str,
default="192,192,192,192,192",
help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
not the same as embedding dimension."""
not the same as embedding dimension.""",
)
parser.add_argument(
@ -143,7 +141,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
default="256,256,256,256,256",
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
" worse."
" worse.",
)
parser.add_argument(
@ -248,10 +246,7 @@ def get_parser():
)
parser.add_argument(
"--base-lr",
type=float,
default=0.05,
help="The base learning rate."
"--base-lr", type=float, default=0.05, help="The base learning rate."
)
parser.add_argument(
@ -451,11 +446,14 @@ def get_params() -> AttributeDict:
def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Zipformer and Transformer
def to_int_tuple(s: str):
return tuple(map(int, s.split(',')))
return tuple(map(int, s.split(",")))
encoder = Zipformer(
num_features=params.feature_dim,
output_downsampling_factor=2,
zipformer_downsampling_factors=to_int_tuple(params.zipformer_downsampling_factors),
zipformer_downsampling_factors=to_int_tuple(
params.zipformer_downsampling_factors
),
encoder_dims=to_int_tuple(params.encoder_dims),
attention_dim=to_int_tuple(params.attention_dims),
encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
@ -479,7 +477,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
encoder_dim=int(params.encoder_dims.split(',')[-1]),
encoder_dim=int(params.encoder_dims.split(",")[-1]),
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
@ -496,7 +494,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder=encoder,
decoder=decoder,
joiner=joiner,
encoder_dim=int(params.encoder_dims.split(',')[-1]),
encoder_dim=int(params.encoder_dims.split(",")[-1]),
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
@ -682,18 +680,17 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s if batch_idx_train >= warm_step
s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step
1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss = (
simple_loss_scale * simple_loss +
pruned_loss_scale * pruned_loss
)
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training
@ -873,12 +870,16 @@ def train_one_epoch(
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
if cur_grad_scale < 1.0 or (
cur_grad_scale < 8.0 and batch_idx % 400 == 0
):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}")
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
@ -888,8 +889,12 @@ def train_one_epoch(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, " +
(f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
f"lr: {cur_lr:.2e}, "
+ (
f"grad_scale: {scaler._scale.item()}"
if params.use_fp16
else ""
)
)
if tb_writer is not None:
@ -905,12 +910,15 @@ def train_one_epoch(
)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
"train/grad_scale",
cur_grad_scale,
params.batch_idx_train,
)
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
if (
batch_idx % params.valid_interval == 0
and not params.print_diagnostics
):
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
@ -921,7 +929,9 @@ def train_one_epoch(
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
@ -997,12 +1007,11 @@ def run(rank, world_size, args):
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank],
find_unused_parameters=True)
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = ScaledAdam(model.parameters(),
lr=params.base_lr,
clipping_scale=2.0)
optimizer = ScaledAdam(
model.parameters(), lr=params.base_lr, clipping_scale=2.0
)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
@ -1043,7 +1052,34 @@ def run(rank, world_size, args):
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
if c.duration < 1.0 or c.duration > 20.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Duration: {c.duration}"
)
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 7) // 2 + 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)
@ -1071,8 +1107,7 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=params.use_fp16,
init_scale=1.0)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1193,7 +1228,9 @@ def scan_pessimistic_batches_for_oom(
)
display_and_save_batch(batch, params=params, sp=sp)
raise
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
def main():

View File

@ -1828,6 +1828,7 @@ def _test_zipformer_main():
torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
)
assert ((seq_len - 7) // 2 + 1) // 2 == f[0].shape[1], (seq_len, f.shape[1])
f[0].sum().backward()
c.eval()
f = c(

View File

@ -90,12 +90,7 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
)
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
@ -1045,7 +1040,9 @@ def train_one_epoch(
params.best_train_loss = params.train_loss
def filter_short_and_long_utterances(cuts: CutSet) -> CutSet:
def filter_short_and_long_utterances(
cuts: CutSet, sp: spm.SentencePieceProcessor
) -> CutSet:
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
@ -1055,7 +1052,34 @@ def filter_short_and_long_utterances(cuts: CutSet) -> CutSet:
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
if c.duration < 1.0 or c.duration > 20.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Duration: {c.duration}"
)
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 7) // 2 + 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False
return True
cuts = cuts.filter(remove_short_and_long_utt)
@ -1162,7 +1186,7 @@ def run(rank, world_size, args):
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
train_cuts = filter_short_and_long_utterances(train_cuts)
train_cuts = filter_short_and_long_utterances(train_cuts, sp)
gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir)
# XL 10k hours
@ -1179,7 +1203,7 @@ def run(rank, world_size, args):
logging.info("Using the S subset of GigaSpeech (250 hours)")
train_giga_cuts = gigaspeech.train_S_cuts()
train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts)
train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts, sp)
train_giga_cuts = train_giga_cuts.repeat(times=None)
if args.enable_musan: