mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Minor fixes
This commit is contained in:
parent
0928939e38
commit
8f21e92b5f
@ -29,6 +29,7 @@ from lhotse import (
|
||||
load_manifest,
|
||||
set_caching_enabled,
|
||||
)
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset import (
|
||||
DynamicBucketingSampler,
|
||||
CutConcatenate,
|
||||
@ -101,7 +102,7 @@ class WenetSpeechDataModule:
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
default=300,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
@ -285,6 +286,11 @@ class WenetSpeechDataModule:
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
return 1.0 <= c.duration <= 16.0
|
||||
|
||||
train_sampler.filter(remove_short_and_long_utt)
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
|
@ -33,7 +33,6 @@ from asr_datamodule import WenetSpeechDataModule
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from torch import Tensor
|
||||
@ -196,7 +195,7 @@ def get_params() -> AttributeDict:
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
# parameters for Noam
|
||||
"warm_step": 80000, # For the 100h subset, use 8k
|
||||
"warm_step": 1600000, # For the 100h subset, use 8k
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
@ -591,22 +590,6 @@ def run(rank, world_size, args):
|
||||
wenetspeech = WenetSpeechDataModule(args)
|
||||
train_cuts = wenetspeech.train_cuts()
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
return 1.0 <= c.duration <= 20.0
|
||||
|
||||
num_in_total = len(train_cuts)
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
num_left = len(train_cuts)
|
||||
num_removed = num_in_total - num_left
|
||||
removed_percent = num_removed / num_in_total * 100
|
||||
|
||||
logging.info(f"Before removing short and long utterances: {num_in_total}")
|
||||
logging.info(f"After removing short and long utterances: {num_left}")
|
||||
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
|
||||
|
||||
train_dl = wenetspeech.train_dataloaders(train_cuts)
|
||||
valid_dl = wenetspeech.valid_dataloaders(wenetspeech.valid_cuts())
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user