update comments

This commit is contained in:
glynpu 2023-03-17 12:17:08 +08:00
parent c055f0cc49
commit 0a5b639ec1

View File

@ -16,7 +16,6 @@
import argparse import argparse
import inspect
import logging import logging
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
@ -190,7 +189,7 @@ class HiMiaWuwDataModule:
"--input-strategy", "--input-strategy",
type=str, type=str,
default="PrecomputedFeatures", default="PrecomputedFeatures",
help="PrecomputedFeatures", help="AudioSamples or PrecomputedFeatures",
) )
group.add_argument( group.add_argument(
"--train-channel", "--train-channel",
@ -198,6 +197,8 @@ class HiMiaWuwDataModule:
default="_7_01", default="_7_01",
help="""channel of HI_MIA train dataset. help="""channel of HI_MIA train dataset.
All channels are used if it is set "all". All channels are used if it is set "all".
Please refer state 6 in prepare.sh for its meaning and other
potential values. Currently, Only "_7_01" is verified.
""", """,
) )
group.add_argument( group.add_argument(
@ -206,6 +207,8 @@ class HiMiaWuwDataModule:
default="_7_01", default="_7_01",
help="""channel of HI_MIA dev dataset. help="""channel of HI_MIA dev dataset.
All channels are used if it is set "all". All channels are used if it is set "all".
Please refer state 6 in prepare.sh for its meaning and other
potential values. Currently, Only "_7_01" is verified.
""", """,
) )
@ -248,22 +251,11 @@ class HiMiaWuwDataModule:
input_transforms = [] input_transforms = []
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append( input_transforms.append(
SpecAugment( SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor, time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks, num_frame_masks=10,
features_mask_size=27, features_mask_size=27,
num_feature_masks=2, num_feature_masks=2,
frames_mask_size=100, frames_mask_size=100,