Minor fixes

This commit is contained in:
PingFeng Luo 2022-01-14 12:12:38 +08:00
parent 0928939e38
commit 8f21e92b5f
2 changed files with 8 additions and 19 deletions

View File

@ -29,6 +29,7 @@ from lhotse import (
load_manifest, load_manifest,
set_caching_enabled, set_caching_enabled,
) )
from lhotse.cut import Cut
from lhotse.dataset import ( from lhotse.dataset import (
DynamicBucketingSampler, DynamicBucketingSampler,
CutConcatenate, CutConcatenate,
@ -101,7 +102,7 @@ class WenetSpeechDataModule:
group.add_argument( group.add_argument(
"--num-buckets", "--num-buckets",
type=int, type=int,
default=30, default=300,
help="The number of buckets for the DynamicBucketingSampler" help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).", "(you might want to increase it for larger datasets).",
) )
@ -285,6 +286,11 @@ class WenetSpeechDataModule:
) )
logging.info("About to create train dataloader") logging.info("About to create train dataloader")
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
return 1.0 <= c.duration <= 16.0
train_sampler.filter(remove_short_and_long_utt)
train_dl = DataLoader( train_dl = DataLoader(
train, train,
sampler=train_sampler, sampler=train_sampler,

View File

@ -33,7 +33,6 @@ from asr_datamodule import WenetSpeechDataModule
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from torch import Tensor from torch import Tensor
@ -196,7 +195,7 @@ def get_params() -> AttributeDict:
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
# parameters for Noam # parameters for Noam
"warm_step": 80000, # For the 100h subset, use 8k "warm_step": 1600000, # For the 100h subset, use 8k
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -591,22 +590,6 @@ def run(rank, world_size, args):
wenetspeech = WenetSpeechDataModule(args) wenetspeech = WenetSpeechDataModule(args)
train_cuts = wenetspeech.train_cuts() train_cuts = wenetspeech.train_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
return 1.0 <= c.duration <= 20.0
num_in_total = len(train_cuts)
train_cuts = train_cuts.filter(remove_short_and_long_utt)
num_left = len(train_cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
logging.info(f"Before removing short and long utterances: {num_in_total}")
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
train_dl = wenetspeech.train_dataloaders(train_cuts) train_dl = wenetspeech.train_dataloaders(train_cuts)
valid_dl = wenetspeech.valid_dataloaders(wenetspeech.valid_cuts()) valid_dl = wenetspeech.valid_dataloaders(wenetspeech.valid_cuts())