mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Zipformer output length (#686)
* add assertion for output length * add comment in filter_cuts * add length filter to Zipformer recipes
This commit is contained in:
parent
855c76655b
commit
c8ce243255
@ -101,6 +101,9 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor):
|
|||||||
# Note: for ./lstm_transducer_stateless/lstm.py, the formula is
|
# Note: for ./lstm_transducer_stateless/lstm.py, the formula is
|
||||||
# T = ((num_frames - 3) // 2 - 1) // 2
|
# T = ((num_frames - 3) // 2 - 1) // 2
|
||||||
|
|
||||||
|
# Note: for ./pruned_transducer_stateless7/zipformer.py, the formula is
|
||||||
|
# T = ((num_frames - 7) // 2 + 1) // 2
|
||||||
|
|
||||||
tokens = sp.encode(c.supervisions[0].text, out_type=str)
|
tokens = sp.encode(c.supervisions[0].text, out_type=str)
|
||||||
|
|
||||||
if T < len(tokens):
|
if T < len(tokens):
|
||||||
|
@ -59,7 +59,6 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from zipformer import Zipformer
|
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
@ -71,6 +70,7 @@ from torch import Tensor
|
|||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from zipformer import Zipformer
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||||
@ -79,9 +79,9 @@ from icefall.checkpoint import (
|
|||||||
save_checkpoint_with_global_batch_idx,
|
save_checkpoint_with_global_batch_idx,
|
||||||
update_averaged_model,
|
update_averaged_model,
|
||||||
)
|
)
|
||||||
from icefall.hooks import register_inf_check_hooks
|
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
LRSchedulerType = Union[
|
LRSchedulerType = Union[
|
||||||
@ -89,14 +89,12 @@ LRSchedulerType = Union[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def set_batch_count(
|
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||||
model: Union[nn.Module, DDP], batch_count: float
|
|
||||||
) -> None:
|
|
||||||
if isinstance(model, DDP):
|
if isinstance(model, DDP):
|
||||||
# get underlying nn.Module
|
# get underlying nn.Module
|
||||||
model = model.module
|
model = model.module
|
||||||
for module in model.modules():
|
for module in model.modules():
|
||||||
if hasattr(module, 'batch_count'):
|
if hasattr(module, "batch_count"):
|
||||||
module.batch_count = batch_count
|
module.batch_count = batch_count
|
||||||
|
|
||||||
|
|
||||||
@ -126,7 +124,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
"--encoder-dims",
|
"--encoder-dims",
|
||||||
type=str,
|
type=str,
|
||||||
default="384,384,384,384,384",
|
default="384,384,384,384,384",
|
||||||
help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated"
|
help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -134,7 +132,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
type=str,
|
type=str,
|
||||||
default="192,192,192,192,192",
|
default="192,192,192,192,192",
|
||||||
help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
|
help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
|
||||||
not the same as embedding dimension."""
|
not the same as embedding dimension.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -143,7 +141,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
default="256,256,256,256,256",
|
default="256,256,256,256,256",
|
||||||
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
||||||
"Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
|
"Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
|
||||||
" worse."
|
" worse.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -248,10 +246,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base-lr",
|
"--base-lr", type=float, default=0.05, help="The base learning rate."
|
||||||
type=float,
|
|
||||||
default=0.05,
|
|
||||||
help="The base learning rate."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -451,11 +446,14 @@ def get_params() -> AttributeDict:
|
|||||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
# TODO: We can add an option to switch between Zipformer and Transformer
|
# TODO: We can add an option to switch between Zipformer and Transformer
|
||||||
def to_int_tuple(s: str):
|
def to_int_tuple(s: str):
|
||||||
return tuple(map(int, s.split(',')))
|
return tuple(map(int, s.split(",")))
|
||||||
|
|
||||||
encoder = Zipformer(
|
encoder = Zipformer(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
output_downsampling_factor=2,
|
output_downsampling_factor=2,
|
||||||
zipformer_downsampling_factors=to_int_tuple(params.zipformer_downsampling_factors),
|
zipformer_downsampling_factors=to_int_tuple(
|
||||||
|
params.zipformer_downsampling_factors
|
||||||
|
),
|
||||||
encoder_dims=to_int_tuple(params.encoder_dims),
|
encoder_dims=to_int_tuple(params.encoder_dims),
|
||||||
attention_dim=to_int_tuple(params.attention_dims),
|
attention_dim=to_int_tuple(params.attention_dims),
|
||||||
encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
|
encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
|
||||||
@ -479,7 +477,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||||
joiner = Joiner(
|
joiner = Joiner(
|
||||||
encoder_dim=int(params.encoder_dims.split(',')[-1]),
|
encoder_dim=int(params.encoder_dims.split(",")[-1]),
|
||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
joiner_dim=params.joiner_dim,
|
joiner_dim=params.joiner_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
@ -496,7 +494,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
encoder_dim=int(params.encoder_dims.split(',')[-1]),
|
encoder_dim=int(params.encoder_dims.split(",")[-1]),
|
||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
joiner_dim=params.joiner_dim,
|
joiner_dim=params.joiner_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
@ -682,18 +680,17 @@ def compute_loss(
|
|||||||
# take down the scale on the simple loss from 1.0 at the start
|
# take down the scale on the simple loss from 1.0 at the start
|
||||||
# to params.simple_loss scale by warm_step.
|
# to params.simple_loss scale by warm_step.
|
||||||
simple_loss_scale = (
|
simple_loss_scale = (
|
||||||
s if batch_idx_train >= warm_step
|
s
|
||||||
|
if batch_idx_train >= warm_step
|
||||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
||||||
)
|
)
|
||||||
pruned_loss_scale = (
|
pruned_loss_scale = (
|
||||||
1.0 if batch_idx_train >= warm_step
|
1.0
|
||||||
|
if batch_idx_train >= warm_step
|
||||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = (
|
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||||
simple_loss_scale * simple_loss +
|
|
||||||
pruned_loss_scale * pruned_loss
|
|
||||||
)
|
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
@ -873,12 +870,16 @@ def train_one_epoch(
|
|||||||
# of the grad scaler is configurable, but we can't configure it to have different
|
# of the grad scaler is configurable, but we can't configure it to have different
|
||||||
# behavior depending on the current grad scale.
|
# behavior depending on the current grad scale.
|
||||||
cur_grad_scale = scaler._scale.item()
|
cur_grad_scale = scaler._scale.item()
|
||||||
if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
|
if cur_grad_scale < 1.0 or (
|
||||||
|
cur_grad_scale < 8.0 and batch_idx % 400 == 0
|
||||||
|
):
|
||||||
scaler.update(cur_grad_scale * 2.0)
|
scaler.update(cur_grad_scale * 2.0)
|
||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}")
|
raise RuntimeError(
|
||||||
|
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||||
|
)
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
@ -888,8 +889,12 @@ def train_one_epoch(
|
|||||||
f"Epoch {params.cur_epoch}, "
|
f"Epoch {params.cur_epoch}, "
|
||||||
f"batch {batch_idx}, loss[{loss_info}], "
|
f"batch {batch_idx}, loss[{loss_info}], "
|
||||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||||
f"lr: {cur_lr:.2e}, " +
|
f"lr: {cur_lr:.2e}, "
|
||||||
(f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
+ (
|
||||||
|
f"grad_scale: {scaler._scale.item()}"
|
||||||
|
if params.use_fp16
|
||||||
|
else ""
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
@ -905,12 +910,15 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
if params.use_fp16:
|
if params.use_fp16:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
"train/grad_scale",
|
||||||
|
cur_grad_scale,
|
||||||
|
params.batch_idx_train,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
batch_idx % params.valid_interval == 0
|
||||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
and not params.print_diagnostics
|
||||||
|
):
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
valid_info = compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
@ -921,7 +929,9 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
model.train()
|
model.train()
|
||||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||||
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
|
logging.info(
|
||||||
|
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||||
|
)
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
@ -997,12 +1007,11 @@ def run(rank, world_size, args):
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
model = DDP(model, device_ids=[rank],
|
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||||
find_unused_parameters=True)
|
|
||||||
|
|
||||||
optimizer = ScaledAdam(model.parameters(),
|
optimizer = ScaledAdam(
|
||||||
lr=params.base_lr,
|
model.parameters(), lr=params.base_lr, clipping_scale=2.0
|
||||||
clipping_scale=2.0)
|
)
|
||||||
|
|
||||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||||
|
|
||||||
@ -1043,7 +1052,34 @@ def run(rank, world_size, args):
|
|||||||
# You should use ../local/display_manifest_statistics.py to get
|
# You should use ../local/display_manifest_statistics.py to get
|
||||||
# an utterance duration distribution for your dataset to select
|
# an utterance duration distribution for your dataset to select
|
||||||
# the threshold
|
# the threshold
|
||||||
return 1.0 <= c.duration <= 20.0
|
if c.duration < 1.0 or c.duration > 20.0:
|
||||||
|
logging.warning(
|
||||||
|
f"Exclude cut with ID {c.id} from training. "
|
||||||
|
f"Duration: {c.duration}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# In pruned RNN-T, we require that T >= S
|
||||||
|
# where T is the number of feature frames after subsampling
|
||||||
|
# and S is the number of tokens in the utterance
|
||||||
|
|
||||||
|
# In ./zipformer.py, the conv module uses the following expression
|
||||||
|
# for subsampling
|
||||||
|
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||||
|
tokens = sp.encode(c.supervisions[0].text, out_type=str)
|
||||||
|
|
||||||
|
if T < len(tokens):
|
||||||
|
logging.warning(
|
||||||
|
f"Exclude cut with ID {c.id} from training. "
|
||||||
|
f"Number of frames (before subsampling): {c.num_frames}. "
|
||||||
|
f"Number of frames (after subsampling): {T}. "
|
||||||
|
f"Text: {c.supervisions[0].text}. "
|
||||||
|
f"Tokens: {tokens}. "
|
||||||
|
f"Number of tokens: {len(tokens)}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
@ -1071,8 +1107,7 @@ def run(rank, world_size, args):
|
|||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16,
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
init_scale=1.0)
|
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1193,7 +1228,9 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
)
|
)
|
||||||
display_and_save_batch(batch, params=params, sp=sp)
|
display_and_save_batch(batch, params=params, sp=sp)
|
||||||
raise
|
raise
|
||||||
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
|
logging.info(
|
||||||
|
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -1828,6 +1828,7 @@ def _test_zipformer_main():
|
|||||||
torch.randn(batch_size, seq_len, feature_dim),
|
torch.randn(batch_size, seq_len, feature_dim),
|
||||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||||
)
|
)
|
||||||
|
assert ((seq_len - 7) // 2 + 1) // 2 == f[0].shape[1], (seq_len, f.shape[1])
|
||||||
f[0].sum().backward()
|
f[0].sum().backward()
|
||||||
c.eval()
|
c.eval()
|
||||||
f = c(
|
f = c(
|
||||||
|
@ -90,12 +90,7 @@ from icefall.checkpoint import (
|
|||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import (
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
AttributeDict,
|
|
||||||
MetricsTracker,
|
|
||||||
setup_logger,
|
|
||||||
str2bool,
|
|
||||||
)
|
|
||||||
|
|
||||||
LRSchedulerType = Union[
|
LRSchedulerType = Union[
|
||||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||||
@ -1045,7 +1040,9 @@ def train_one_epoch(
|
|||||||
params.best_train_loss = params.train_loss
|
params.best_train_loss = params.train_loss
|
||||||
|
|
||||||
|
|
||||||
def filter_short_and_long_utterances(cuts: CutSet) -> CutSet:
|
def filter_short_and_long_utterances(
|
||||||
|
cuts: CutSet, sp: spm.SentencePieceProcessor
|
||||||
|
) -> CutSet:
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
#
|
#
|
||||||
@ -1055,7 +1052,34 @@ def filter_short_and_long_utterances(cuts: CutSet) -> CutSet:
|
|||||||
# You should use ../local/display_manifest_statistics.py to get
|
# You should use ../local/display_manifest_statistics.py to get
|
||||||
# an utterance duration distribution for your dataset to select
|
# an utterance duration distribution for your dataset to select
|
||||||
# the threshold
|
# the threshold
|
||||||
return 1.0 <= c.duration <= 20.0
|
if c.duration < 1.0 or c.duration > 20.0:
|
||||||
|
logging.warning(
|
||||||
|
f"Exclude cut with ID {c.id} from training. "
|
||||||
|
f"Duration: {c.duration}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# In pruned RNN-T, we require that T >= S
|
||||||
|
# where T is the number of feature frames after subsampling
|
||||||
|
# and S is the number of tokens in the utterance
|
||||||
|
|
||||||
|
# In ./zipformer.py, the conv module uses the following expression
|
||||||
|
# for subsampling
|
||||||
|
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||||
|
tokens = sp.encode(c.supervisions[0].text, out_type=str)
|
||||||
|
|
||||||
|
if T < len(tokens):
|
||||||
|
logging.warning(
|
||||||
|
f"Exclude cut with ID {c.id} from training. "
|
||||||
|
f"Number of frames (before subsampling): {c.num_frames}. "
|
||||||
|
f"Number of frames (after subsampling): {T}. "
|
||||||
|
f"Text: {c.supervisions[0].text}. "
|
||||||
|
f"Tokens: {tokens}. "
|
||||||
|
f"Number of tokens: {len(tokens)}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
cuts = cuts.filter(remove_short_and_long_utt)
|
cuts = cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
@ -1162,7 +1186,7 @@ def run(rank, world_size, args):
|
|||||||
train_cuts += librispeech.train_clean_360_cuts()
|
train_cuts += librispeech.train_clean_360_cuts()
|
||||||
train_cuts += librispeech.train_other_500_cuts()
|
train_cuts += librispeech.train_other_500_cuts()
|
||||||
|
|
||||||
train_cuts = filter_short_and_long_utterances(train_cuts)
|
train_cuts = filter_short_and_long_utterances(train_cuts, sp)
|
||||||
|
|
||||||
gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir)
|
gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir)
|
||||||
# XL 10k hours
|
# XL 10k hours
|
||||||
@ -1179,7 +1203,7 @@ def run(rank, world_size, args):
|
|||||||
logging.info("Using the S subset of GigaSpeech (250 hours)")
|
logging.info("Using the S subset of GigaSpeech (250 hours)")
|
||||||
train_giga_cuts = gigaspeech.train_S_cuts()
|
train_giga_cuts = gigaspeech.train_S_cuts()
|
||||||
|
|
||||||
train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts)
|
train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts, sp)
|
||||||
train_giga_cuts = train_giga_cuts.repeat(times=None)
|
train_giga_cuts = train_giga_cuts.repeat(times=None)
|
||||||
|
|
||||||
if args.enable_musan:
|
if args.enable_musan:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user