update train.py

This commit is contained in:
Fangjun Kuang 2022-11-14 16:26:53 +08:00
parent ab38f4a926
commit 1d494556fc

View File

@ -27,8 +27,8 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 1 \ --start-epoch 1 \
--exp-dir pruned_transducer_stateless7/exp \ --exp-dir pruned_transducer_stateless7/exp \
--full-libri 1 \ --max-duration 750 \
--max-duration 300 --training-subset L
# For mix precision training: # For mix precision training:
@ -38,9 +38,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--start-epoch 1 \ --start-epoch 1 \
--use-fp16 1 \ --use-fp16 1 \
--exp-dir pruned_transducer_stateless7/exp \ --exp-dir pruned_transducer_stateless7/exp \
--full-libri 1 \ --max-duration 750
--max-duration 550
""" """
@ -54,12 +52,10 @@ from typing import Any, Dict, Optional, Tuple, Union
import k2 import k2
import optim import optim
import sentencepiece as spm
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import WenetSpeechAsrDataModule
from zipformer import Zipformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
@ -71,17 +67,20 @@ from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
from icefall import diagnostics from icefall import diagnostics
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import ( from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx, save_checkpoint_with_global_batch_idx,
update_averaged_model, update_averaged_model,
) )
from icefall.hooks import register_inf_check_hooks
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[ LRSchedulerType = Union[
@ -89,14 +88,12 @@ LRSchedulerType = Union[
] ]
def set_batch_count( def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
model: Union[nn.Module, DDP], batch_count: float
) -> None:
if isinstance(model, DDP): if isinstance(model, DDP):
# get underlying nn.Module # get underlying nn.Module
model = model.module model = model.module
for module in model.modules(): for module in model.modules():
if hasattr(module, 'batch_count'): if hasattr(module, "batch_count"):
module.batch_count = batch_count module.batch_count = batch_count
@ -126,7 +123,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--encoder-dims", "--encoder-dims",
type=str, type=str,
default="384,384,384,384,384", default="384,384,384,384,384",
help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated" help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
) )
parser.add_argument( parser.add_argument(
@ -134,7 +131,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
type=str, type=str,
default="192,192,192,192,192", default="192,192,192,192,192",
help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
not the same as embedding dimension.""" not the same as embedding dimension.""",
) )
parser.add_argument( parser.add_argument(
@ -143,7 +140,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
default="256,256,256,256,256", default="256,256,256,256,256",
help="Unmasked dimensions in the encoders, relates to augmentation during training. " help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
" worse." " worse.",
) )
parser.add_argument( parser.add_argument(
@ -241,17 +238,17 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--lang-dir",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_char",
help="Path to the BPE model", help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
) )
parser.add_argument( parser.add_argument(
"--base-lr", "--base-lr", type=float, default=0.05, help="The base learning rate."
type=float,
default=0.05,
help="The base learning rate."
) )
parser.add_argument( parser.add_argument(
@ -451,11 +448,14 @@ def get_params() -> AttributeDict:
def get_encoder_model(params: AttributeDict) -> nn.Module: def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Zipformer and Transformer # TODO: We can add an option to switch between Zipformer and Transformer
def to_int_tuple(s: str): def to_int_tuple(s: str):
return tuple(map(int, s.split(','))) return tuple(map(int, s.split(",")))
encoder = Zipformer( encoder = Zipformer(
num_features=params.feature_dim, num_features=params.feature_dim,
output_downsampling_factor=2, output_downsampling_factor=2,
zipformer_downsampling_factors=to_int_tuple(params.zipformer_downsampling_factors), zipformer_downsampling_factors=to_int_tuple(
params.zipformer_downsampling_factors
),
encoder_dims=to_int_tuple(params.encoder_dims), encoder_dims=to_int_tuple(params.encoder_dims),
attention_dim=to_int_tuple(params.attention_dims), attention_dim=to_int_tuple(params.attention_dims),
encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
@ -479,7 +479,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
def get_joiner_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
encoder_dim=int(params.encoder_dims.split(',')[-1]), encoder_dim=int(params.encoder_dims.split(",")[-1]),
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim, joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
@ -496,7 +496,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
encoder_dim=int(params.encoder_dims.split(',')[-1]), encoder_dim=int(params.encoder_dims.split(",")[-1]),
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim, joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
@ -567,9 +567,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -626,7 +623,7 @@ def save_checkpoint(
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor, graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict, batch: dict,
is_training: bool, is_training: bool,
) -> Tuple[Tensor, MetricsTracker]: ) -> Tuple[Tensor, MetricsTracker]:
@ -665,7 +662,8 @@ def compute_loss(
warm_step = params.warm_step warm_step = params.warm_step
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y = graph_compiler.texts_to_ids(texts)
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
@ -682,18 +680,17 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step. # to params.simple_loss scale by warm_step.
simple_loss_scale = ( simple_loss_scale = (
s if batch_idx_train >= warm_step s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
) )
pruned_loss_scale = ( pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step 1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step) else 0.1 + 0.9 * (batch_idx_train / warm_step)
) )
loss = ( loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
simple_loss_scale * simple_loss +
pruned_loss_scale * pruned_loss
)
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
@ -715,7 +712,7 @@ def compute_loss(
def compute_validation_loss( def compute_validation_loss(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor, graph_compiler: CharCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
) -> MetricsTracker: ) -> MetricsTracker:
@ -728,7 +725,7 @@ def compute_validation_loss(
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
sp=sp, graph_compiler=graph_compiler,
batch=batch, batch=batch,
is_training=False, is_training=False,
) )
@ -751,7 +748,7 @@ def train_one_epoch(
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType, scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor, graph_compiler: CharCtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: GradScaler,
@ -795,13 +792,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -810,7 +801,7 @@ def train_one_epoch(
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
sp=sp, graph_compiler=graph_compiler,
batch=batch, batch=batch,
is_training=True, is_training=True,
) )
@ -827,7 +818,7 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
except: # noqa except: # noqa
display_and_save_batch(batch, params=params, sp=sp) display_and_save_batch(batch, params=params)
raise raise
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
@ -848,7 +839,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -861,7 +851,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,
@ -873,12 +862,16 @@ def train_one_epoch(
# of the grad scaler is configurable, but we can't configure it to have different # of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale. # behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item() cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): if cur_grad_scale < 1.0 or (
cur_grad_scale < 8.0 and batch_idx % 400 == 0
):
scaler.update(cur_grad_scale * 2.0) scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}") raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]
@ -888,8 +881,12 @@ def train_one_epoch(
f"Epoch {params.cur_epoch}, " f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], " f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, " + f"lr: {cur_lr:.2e}, "
(f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + (
f"grad_scale: {scaler._scale.item()}"
if params.use_fp16
else ""
)
) )
if tb_writer is not None: if tb_writer is not None:
@ -905,23 +902,28 @@ def train_one_epoch(
) )
if params.use_fp16: if params.use_fp16:
tb_writer.add_scalar( tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train "train/grad_scale",
cur_grad_scale,
params.batch_idx_train,
) )
if (
batch_idx % params.valid_interval == 0
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: and not params.print_diagnostics
):
logging.info("Computing validation loss") logging.info("Computing validation loss")
valid_info = compute_validation_loss( valid_info = compute_validation_loss(
params=params, params=params,
model=model, model=model,
sp=sp, graph_compiler=graph_compiler,
valid_dl=valid_dl, valid_dl=valid_dl,
world_size=world_size, world_size=world_size,
) )
model.train() model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None: if tb_writer is not None:
valid_info.write_summary( valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train tb_writer, "train/valid_", params.batch_idx_train
@ -948,8 +950,6 @@ def run(rank, world_size, args):
""" """
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
if params.full_libri is False:
params.valid_interval = 1600
fix_random_seed(params.seed) fix_random_seed(params.seed)
if world_size > 1: if world_size > 1:
@ -968,12 +968,14 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor() lexicon = Lexicon(params.lang_dir)
sp.load(params.bpe_model) graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
# <blk> is defined in local/train_bpe_model.py params.blank_id = lexicon.token_table["<blk>"]
params.blank_id = sp.piece_to_id("<blk>") params.vocab_size = max(lexicon.tokens) + 1
params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)
@ -997,12 +999,11 @@ def run(rank, world_size, args):
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank], model = DDP(model, device_ids=[rank], find_unused_parameters=True)
find_unused_parameters=True)
optimizer = ScaledAdam(model.parameters(), optimizer = ScaledAdam(
lr=params.base_lr, model.parameters(), lr=params.base_lr, clipping_scale=2.0
clipping_scale=2.0) )
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
@ -1027,26 +1028,26 @@ def run(rank, world_size, args):
if params.inf_check: if params.inf_check:
register_inf_check_hooks(model) register_inf_check_hooks(model)
librispeech = LibriSpeechAsrDataModule(args) wenetspeech = WenetSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts() train_cuts = wenetspeech.train_cuts()
if params.full_libri: valid_cuts = wenetspeech.valid_cuts()
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 19 seconds
# #
# Caution: There is a reason to select 20.0 here. Please see # Caution: There is a reason to select 19.0 here. Please see
# ../local/display_manifest_statistics.py # ../local/display_manifest_statistics.py
# #
# You should use ../local/display_manifest_statistics.py to get # You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select # an utterance duration distribution for your dataset to select
# the threshold # the threshold
return 1.0 <= c.duration <= 20.0 return 1.0 <= c.duration <= 19.0
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
valid_dl = wenetspeech.valid_dataloaders(valid_cuts)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint # We only load the sampler's state dict when it loads a checkpoint
# saved in the middle of an epoch # saved in the middle of an epoch
@ -1054,25 +1055,20 @@ def run(rank, world_size, args):
else: else:
sampler_state_dict = None sampler_state_dict = None
train_dl = librispeech.train_dataloaders( train_dl = wenetspeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict train_cuts, sampler_state_dict=sampler_state_dict
) )
valid_cuts = librispeech.dev_clean_cuts() if False and not params.print_diagnostics:
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
model=model, model=model,
train_dl=train_dl, train_dl=train_dl,
optimizer=optimizer, optimizer=optimizer,
sp=sp, graph_compiler=graph_compiler,
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1093,7 +1089,7 @@ def run(rank, world_size, args):
model_avg=model_avg, model_avg=model_avg,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
sp=sp, graph_compiler=graph_compiler,
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
scaler=scaler, scaler=scaler,
@ -1127,7 +1123,6 @@ def run(rank, world_size, args):
def display_and_save_batch( def display_and_save_batch(
batch: dict, batch: dict,
params: AttributeDict, params: AttributeDict,
sp: spm.SentencePieceProcessor,
) -> None: ) -> None:
"""Display the batch statistics and save the batch into disk. """Display the batch statistics and save the batch into disk.
@ -1137,8 +1132,6 @@ def display_and_save_batch(
for the content in it. for the content in it.
params: params:
Parameters for training. See :func:`get_params`. Parameters for training. See :func:`get_params`.
sp:
The BPE model.
""" """
from lhotse.utils import uuid4 from lhotse.utils import uuid4
@ -1146,13 +1139,13 @@ def display_and_save_batch(
logging.info(f"Saving batch to {filename}") logging.info(f"Saving batch to {filename}")
torch.save(batch, filename) torch.save(batch, filename)
supervisions = batch["supervisions"]
features = batch["inputs"] features = batch["inputs"]
logging.info(f"features shape: {features.shape}") logging.info(f"features shape: {features.shape}")
y = sp.encode(supervisions["text"], out_type=int) texts = batch["supervisions"]["text"]
num_tokens = sum(len(i) for i in y) num_tokens = sum(len(i) for i in texts)
logging.info(f"num tokens: {num_tokens}") logging.info(f"num tokens: {num_tokens}")
@ -1160,7 +1153,7 @@ def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor, graph_compiler: CharCtcTrainingGraphCompiler,
params: AttributeDict, params: AttributeDict,
): ):
from lhotse.dataset import find_pessimistic_batches from lhotse.dataset import find_pessimistic_batches
@ -1176,7 +1169,7 @@ def scan_pessimistic_batches_for_oom(
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,
sp=sp, graph_compiler=graph_compiler,
batch=batch, batch=batch,
is_training=True, is_training=True,
) )
@ -1191,15 +1184,18 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} " f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..." f"(={crit_values[criterion]}) ..."
) )
display_and_save_batch(batch, params=params, sp=sp) display_and_save_batch(batch, params=params)
raise raise
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
def main(): def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) WenetSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.lang_dir = Path(args.lang_dir)
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
world_size = args.world_size world_size = args.world_size