mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Disable SpecAug for yesno.
Also replace Adam with SGD.
This commit is contained in:
parent
6c2c9b9d74
commit
2e37b29e66
18
.github/workflows/run-yesno-recipe.yml
vendored
18
.github/workflows/run-yesno-recipe.yml
vendored
@ -69,21 +69,13 @@ jobs:
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
echo $PYTHONPATH
|
||||
ls -lh
|
||||
|
||||
# The following three lines are for macOS
|
||||
lib_path=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")
|
||||
echo "lib_path: $lib_path"
|
||||
export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH
|
||||
ls -lh $lib_path
|
||||
|
||||
cd egs/yesno/ASR
|
||||
./prepare.sh
|
||||
python3 ./tdnn/train.py --num-epochs 100
|
||||
python3 ./tdnn/decode.py --epoch 99
|
||||
python3 ./tdnn/decode.py --epoch 95
|
||||
python3 ./tdnn/decode.py --epoch 90
|
||||
python3 ./tdnn/decode.py --epoch 80
|
||||
python3 ./tdnn/decode.py --epoch 70
|
||||
python3 ./tdnn/decode.py --epoch 60
|
||||
python3 ./tdnn/train.py
|
||||
python3 ./tdnn/decode.py --avg 2
|
||||
python3 ./tdnn/decode.py --avg 3
|
||||
python3 ./tdnn/decode.py --avg 4
|
||||
python3 ./tdnn/decode.py --avg 5
|
||||
# TODO: Check that the WER is less than some value
|
||||
|
@ -27,7 +27,6 @@ from lhotse.dataset import (
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from torch.utils.data import DataLoader
|
||||
@ -163,18 +162,8 @@ class YesNoAsrDataModule(DataModule):
|
||||
)
|
||||
] + transforms
|
||||
|
||||
input_transforms = [
|
||||
SpecAugment(
|
||||
num_frame_masks=2,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
]
|
||||
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
@ -194,7 +183,6 @@ class YesNoAsrDataModule(DataModule):
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=23))
|
||||
),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
|
@ -39,7 +39,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
default=4,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
|
@ -61,7 +61,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=50,
|
||||
default=10,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
@ -129,11 +129,10 @@ def get_params() -> AttributeDict:
|
||||
{
|
||||
"exp_dir": Path("tdnn/exp"),
|
||||
"lang_dir": Path("data/lang_phone"),
|
||||
"lr": 1e-3,
|
||||
"lr": 1e-1,
|
||||
"feature_dim": 23,
|
||||
"weight_decay": 1e-6,
|
||||
"start_epoch": 0,
|
||||
"num_epochs": 50,
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
@ -491,7 +490,7 @@ def run(rank, world_size, args):
|
||||
if world_size > 1:
|
||||
model = DDP(model, device_ids=[rank])
|
||||
|
||||
optimizer = optim.AdamW(
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(),
|
||||
lr=params.lr,
|
||||
weight_decay=params.weight_decay,
|
||||
|
Loading…
x
Reference in New Issue
Block a user