Merge branch 'k2-fsa:master' into fix/k2ssl-multi-gpu

This commit is contained in:
Yifan Yang 2024-05-09 20:27:58 +08:00 committed by GitHub
commit 967bf92d87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 13 deletions

View File

@ -181,7 +181,7 @@ class YesNoAsrDataModule(DataModule):
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(
FbankConfig(sampling_rate=8000, num_mel_bins=23) Fbank(FbankConfig(sampling_rate=8000, num_mel_bins=23))
), ),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -222,9 +222,11 @@ class YesNoAsrDataModule(DataModule):
logging.debug("About to create test dataset") logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset( test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23))) input_strategy=(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23)))
if self.args.on_the_fly_feats if self.args.on_the_fly_feats
else PrecomputedFeatures(), else PrecomputedFeatures()
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = DynamicBucketingSampler( sampler = DynamicBucketingSampler(

View File

@ -110,13 +110,6 @@ def str2bool(v):
raise argparse.ArgumentTypeError("Boolean value expected.") raise argparse.ArgumentTypeError("Boolean value expected.")
def clear_log_handlers():
logger = logging.getLogger()
handlers = logger.handlers[:]
for handler in handlers:
logger.removeHandler(handler)
def setup_logger( def setup_logger(
log_filename: Pathlike, log_filename: Pathlike,
log_level: str = "info", log_level: str = "info",
@ -133,8 +126,6 @@ def setup_logger(
use_console: use_console:
True to also print logs to console. True to also print logs to console.
""" """
clear_log_handlers()
now = datetime.now() now = datetime.now()
date_time = now.strftime("%Y-%m-%d-%H-%M-%S") date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
if dist.is_available() and dist.is_initialized(): if dist.is_available() and dist.is_initialized():