Minor fixes

This commit is contained in:
PingFeng Luo 2022-01-14 12:12:38 +08:00
parent 0928939e38
commit 8f21e92b5f
2 changed files with 8 additions and 19 deletions

View File

@ -29,6 +29,7 @@ from lhotse import (
load_manifest,
set_caching_enabled,
)
from lhotse.cut import Cut
from lhotse.dataset import (
DynamicBucketingSampler,
CutConcatenate,
@ -101,7 +102,7 @@ class WenetSpeechDataModule:
group.add_argument(
"--num-buckets",
type=int,
default=30,
default=300,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
@ -285,6 +286,11 @@ class WenetSpeechDataModule:
)
logging.info("About to create train dataloader")
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
return 1.0 <= c.duration <= 16.0
train_sampler.filter(remove_short_and_long_utt)
train_dl = DataLoader(
train,
sampler=train_sampler,

View File

@ -33,7 +33,6 @@ from asr_datamodule import WenetSpeechDataModule
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from model import Transducer
from torch import Tensor
@ -196,7 +195,7 @@ def get_params() -> AttributeDict:
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for Noam
"warm_step": 80000, # For the 100h subset, use 8k
"warm_step": 1600000, # For the 100h subset, use 8k
"env_info": get_env_info(),
}
)
@ -591,22 +590,6 @@ def run(rank, world_size, args):
wenetspeech = WenetSpeechDataModule(args)
train_cuts = wenetspeech.train_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
return 1.0 <= c.duration <= 20.0
num_in_total = len(train_cuts)
train_cuts = train_cuts.filter(remove_short_and_long_utt)
num_left = len(train_cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
logging.info(f"Before removing short and long utterances: {num_in_total}")
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
train_dl = wenetspeech.train_dataloaders(train_cuts)
valid_dl = wenetspeech.valid_dataloaders(wenetspeech.valid_cuts())