Update decode.py and train.py to use periodically averaged models.

This commit is contained in:
Fangjun Kuang 2022-05-13 23:22:30 +08:00
parent 7b786ce0b9
commit 2ce48a2c21
3 changed files with 204 additions and 85 deletions

View File

@ -20,40 +20,40 @@
Usage: Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless4/decode.py \ ./pruned_transducer_stateless4/decode.py \
--epoch 30 \ --epoch 30 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \ --exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (not recommended) (2) beam search (not recommended)
./pruned_transducer_stateless4/decode.py \ ./pruned_transducer_stateless4/decode.py \
--epoch 30 \ --epoch 30 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \ --exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless4/decode.py \ ./pruned_transducer_stateless4/decode.py \
--epoch 30 \ --epoch 30 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \ --exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search (4) fast beam search
./pruned_transducer_stateless4/decode.py \ ./pruned_transducer_stateless4/decode.py \
--epoch 30 \ --epoch 30 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \ --exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 4 \
--max-contexts 4 \ --max-contexts 4 \
--max-states 8 --max-states 8
""" """
@ -502,7 +502,7 @@ def main():
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py # <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# #
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) # Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -78,6 +79,7 @@ from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
@ -85,6 +87,7 @@ from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
str2bool,
write_error_stats, write_error_stats,
) )
@ -97,9 +100,9 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
default=28, default=30,
help="""It specifies the checkpoint to use for decoding. help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0. Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""", You can specify --avg to use more checkpoints for model averaging.""",
) )
@ -122,6 +125,17 @@ def get_parser():
"'--epoch' and '--iter'", "'--epoch' and '--iter'",
) )
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
@ -238,7 +252,7 @@ def decode_one_batch(
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
""" """
device = model.device device = next(model.parameters()).device
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3
@ -475,6 +489,9 @@ def main():
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started") logging.info("Decoding started")
@ -497,38 +514,85 @@ def main():
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
if params.iter > 0: if not params.use_averaged_model:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ if params.iter > 0:
: params.avg filenames = find_checkpoints(
] params.exp_dir, iteration=-params.iter
if len(filenames) == 0: )[: params.avg]
raise ValueError( if len(filenames) == 0:
f"No checkpoints found for" raise ValueError(
f" --iter {params.iter}, --avg {params.avg}" f"No checkpoints found for"
) f" --iter {params.iter}, --avg {params.avg}"
elif len(filenames) < params.avg: )
raise ValueError( elif len(filenames) < params.avg:
f"Not enough checkpoints ({len(filenames)}) found for" raise ValueError(
f" --iter {params.iter}, --avg {params.avg}" f"Not enough checkpoints ({len(filenames)}) found for"
) f" --iter {params.iter}, --avg {params.avg}"
logging.info(f"averaging {filenames}") )
model.to(device) logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device)) model.to(device)
elif params.avg == 1: model.load_state_dict(average_checkpoints(filenames, device=device))
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
start = params.epoch - params.avg + 1 if params.iter > 0:
filenames = [] filenames = find_checkpoints(
for i in range(start, params.epoch + 1): params.exp_dir, iteration=-params.iter
if start >= 0: )[: params.avg + 1]
filenames.append(f"{params.exp_dir}/epoch-{i}.pt") if len(filenames) == 0:
logging.info(f"averaging {filenames}") raise ValueError(
model.to(device) f"No checkpoints found for"
model.load_state_dict(average_checkpoints(filenames, device=device)) f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0
start = params.epoch - params.avg
assert start >= 1
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device) model.to(device)
model.eval() model.eval()
model.device = device
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, # Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang # Wei Kang,
# Mingshuang Luo) # Mingshuang Luo,)
# Zengwei Yao)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -24,7 +25,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless5/train.py \ ./pruned_transducer_stateless5/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 1 \
--exp-dir pruned_transducer_stateless5/exp \ --exp-dir pruned_transducer_stateless5/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 300 --max-duration 300
@ -34,7 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless5/train.py \ ./pruned_transducer_stateless5/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 1 \
--use-fp16 1 \ --use-fp16 1 \
--exp-dir pruned_transducer_stateless5/exp \ --exp-dir pruned_transducer_stateless5/exp \
--full-libri 1 \ --full-libri 1 \
@ -44,6 +45,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse import argparse
import copy
import logging import logging
import warnings import warnings
from pathlib import Path from pathlib import Path
@ -73,7 +75,10 @@ from torch.utils.tensorboard import SummaryWriter
from icefall import diagnostics from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx,
update_averaged_model,
)
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -166,10 +171,10 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--start-epoch", "--start-epoch",
type=int, type=int,
default=0, default=1,
help="""Resume training from from this epoch. help="""Resume training from this epoch. It should be positive.
If it is positive, it will load checkpoint from If larger than 1, it will load checkpoint from
transducer_stateless2/exp/epoch-{start_epoch-1}.pt exp-dir/epoch-{start_epoch-1}.pt
""", """,
) )
@ -282,7 +287,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--save-every-n", "--save-every-n",
type=int, type=int,
default=8000, default=4000,
help="""Save checkpoint after processing this number of batches" help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename params.batch_idx_train % save_every_n == 0. The checkpoint filename
@ -295,7 +300,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--keep-last-k", "--keep-last-k",
type=int, type=int,
default=20, default=30,
help="""Only keep this number of checkpoints on disk. help="""Only keep this number of checkpoints on disk.
For instance, if it is 3, there are only 3 checkpoints For instance, if it is 3, there are only 3 checkpoints
in the exp-dir with filenames `checkpoint-xxx.pt`. in the exp-dir with filenames `checkpoint-xxx.pt`.
@ -303,6 +308,19 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--average-period",
type=int,
default=100,
help="""Update the averaged model, namely `model_avg`, after processing
this number of batches. `model_avg` is a separate version of model,
in which each floating-point parameter is the average of all the
parameters from the start of training. Each time we take the average,
we do: `model_avg = model * (average_period / batch_idx_train) +
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
""",
)
parser.add_argument( parser.add_argument(
"--use-fp16", "--use-fp16",
type=str2bool, type=str2bool,
@ -434,6 +452,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
def load_checkpoint_if_available( def load_checkpoint_if_available(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
model_avg: nn.Module = None,
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
@ -441,7 +460,7 @@ def load_checkpoint_if_available(
If params.start_batch is positive, it will load the checkpoint from If params.start_batch is positive, it will load the checkpoint from
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
params.start_epoch is positive, it will load the checkpoint from params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`. `params.start_epoch - 1`.
Apart from loading state dict for `model` and `optimizer` it also updates Apart from loading state dict for `model` and `optimizer` it also updates
@ -453,6 +472,8 @@ def load_checkpoint_if_available(
The return value of :func:`get_params`. The return value of :func:`get_params`.
model: model:
The training model. The training model.
model_avg:
The stored model averaged from the start of training.
optimizer: optimizer:
The optimizer that we are using. The optimizer that we are using.
scheduler: scheduler:
@ -462,7 +483,7 @@ def load_checkpoint_if_available(
""" """
if params.start_batch > 0: if params.start_batch > 0:
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
elif params.start_epoch > 0: elif params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else: else:
return None return None
@ -472,6 +493,7 @@ def load_checkpoint_if_available(
saved_params = load_checkpoint( saved_params = load_checkpoint(
filename, filename,
model=model, model=model,
model_avg=model_avg,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
) )
@ -498,7 +520,8 @@ def load_checkpoint_if_available(
def save_checkpoint( def save_checkpoint(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: Union[nn.Module, DDP],
model_avg: Optional[nn.Module] = None,
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
@ -512,6 +535,8 @@ def save_checkpoint(
It is returned by :func:`get_params`. It is returned by :func:`get_params`.
model: model:
The training model. The training model.
model_avg:
The stored model averaged from the start of training.
optimizer: optimizer:
The optimizer used in the training. The optimizer used in the training.
sampler: sampler:
@ -525,6 +550,7 @@ def save_checkpoint(
save_checkpoint_impl( save_checkpoint_impl(
filename=filename, filename=filename,
model=model, model=model,
model_avg=model_avg,
params=params, params=params,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
@ -544,7 +570,7 @@ def save_checkpoint(
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
is_training: bool, is_training: bool,
@ -568,7 +594,11 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = model.device device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
@ -624,7 +654,7 @@ def compute_loss(
def compute_validation_loss( def compute_validation_loss(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
@ -658,13 +688,14 @@ def compute_validation_loss(
def train_one_epoch( def train_one_epoch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType, scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: GradScaler,
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
rank: int = 0, rank: int = 0,
@ -690,6 +721,8 @@ def train_one_epoch(
Dataloader for the validation dataset. Dataloader for the validation dataset.
scaler: scaler:
The scaler used for mix precision training. The scaler used for mix precision training.
model_avg:
The stored model averaged from the start of training.
tb_writer: tb_writer:
Writer to write log messages to tensorboard. Writer to write log messages to tensorboard.
world_size: world_size:
@ -739,6 +772,17 @@ def train_one_epoch(
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
return return
if (
rank == 0
and params.batch_idx_train > 0
and params.batch_idx_train % params.average_period == 0
):
update_averaged_model(
params=params,
model_cur=model,
model_avg=model_avg,
)
if ( if (
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
@ -748,6 +792,7 @@ def train_one_epoch(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
model=model, model=model,
model_avg=model_avg,
params=params, params=params,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
@ -855,13 +900,21 @@ def run(rank, world_size, args):
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoints = load_checkpoint_if_available(params=params, model=model) assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model)
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
)
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank])
model.device = device
optimizer = Eve(model.parameters(), lr=params.initial_lr) optimizer = Eve(model.parameters(), lr=params.initial_lr)
@ -934,10 +987,10 @@ def run(rank, world_size, args):
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs + 1):
scheduler.step_epoch(epoch) scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch) fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch - 1)
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
@ -947,6 +1000,7 @@ def run(rank, world_size, args):
train_one_epoch( train_one_epoch(
params=params, params=params,
model=model, model=model,
model_avg=model_avg,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
sp=sp, sp=sp,
@ -965,6 +1019,7 @@ def run(rank, world_size, args):
save_checkpoint( save_checkpoint(
params=params, params=params,
model=model, model=model,
model_avg=model_avg,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
sampler=train_dl.sampler, sampler=train_dl.sampler,
@ -1012,7 +1067,7 @@ def display_and_save_batch(
def scan_pessimistic_batches_for_oom( def scan_pessimistic_batches_for_oom(
model: nn.Module, model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
@ -1021,7 +1076,7 @@ def scan_pessimistic_batches_for_oom(
from lhotse.dataset import find_pessimistic_batches from lhotse.dataset import find_pessimistic_batches
logging.info( logging.info(
"Sanity check -- see if any of the batches in epoch 0 would cause OOM." "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
) )
batches, crit_values = find_pessimistic_batches(train_dl.sampler) batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():