mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
* propaga changes from #525 to other librispeech recipes * refactor display_and_save_batch to utils * fixed typo * reformat code style
This commit is contained in:
parent
669401869d
commit
c74cec59e9
@ -66,6 +66,7 @@ class Transducer(nn.Module):
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
reduction: str = "sum",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -86,6 +87,10 @@ class Transducer(nn.Module):
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
reduction:
|
||||
"sum" to sum the losses over all utterances in the batch.
|
||||
"none" to return the loss in a 1-D tensor for each utterance
|
||||
in the batch.
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
@ -95,6 +100,7 @@ class Transducer(nn.Module):
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert reduction in ("sum", "none"), reduction
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
@ -136,7 +142,7 @@ class Transducer(nn.Module):
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
reduction=reduction,
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
@ -163,7 +169,7 @@ class Transducer(nn.Module):
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
reduction=reduction,
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
||||
|
@ -78,6 +78,7 @@ from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
display_and_save_batch,
|
||||
measure_gradient_norms,
|
||||
measure_weight_norms,
|
||||
optim_step_and_measure_param_change,
|
||||
@ -544,7 +545,36 @@ def compute_loss(
|
||||
prune_range=params.prune_range,
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
reduction="none",
|
||||
)
|
||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||
is_finite = simple_loss_is_finite & pruned_loss_is_finite
|
||||
if not torch.all(is_finite):
|
||||
logging.info(
|
||||
"Not all losses are finite!\n"
|
||||
f"simple_loss: {simple_loss}\n"
|
||||
f"pruned_loss: {pruned_loss}"
|
||||
)
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
simple_loss = simple_loss[simple_loss_is_finite]
|
||||
pruned_loss = pruned_loss[pruned_loss_is_finite]
|
||||
|
||||
# If the batch contains more than 10 utterances AND
|
||||
# if either all simple_loss or pruned_loss is inf or nan,
|
||||
# we stop the training process by raising an exception
|
||||
if feature.size(0) >= 10:
|
||||
if torch.all(~simple_loss_is_finite) or torch.all(
|
||||
~pruned_loss_is_finite
|
||||
):
|
||||
raise ValueError(
|
||||
"There are too many utterances in this batch "
|
||||
"leading to inf or nan losses."
|
||||
)
|
||||
|
||||
simple_loss = simple_loss.sum()
|
||||
pruned_loss = pruned_loss.sum()
|
||||
|
||||
loss = params.simple_loss_scale * simple_loss + pruned_loss
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
@ -552,6 +582,10 @@ def compute_loss(
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
# info["frames"] is an approximate number for two reasons:
|
||||
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
|
||||
# (2) If some utterances in the batch lead to inf/nan loss, they
|
||||
# are filtered out.
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
@ -88,7 +88,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
display_and_save_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
@ -600,7 +606,35 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
warmup=warmup,
|
||||
reduction="none",
|
||||
)
|
||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||
is_finite = simple_loss_is_finite & pruned_loss_is_finite
|
||||
if not torch.all(is_finite):
|
||||
logging.info(
|
||||
"Not all losses are finite!\n"
|
||||
f"simple_loss: {simple_loss}\n"
|
||||
f"pruned_loss: {pruned_loss}"
|
||||
)
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
simple_loss = simple_loss[simple_loss_is_finite]
|
||||
pruned_loss = pruned_loss[pruned_loss_is_finite]
|
||||
|
||||
# If the batch contains more than 10 utterances AND
|
||||
# if either all simple_loss or pruned_loss is inf or nan,
|
||||
# we stop the training process by raising an exception
|
||||
if feature.size(0) >= 10:
|
||||
if torch.all(~simple_loss_is_finite) or torch.all(
|
||||
~pruned_loss_is_finite
|
||||
):
|
||||
raise ValueError(
|
||||
"There are too many utterances in this batch "
|
||||
"leading to inf or nan losses."
|
||||
)
|
||||
|
||||
simple_loss = simple_loss.sum()
|
||||
pruned_loss = pruned_loss.sum()
|
||||
# after the main warmup step, we keep pruned_loss_scale small
|
||||
# for the same amount of time (model_warm_step), to avoid
|
||||
# overwhelming the simple_loss and causing it to diverge,
|
||||
@ -620,6 +654,10 @@ def compute_loss(
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
# info["frames"] is an approximate number for two reasons:
|
||||
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
|
||||
# (2) If some utterances in the batch lead to inf/nan loss, they
|
||||
# are filtered out.
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
@ -993,38 +1031,6 @@ def run(rank, world_size, args):
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def display_and_save_batch(
|
||||
batch: dict,
|
||||
params: AttributeDict,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
) -> None:
|
||||
"""Display the batch statistics and save the batch into disk.
|
||||
|
||||
Args:
|
||||
batch:
|
||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||
for the content in it.
|
||||
params:
|
||||
Parameters for training. See :func:`get_params`.
|
||||
sp:
|
||||
The BPE model.
|
||||
"""
|
||||
from lhotse.utils import uuid4
|
||||
|
||||
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
|
||||
logging.info(f"Saving batch to {filename}")
|
||||
torch.save(batch, filename)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
features = batch["inputs"]
|
||||
|
||||
logging.info(f"features shape: {features.shape}")
|
||||
|
||||
y = sp.encode(supervisions["text"], out_type=int)
|
||||
num_tokens = sum(len(i) for i in y)
|
||||
logging.info(f"num tokens: {num_tokens}")
|
||||
|
||||
|
||||
def scan_pessimistic_batches_for_oom(
|
||||
model: nn.Module,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
|
@ -105,6 +105,7 @@ class Transducer(nn.Module):
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
warmup: float = 1.0,
|
||||
reduction: str = "sum",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -131,6 +132,10 @@ class Transducer(nn.Module):
|
||||
warmup:
|
||||
A value warmup >= 0 that determines which modules are active, values
|
||||
warmup > 1 "are fully warmed up" and all modules will be active.
|
||||
reduction:
|
||||
"sum" to sum the losses over all utterances in the batch.
|
||||
"none" to return the loss in a 1-D tensor for each utterance
|
||||
in the batch.
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
@ -140,6 +145,7 @@ class Transducer(nn.Module):
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert reduction in ("sum", "none"), reduction
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
@ -196,7 +202,7 @@ class Transducer(nn.Module):
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
reduction=reduction,
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
@ -229,7 +235,7 @@ class Transducer(nn.Module):
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
reduction=reduction,
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
||||
|
@ -84,7 +84,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
display_and_save_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
@ -637,7 +643,35 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
warmup=warmup,
|
||||
reduction="none",
|
||||
)
|
||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||
is_finite = simple_loss_is_finite & pruned_loss_is_finite
|
||||
if not torch.all(is_finite):
|
||||
logging.info(
|
||||
"Not all losses are finite!\n"
|
||||
f"simple_loss: {simple_loss}\n"
|
||||
f"pruned_loss: {pruned_loss}"
|
||||
)
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
simple_loss = simple_loss[simple_loss_is_finite]
|
||||
pruned_loss = pruned_loss[pruned_loss_is_finite]
|
||||
|
||||
# If the batch contains more than 10 utterances AND
|
||||
# if either all simple_loss or pruned_loss is inf or nan,
|
||||
# we stop the training process by raising an exception
|
||||
if feature.size(0) >= 10:
|
||||
if torch.all(~simple_loss_is_finite) or torch.all(
|
||||
~pruned_loss_is_finite
|
||||
):
|
||||
raise ValueError(
|
||||
"There are too many utterances in this batch "
|
||||
"leading to inf or nan losses."
|
||||
)
|
||||
|
||||
simple_loss = simple_loss.sum()
|
||||
pruned_loss = pruned_loss.sum()
|
||||
# after the main warmup step, we keep pruned_loss_scale small
|
||||
# for the same amount of time (model_warm_step), to avoid
|
||||
# overwhelming the simple_loss and causing it to diverge,
|
||||
@ -657,6 +691,10 @@ def compute_loss(
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
# info["frames"] is an approximate number for two reasons:
|
||||
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
|
||||
# (2) If some utterances in the batch lead to inf/nan loss, they
|
||||
# are filtered out.
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
@ -93,7 +93,13 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
display_and_save_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
@ -630,7 +636,35 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
warmup=warmup,
|
||||
reduction="none",
|
||||
)
|
||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||
is_finite = simple_loss_is_finite & pruned_loss_is_finite
|
||||
if not torch.all(is_finite):
|
||||
logging.info(
|
||||
"Not all losses are finite!\n"
|
||||
f"simple_loss: {simple_loss}\n"
|
||||
f"pruned_loss: {pruned_loss}"
|
||||
)
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
simple_loss = simple_loss[simple_loss_is_finite]
|
||||
pruned_loss = pruned_loss[pruned_loss_is_finite]
|
||||
|
||||
# If the batch contains more than 10 utterances AND
|
||||
# if either all simple_loss or pruned_loss is inf or nan,
|
||||
# we stop the training process by raising an exception
|
||||
if feature.size(0) >= 10:
|
||||
if torch.all(~simple_loss_is_finite) or torch.all(
|
||||
~pruned_loss_is_finite
|
||||
):
|
||||
raise ValueError(
|
||||
"There are too many utterances in this batch "
|
||||
"leading to inf or nan losses."
|
||||
)
|
||||
|
||||
simple_loss = simple_loss.sum()
|
||||
pruned_loss = pruned_loss.sum()
|
||||
# after the main warmup step, we keep pruned_loss_scale small
|
||||
# for the same amount of time (model_warm_step), to avoid
|
||||
# overwhelming the simple_loss and causing it to diverge,
|
||||
@ -650,6 +684,10 @@ def compute_loss(
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
# info["frames"] is an approximate number for two reasons:
|
||||
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
|
||||
# (2) If some utterances in the batch lead to inf/nan loss, they
|
||||
# are filtered out.
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
@ -81,7 +81,13 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
display_and_save_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
@ -670,7 +676,7 @@ def compute_loss(
|
||||
simple_loss = simple_loss[simple_loss_is_finite]
|
||||
pruned_loss = pruned_loss[pruned_loss_is_finite]
|
||||
|
||||
# If the batch contains more than 10 utterance AND
|
||||
# If the batch contains more than 10 utterances AND
|
||||
# if either all simple_loss or pruned_loss is inf or nan,
|
||||
# we stop the training process by raising an exception
|
||||
if feature.size(0) >= 10:
|
||||
@ -1108,38 +1114,6 @@ def run(rank, world_size, args):
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def display_and_save_batch(
|
||||
batch: dict,
|
||||
params: AttributeDict,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
) -> None:
|
||||
"""Display the batch statistics and save the batch into disk.
|
||||
|
||||
Args:
|
||||
batch:
|
||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||
for the content in it.
|
||||
params:
|
||||
Parameters for training. See :func:`get_params`.
|
||||
sp:
|
||||
The BPE model.
|
||||
"""
|
||||
from lhotse.utils import uuid4
|
||||
|
||||
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
|
||||
logging.info(f"Saving batch to {filename}")
|
||||
torch.save(batch, filename)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
features = batch["inputs"]
|
||||
|
||||
logging.info(f"features shape: {features.shape}")
|
||||
|
||||
y = sp.encode(supervisions["text"], out_type=int)
|
||||
num_tokens = sum(len(i) for i in y)
|
||||
logging.info(f"num tokens: {num_tokens}")
|
||||
|
||||
|
||||
def scan_pessimistic_batches_for_oom(
|
||||
model: Union[nn.Module, DDP],
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
|
@ -89,6 +89,7 @@ class Transducer(nn.Module):
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
warmup: float = 1.0,
|
||||
reduction: str = "sum",
|
||||
codebook_indexes: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@ -113,6 +114,10 @@ class Transducer(nn.Module):
|
||||
warmup:
|
||||
A value warmup >= 0 that determines which modules are active, values
|
||||
warmup > 1 "are fully warmed up" and all modules will be active.
|
||||
reduction:
|
||||
"sum" to sum the losses over all utterances in the batch.
|
||||
"none" to return the loss in a 1-D tensor for each utterance
|
||||
in the batch.
|
||||
codebook_indexes:
|
||||
codebook_indexes extracted from a teacher model.
|
||||
Returns:
|
||||
@ -124,6 +129,7 @@ class Transducer(nn.Module):
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert reduction in ("sum", "none"), reduction
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
@ -184,7 +190,7 @@ class Transducer(nn.Module):
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
reduction=reduction,
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
@ -217,7 +223,7 @@ class Transducer(nn.Module):
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
reduction=reduction,
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss, codebook_loss)
|
||||
|
@ -93,7 +93,13 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
display_and_save_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
@ -631,8 +637,35 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
warmup=warmup,
|
||||
reduction="none",
|
||||
codebook_indexes=codebook_indexes,
|
||||
)
|
||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||
is_finite = simple_loss_is_finite & pruned_loss_is_finite
|
||||
if not torch.all(is_finite):
|
||||
logging.info(
|
||||
"Not all losses are finite!\n"
|
||||
f"simple_loss: {simple_loss}\n"
|
||||
f"pruned_loss: {pruned_loss}"
|
||||
)
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
simple_loss = simple_loss[simple_loss_is_finite]
|
||||
pruned_loss = pruned_loss[pruned_loss_is_finite]
|
||||
# If the batch contains more than 10 utterances AND
|
||||
# if either all simple_loss or pruned_loss is inf or nan,
|
||||
# we stop the training process by raising an exception
|
||||
if feature.size(0) >= 10:
|
||||
if torch.all(~simple_loss_is_finite) or torch.all(
|
||||
~pruned_loss_is_finite
|
||||
):
|
||||
raise ValueError(
|
||||
"There are too many utterances in this batch "
|
||||
"leading to inf or nan losses."
|
||||
)
|
||||
|
||||
simple_loss = simple_loss.sum()
|
||||
pruned_loss = pruned_loss.sum()
|
||||
# after the main warmup step, we keep pruned_loss_scale small
|
||||
# for the same amount of time (model_warm_step), to avoid
|
||||
# overwhelming the simple_loss and causing it to diverge,
|
||||
@ -654,6 +687,10 @@ def compute_loss(
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
# info["frames"] is an approximate number for two reasons:
|
||||
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
|
||||
# (2) If some utterances in the batch lead to inf/nan loss, they
|
||||
# are filtered out.
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
@ -944,3 +944,35 @@ def tokenize_by_bpe_model(
|
||||
txt_with_bpe = "/".join(tokens)
|
||||
|
||||
return txt_with_bpe
|
||||
|
||||
|
||||
def display_and_save_batch(
|
||||
batch: dict,
|
||||
params: AttributeDict,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
) -> None:
|
||||
"""Display the batch statistics and save the batch into disk.
|
||||
|
||||
Args:
|
||||
batch:
|
||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||
for the content in it.
|
||||
params:
|
||||
Parameters for training. See :func:`get_params`.
|
||||
sp:
|
||||
The BPE model.
|
||||
"""
|
||||
from lhotse.utils import uuid4
|
||||
|
||||
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
|
||||
logging.info(f"Saving batch to {filename}")
|
||||
torch.save(batch, filename)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
features = batch["inputs"]
|
||||
|
||||
logging.info(f"features shape: {features.shape}")
|
||||
|
||||
y = sp.encode(supervisions["text"], out_type=int)
|
||||
num_tokens = sum(len(i) for i in y)
|
||||
logging.info(f"num tokens: {num_tokens}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user