Disable SpecAug for yesno.

Also replace Adam with SGD.
This commit is contained in:
Fangjun Kuang 2021-08-23 13:57:46 +08:00
parent 6c2c9b9d74
commit 2e37b29e66
4 changed files with 9 additions and 30 deletions

View File

@ -69,21 +69,13 @@ jobs:
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
echo $PYTHONPATH
ls -lh
# The following three lines are for macOS
lib_path=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")
echo "lib_path: $lib_path"
export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH
ls -lh $lib_path
cd egs/yesno/ASR
./prepare.sh
python3 ./tdnn/train.py --num-epochs 100
python3 ./tdnn/decode.py --epoch 99
python3 ./tdnn/decode.py --epoch 95
python3 ./tdnn/decode.py --epoch 90
python3 ./tdnn/decode.py --epoch 80
python3 ./tdnn/decode.py --epoch 70
python3 ./tdnn/decode.py --epoch 60
python3 ./tdnn/train.py
python3 ./tdnn/decode.py --avg 2
python3 ./tdnn/decode.py --avg 3
python3 ./tdnn/decode.py --avg 4
python3 ./tdnn/decode.py --avg 5
# TODO: Check that the WER is less than some value

View File

@ -27,7 +27,6 @@ from lhotse.dataset import (
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
@ -163,18 +162,8 @@ class YesNoAsrDataModule(DataModule):
)
] + transforms
input_transforms = [
SpecAugment(
num_frame_masks=2,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
]
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
@ -194,7 +183,6 @@ class YesNoAsrDataModule(DataModule):
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=23))
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)

View File

@ -39,7 +39,7 @@ def get_parser():
parser.add_argument(
"--avg",
type=int,
default=15,
default=4,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",

View File

@ -61,7 +61,7 @@ def get_parser():
parser.add_argument(
"--num-epochs",
type=int,
default=50,
default=10,
help="Number of epochs to train.",
)
@ -129,11 +129,10 @@ def get_params() -> AttributeDict:
{
"exp_dir": Path("tdnn/exp"),
"lang_dir": Path("data/lang_phone"),
"lr": 1e-3,
"lr": 1e-1,
"feature_dim": 23,
"weight_decay": 1e-6,
"start_epoch": 0,
"num_epochs": 50,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
@ -491,7 +490,7 @@ def run(rank, world_size, args):
if world_size > 1:
model = DDP(model, device_ids=[rank])
optimizer = optim.AdamW(
optimizer = optim.SGD(
model.parameters(),
lr=params.lr,
weight_decay=params.weight_decay,