Zipformer output length (#686)

* add assertion for output length

* add comment in filter_cuts

* add length filter to Zipformer recipes
This commit is contained in:
Desh Raj 2022-11-15 22:29:45 -05:00 committed by GitHub
parent 855c76655b
commit c8ce243255
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 116 additions and 51 deletions

View File

@ -101,6 +101,9 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor):
# Note: for ./lstm_transducer_stateless/lstm.py, the formula is # Note: for ./lstm_transducer_stateless/lstm.py, the formula is
# T = ((num_frames - 3) // 2 - 1) // 2 # T = ((num_frames - 3) // 2 - 1) // 2
# Note: for ./pruned_transducer_stateless7/zipformer.py, the formula is
# T = ((num_frames - 7) // 2 + 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str) tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens): if T < len(tokens):

View File

@ -59,7 +59,6 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from zipformer import Zipformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
@ -71,6 +70,7 @@ from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
from icefall import diagnostics from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import load_checkpoint, remove_checkpoints
@ -79,9 +79,9 @@ from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx, save_checkpoint_with_global_batch_idx,
update_averaged_model, update_averaged_model,
) )
from icefall.hooks import register_inf_check_hooks
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[ LRSchedulerType = Union[
@ -89,14 +89,12 @@ LRSchedulerType = Union[
] ]
def set_batch_count( def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
model: Union[nn.Module, DDP], batch_count: float
) -> None:
if isinstance(model, DDP): if isinstance(model, DDP):
# get underlying nn.Module # get underlying nn.Module
model = model.module model = model.module
for module in model.modules(): for module in model.modules():
if hasattr(module, 'batch_count'): if hasattr(module, "batch_count"):
module.batch_count = batch_count module.batch_count = batch_count
@ -126,7 +124,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--encoder-dims", "--encoder-dims",
type=str, type=str,
default="384,384,384,384,384", default="384,384,384,384,384",
help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated" help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
) )
parser.add_argument( parser.add_argument(
@ -134,7 +132,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
type=str, type=str,
default="192,192,192,192,192", default="192,192,192,192,192",
help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
not the same as embedding dimension.""" not the same as embedding dimension.""",
) )
parser.add_argument( parser.add_argument(
@ -143,7 +141,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
default="256,256,256,256,256", default="256,256,256,256,256",
help="Unmasked dimensions in the encoders, relates to augmentation during training. " help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
" worse." " worse.",
) )
parser.add_argument( parser.add_argument(
@ -248,10 +246,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--base-lr", "--base-lr", type=float, default=0.05, help="The base learning rate."
type=float,
default=0.05,
help="The base learning rate."
) )
parser.add_argument( parser.add_argument(
@ -451,11 +446,14 @@ def get_params() -> AttributeDict:
def get_encoder_model(params: AttributeDict) -> nn.Module: def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Zipformer and Transformer # TODO: We can add an option to switch between Zipformer and Transformer
def to_int_tuple(s: str): def to_int_tuple(s: str):
return tuple(map(int, s.split(','))) return tuple(map(int, s.split(",")))
encoder = Zipformer( encoder = Zipformer(
num_features=params.feature_dim, num_features=params.feature_dim,
output_downsampling_factor=2, output_downsampling_factor=2,
zipformer_downsampling_factors=to_int_tuple(params.zipformer_downsampling_factors), zipformer_downsampling_factors=to_int_tuple(
params.zipformer_downsampling_factors
),
encoder_dims=to_int_tuple(params.encoder_dims), encoder_dims=to_int_tuple(params.encoder_dims),
attention_dim=to_int_tuple(params.attention_dims), attention_dim=to_int_tuple(params.attention_dims),
encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
@ -479,7 +477,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
def get_joiner_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
encoder_dim=int(params.encoder_dims.split(',')[-1]), encoder_dim=int(params.encoder_dims.split(",")[-1]),
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim, joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
@ -496,7 +494,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
encoder_dim=int(params.encoder_dims.split(',')[-1]), encoder_dim=int(params.encoder_dims.split(",")[-1]),
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim, joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
@ -682,18 +680,17 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step. # to params.simple_loss scale by warm_step.
simple_loss_scale = ( simple_loss_scale = (
s if batch_idx_train >= warm_step s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
) )
pruned_loss_scale = ( pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step 1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step) else 0.1 + 0.9 * (batch_idx_train / warm_step)
) )
loss = ( loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
simple_loss_scale * simple_loss +
pruned_loss_scale * pruned_loss
)
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
@ -873,12 +870,16 @@ def train_one_epoch(
# of the grad scaler is configurable, but we can't configure it to have different # of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale. # behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item() cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): if cur_grad_scale < 1.0 or (
cur_grad_scale < 8.0 and batch_idx % 400 == 0
):
scaler.update(cur_grad_scale * 2.0) scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}") raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]
@ -888,8 +889,12 @@ def train_one_epoch(
f"Epoch {params.cur_epoch}, " f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], " f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, " + f"lr: {cur_lr:.2e}, "
(f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + (
f"grad_scale: {scaler._scale.item()}"
if params.use_fp16
else ""
)
) )
if tb_writer is not None: if tb_writer is not None:
@ -905,12 +910,15 @@ def train_one_epoch(
) )
if params.use_fp16: if params.use_fp16:
tb_writer.add_scalar( tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train "train/grad_scale",
cur_grad_scale,
params.batch_idx_train,
) )
if (
batch_idx % params.valid_interval == 0
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: and not params.print_diagnostics
):
logging.info("Computing validation loss") logging.info("Computing validation loss")
valid_info = compute_validation_loss( valid_info = compute_validation_loss(
params=params, params=params,
@ -921,7 +929,9 @@ def train_one_epoch(
) )
model.train() model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None: if tb_writer is not None:
valid_info.write_summary( valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train tb_writer, "train/valid_", params.batch_idx_train
@ -997,12 +1007,11 @@ def run(rank, world_size, args):
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank], model = DDP(model, device_ids=[rank], find_unused_parameters=True)
find_unused_parameters=True)
optimizer = ScaledAdam(model.parameters(), optimizer = ScaledAdam(
lr=params.base_lr, model.parameters(), lr=params.base_lr, clipping_scale=2.0
clipping_scale=2.0) )
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
@ -1043,7 +1052,34 @@ def run(rank, world_size, args):
# You should use ../local/display_manifest_statistics.py to get # You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select # an utterance duration distribution for your dataset to select
# the threshold # the threshold
return 1.0 <= c.duration <= 20.0 if c.duration < 1.0 or c.duration > 20.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Duration: {c.duration}"
)
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 7) // 2 + 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
@ -1071,8 +1107,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1193,7 +1228,9 @@ def scan_pessimistic_batches_for_oom(
) )
display_and_save_batch(batch, params=params, sp=sp) display_and_save_batch(batch, params=params, sp=sp)
raise raise
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
def main(): def main():

View File

@ -1828,6 +1828,7 @@ def _test_zipformer_main():
torch.randn(batch_size, seq_len, feature_dim), torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64), torch.full((batch_size,), seq_len, dtype=torch.int64),
) )
assert ((seq_len - 7) // 2 + 1) // 2 == f[0].shape[1], (seq_len, f.shape[1])
f[0].sum().backward() f[0].sum().backward()
c.eval() c.eval()
f = c( f = c(

View File

@ -90,12 +90,7 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
)
LRSchedulerType = Union[ LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
@ -1045,7 +1040,9 @@ def train_one_epoch(
params.best_train_loss = params.train_loss params.best_train_loss = params.train_loss
def filter_short_and_long_utterances(cuts: CutSet) -> CutSet: def filter_short_and_long_utterances(
cuts: CutSet, sp: spm.SentencePieceProcessor
) -> CutSet:
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds
# #
@ -1055,7 +1052,34 @@ def filter_short_and_long_utterances(cuts: CutSet) -> CutSet:
# You should use ../local/display_manifest_statistics.py to get # You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select # an utterance duration distribution for your dataset to select
# the threshold # the threshold
return 1.0 <= c.duration <= 20.0 if c.duration < 1.0 or c.duration > 20.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Duration: {c.duration}"
)
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 7) // 2 + 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False
return True
cuts = cuts.filter(remove_short_and_long_utt) cuts = cuts.filter(remove_short_and_long_utt)
@ -1162,7 +1186,7 @@ def run(rank, world_size, args):
train_cuts += librispeech.train_clean_360_cuts() train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts() train_cuts += librispeech.train_other_500_cuts()
train_cuts = filter_short_and_long_utterances(train_cuts) train_cuts = filter_short_and_long_utterances(train_cuts, sp)
gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir) gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir)
# XL 10k hours # XL 10k hours
@ -1179,7 +1203,7 @@ def run(rank, world_size, args):
logging.info("Using the S subset of GigaSpeech (250 hours)") logging.info("Using the S subset of GigaSpeech (250 hours)")
train_giga_cuts = gigaspeech.train_S_cuts() train_giga_cuts = gigaspeech.train_S_cuts()
train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts) train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts, sp)
train_giga_cuts = train_giga_cuts.repeat(times=None) train_giga_cuts = train_giga_cuts.repeat(times=None)
if args.enable_musan: if args.enable_musan: