propagate changes from #525 to other librispeech recipes (#531)

* propaga changes from #525 to other librispeech recipes

* refactor display_and_save_batch to utils

* fixed typo

* reformat code style
This commit is contained in:
marcoyang1998 2022-08-17 17:18:15 +08:00 committed by GitHub
parent 669401869d
commit c74cec59e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 253 additions and 76 deletions

View File

@ -66,6 +66,7 @@ class Transducer(nn.Module):
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
reduction: str = "sum",
) -> torch.Tensor:
"""
Args:
@ -86,6 +87,10 @@ class Transducer(nn.Module):
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
reduction:
"sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance
in the batch.
Returns:
Return the transducer loss.
@ -95,6 +100,7 @@ class Transducer(nn.Module):
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert reduction in ("sum", "none"), reduction
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
@ -136,7 +142,7 @@ class Transducer(nn.Module):
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
reduction=reduction,
return_grad=True,
)
@ -163,7 +169,7 @@ class Transducer(nn.Module):
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
reduction=reduction,
)
return (simple_loss, pruned_loss)

View File

@ -78,6 +78,7 @@ from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
MetricsTracker,
display_and_save_batch,
measure_gradient_norms,
measure_weight_norms,
optim_step_and_measure_param_change,
@ -544,7 +545,36 @@ def compute_loss(
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
reduction="none",
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
is_finite = simple_loss_is_finite & pruned_loss_is_finite
if not torch.all(is_finite):
logging.info(
"Not all losses are finite!\n"
f"simple_loss: {simple_loss}\n"
f"pruned_loss: {pruned_loss}"
)
display_and_save_batch(batch, params=params, sp=sp)
simple_loss = simple_loss[simple_loss_is_finite]
pruned_loss = pruned_loss[pruned_loss_is_finite]
# If the batch contains more than 10 utterances AND
# if either all simple_loss or pruned_loss is inf or nan,
# we stop the training process by raising an exception
if feature.size(0) >= 10:
if torch.all(~simple_loss_is_finite) or torch.all(
~pruned_loss_is_finite
):
raise ValueError(
"There are too many utterances in this batch "
"leading to inf or nan losses."
)
simple_loss = simple_loss.sum()
pruned_loss = pruned_loss.sum()
loss = params.simple_loss_scale * simple_loss + pruned_loss
assert loss.requires_grad == is_training
@ -552,6 +582,10 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# info["frames"] is an approximate number for two reasons:
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
# (2) If some utterances in the batch lead to inf/nan loss, they
# are filtered out.
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)

View File

@ -88,7 +88,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
display_and_save_batch,
setup_logger,
str2bool,
)
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
@ -600,7 +606,35 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
is_finite = simple_loss_is_finite & pruned_loss_is_finite
if not torch.all(is_finite):
logging.info(
"Not all losses are finite!\n"
f"simple_loss: {simple_loss}\n"
f"pruned_loss: {pruned_loss}"
)
display_and_save_batch(batch, params=params, sp=sp)
simple_loss = simple_loss[simple_loss_is_finite]
pruned_loss = pruned_loss[pruned_loss_is_finite]
# If the batch contains more than 10 utterances AND
# if either all simple_loss or pruned_loss is inf or nan,
# we stop the training process by raising an exception
if feature.size(0) >= 10:
if torch.all(~simple_loss_is_finite) or torch.all(
~pruned_loss_is_finite
):
raise ValueError(
"There are too many utterances in this batch "
"leading to inf or nan losses."
)
simple_loss = simple_loss.sum()
pruned_loss = pruned_loss.sum()
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
@ -620,6 +654,10 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# info["frames"] is an approximate number for two reasons:
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
# (2) If some utterances in the batch lead to inf/nan loss, they
# are filtered out.
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
@ -993,38 +1031,6 @@ def run(rank, world_size, args):
cleanup_dist()
def display_and_save_batch(
batch: dict,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
) -> None:
"""Display the batch statistics and save the batch into disk.
Args:
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
params:
Parameters for training. See :func:`get_params`.
sp:
The BPE model.
"""
from lhotse.utils import uuid4
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
logging.info(f"Saving batch to {filename}")
torch.save(batch, filename)
supervisions = batch["supervisions"]
features = batch["inputs"]
logging.info(f"features shape: {features.shape}")
y = sp.encode(supervisions["text"], out_type=int)
num_tokens = sum(len(i) for i in y)
logging.info(f"num tokens: {num_tokens}")
def scan_pessimistic_batches_for_oom(
model: nn.Module,
train_dl: torch.utils.data.DataLoader,

View File

@ -105,6 +105,7 @@ class Transducer(nn.Module):
am_scale: float = 0.0,
lm_scale: float = 0.0,
warmup: float = 1.0,
reduction: str = "sum",
) -> torch.Tensor:
"""
Args:
@ -131,6 +132,10 @@ class Transducer(nn.Module):
warmup:
A value warmup >= 0 that determines which modules are active, values
warmup > 1 "are fully warmed up" and all modules will be active.
reduction:
"sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance
in the batch.
Returns:
Return the transducer loss.
@ -140,6 +145,7 @@ class Transducer(nn.Module):
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert reduction in ("sum", "none"), reduction
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
@ -196,7 +202,7 @@ class Transducer(nn.Module):
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
reduction=reduction,
return_grad=True,
)
@ -229,7 +235,7 @@ class Transducer(nn.Module):
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
reduction=reduction,
)
return (simple_loss, pruned_loss)

View File

@ -84,7 +84,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
display_and_save_batch,
setup_logger,
str2bool,
)
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
@ -637,7 +643,35 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
is_finite = simple_loss_is_finite & pruned_loss_is_finite
if not torch.all(is_finite):
logging.info(
"Not all losses are finite!\n"
f"simple_loss: {simple_loss}\n"
f"pruned_loss: {pruned_loss}"
)
display_and_save_batch(batch, params=params, sp=sp)
simple_loss = simple_loss[simple_loss_is_finite]
pruned_loss = pruned_loss[pruned_loss_is_finite]
# If the batch contains more than 10 utterances AND
# if either all simple_loss or pruned_loss is inf or nan,
# we stop the training process by raising an exception
if feature.size(0) >= 10:
if torch.all(~simple_loss_is_finite) or torch.all(
~pruned_loss_is_finite
):
raise ValueError(
"There are too many utterances in this batch "
"leading to inf or nan losses."
)
simple_loss = simple_loss.sum()
pruned_loss = pruned_loss.sum()
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
@ -657,6 +691,10 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# info["frames"] is an approximate number for two reasons:
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
# (2) If some utterances in the batch lead to inf/nan loss, they
# are filtered out.
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)

View File

@ -93,7 +93,13 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
display_and_save_batch,
setup_logger,
str2bool,
)
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
@ -630,7 +636,35 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
is_finite = simple_loss_is_finite & pruned_loss_is_finite
if not torch.all(is_finite):
logging.info(
"Not all losses are finite!\n"
f"simple_loss: {simple_loss}\n"
f"pruned_loss: {pruned_loss}"
)
display_and_save_batch(batch, params=params, sp=sp)
simple_loss = simple_loss[simple_loss_is_finite]
pruned_loss = pruned_loss[pruned_loss_is_finite]
# If the batch contains more than 10 utterances AND
# if either all simple_loss or pruned_loss is inf or nan,
# we stop the training process by raising an exception
if feature.size(0) >= 10:
if torch.all(~simple_loss_is_finite) or torch.all(
~pruned_loss_is_finite
):
raise ValueError(
"There are too many utterances in this batch "
"leading to inf or nan losses."
)
simple_loss = simple_loss.sum()
pruned_loss = pruned_loss.sum()
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
@ -650,6 +684,10 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# info["frames"] is an approximate number for two reasons:
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
# (2) If some utterances in the batch lead to inf/nan loss, they
# are filtered out.
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)

View File

@ -81,7 +81,13 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
display_and_save_batch,
setup_logger,
str2bool,
)
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
@ -670,7 +676,7 @@ def compute_loss(
simple_loss = simple_loss[simple_loss_is_finite]
pruned_loss = pruned_loss[pruned_loss_is_finite]
# If the batch contains more than 10 utterance AND
# If the batch contains more than 10 utterances AND
# if either all simple_loss or pruned_loss is inf or nan,
# we stop the training process by raising an exception
if feature.size(0) >= 10:
@ -1108,38 +1114,6 @@ def run(rank, world_size, args):
cleanup_dist()
def display_and_save_batch(
batch: dict,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
) -> None:
"""Display the batch statistics and save the batch into disk.
Args:
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
params:
Parameters for training. See :func:`get_params`.
sp:
The BPE model.
"""
from lhotse.utils import uuid4
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
logging.info(f"Saving batch to {filename}")
torch.save(batch, filename)
supervisions = batch["supervisions"]
features = batch["inputs"]
logging.info(f"features shape: {features.shape}")
y = sp.encode(supervisions["text"], out_type=int)
num_tokens = sum(len(i) for i in y)
logging.info(f"num tokens: {num_tokens}")
def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader,

View File

@ -89,6 +89,7 @@ class Transducer(nn.Module):
am_scale: float = 0.0,
lm_scale: float = 0.0,
warmup: float = 1.0,
reduction: str = "sum",
codebook_indexes: torch.Tensor = None,
) -> torch.Tensor:
"""
@ -113,6 +114,10 @@ class Transducer(nn.Module):
warmup:
A value warmup >= 0 that determines which modules are active, values
warmup > 1 "are fully warmed up" and all modules will be active.
reduction:
"sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance
in the batch.
codebook_indexes:
codebook_indexes extracted from a teacher model.
Returns:
@ -124,6 +129,7 @@ class Transducer(nn.Module):
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert reduction in ("sum", "none"), reduction
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
@ -184,7 +190,7 @@ class Transducer(nn.Module):
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
reduction=reduction,
return_grad=True,
)
@ -217,7 +223,7 @@ class Transducer(nn.Module):
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
reduction=reduction,
)
return (simple_loss, pruned_loss, codebook_loss)

View File

@ -93,7 +93,13 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
display_and_save_batch,
setup_logger,
str2bool,
)
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
@ -631,8 +637,35 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
codebook_indexes=codebook_indexes,
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
is_finite = simple_loss_is_finite & pruned_loss_is_finite
if not torch.all(is_finite):
logging.info(
"Not all losses are finite!\n"
f"simple_loss: {simple_loss}\n"
f"pruned_loss: {pruned_loss}"
)
display_and_save_batch(batch, params=params, sp=sp)
simple_loss = simple_loss[simple_loss_is_finite]
pruned_loss = pruned_loss[pruned_loss_is_finite]
# If the batch contains more than 10 utterances AND
# if either all simple_loss or pruned_loss is inf or nan,
# we stop the training process by raising an exception
if feature.size(0) >= 10:
if torch.all(~simple_loss_is_finite) or torch.all(
~pruned_loss_is_finite
):
raise ValueError(
"There are too many utterances in this batch "
"leading to inf or nan losses."
)
simple_loss = simple_loss.sum()
pruned_loss = pruned_loss.sum()
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
@ -654,6 +687,10 @@ def compute_loss(
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# info["frames"] is an approximate number for two reasons:
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
# (2) If some utterances in the batch lead to inf/nan loss, they
# are filtered out.
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)

View File

@ -944,3 +944,35 @@ def tokenize_by_bpe_model(
txt_with_bpe = "/".join(tokens)
return txt_with_bpe
def display_and_save_batch(
batch: dict,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
) -> None:
"""Display the batch statistics and save the batch into disk.
Args:
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
params:
Parameters for training. See :func:`get_params`.
sp:
The BPE model.
"""
from lhotse.utils import uuid4
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
logging.info(f"Saving batch to {filename}")
torch.save(batch, filename)
supervisions = batch["supervisions"]
features = batch["inputs"]
logging.info(f"features shape: {features.shape}")
y = sp.encode(supervisions["text"], out_type=int)
num_tokens = sum(len(i) for i in y)
logging.info(f"num tokens: {num_tokens}")