Disable SpecAug for yesno.

Also replace Adam with SGD.
This commit is contained in:
Fangjun Kuang 2021-08-23 13:57:46 +08:00
parent 6c2c9b9d74
commit 2e37b29e66
4 changed files with 9 additions and 30 deletions

View File

@ -69,21 +69,13 @@ jobs:
run: | run: |
export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=$PWD:$PYTHONPATH
echo $PYTHONPATH echo $PYTHONPATH
ls -lh
# The following three lines are for macOS
lib_path=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")
echo "lib_path: $lib_path"
export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH
ls -lh $lib_path
cd egs/yesno/ASR cd egs/yesno/ASR
./prepare.sh ./prepare.sh
python3 ./tdnn/train.py --num-epochs 100 python3 ./tdnn/train.py
python3 ./tdnn/decode.py --epoch 99 python3 ./tdnn/decode.py --avg 2
python3 ./tdnn/decode.py --epoch 95 python3 ./tdnn/decode.py --avg 3
python3 ./tdnn/decode.py --epoch 90 python3 ./tdnn/decode.py --avg 4
python3 ./tdnn/decode.py --epoch 80 python3 ./tdnn/decode.py --avg 5
python3 ./tdnn/decode.py --epoch 70
python3 ./tdnn/decode.py --epoch 60
# TODO: Check that the WER is less than some value # TODO: Check that the WER is less than some value

View File

@ -27,7 +27,6 @@ from lhotse.dataset import (
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler, SingleCutSampler,
SpecAugment,
) )
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -163,18 +162,8 @@ class YesNoAsrDataModule(DataModule):
) )
] + transforms ] + transforms
input_transforms = [
SpecAugment(
num_frame_masks=2,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
]
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -194,7 +183,6 @@ class YesNoAsrDataModule(DataModule):
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=23)) Fbank(FbankConfig(num_mel_bins=23))
), ),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )

View File

@ -39,7 +39,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=15, default=4,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. ",

View File

@ -61,7 +61,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--num-epochs", "--num-epochs",
type=int, type=int,
default=50, default=10,
help="Number of epochs to train.", help="Number of epochs to train.",
) )
@ -129,11 +129,10 @@ def get_params() -> AttributeDict:
{ {
"exp_dir": Path("tdnn/exp"), "exp_dir": Path("tdnn/exp"),
"lang_dir": Path("data/lang_phone"), "lang_dir": Path("data/lang_phone"),
"lr": 1e-3, "lr": 1e-1,
"feature_dim": 23, "feature_dim": 23,
"weight_decay": 1e-6, "weight_decay": 1e-6,
"start_epoch": 0, "start_epoch": 0,
"num_epochs": 50,
"best_train_loss": float("inf"), "best_train_loss": float("inf"),
"best_valid_loss": float("inf"), "best_valid_loss": float("inf"),
"best_train_epoch": -1, "best_train_epoch": -1,
@ -491,7 +490,7 @@ def run(rank, world_size, args):
if world_size > 1: if world_size > 1:
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank])
optimizer = optim.AdamW( optimizer = optim.SGD(
model.parameters(), model.parameters(),
lr=params.lr, lr=params.lr,
weight_decay=params.weight_decay, weight_decay=params.weight_decay,