[yesno] Remove padding in TDNN (#21)

* Disable SpecAug for yesno.

Also replace Adam with SGD.

* Remove padding in the model to make the results reproducible.
This commit is contained in:
Fangjun Kuang 2021-08-23 15:59:36 +08:00 committed by GitHub
parent 6c2c9b9d74
commit 57cb611665
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 20 additions and 52 deletions

View File

@ -69,21 +69,10 @@ jobs:
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
echo $PYTHONPATH
ls -lh
# The following three lines are for macOS
lib_path=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")
echo "lib_path: $lib_path"
export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH
ls -lh $lib_path
cd egs/yesno/ASR
./prepare.sh
python3 ./tdnn/train.py --num-epochs 100
python3 ./tdnn/decode.py --epoch 99
python3 ./tdnn/decode.py --epoch 95
python3 ./tdnn/decode.py --epoch 90
python3 ./tdnn/decode.py --epoch 80
python3 ./tdnn/decode.py --epoch 70
python3 ./tdnn/decode.py --epoch 60
python3 ./tdnn/train.py
python3 ./tdnn/decode.py
# TODO: Check that the WER is less than some value

View File

@ -27,7 +27,6 @@ from lhotse.dataset import (
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
@ -163,18 +162,8 @@ class YesNoAsrDataModule(DataModule):
)
] + transforms
input_transforms = [
SpecAugment(
num_frame_masks=2,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
]
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
@ -194,7 +183,6 @@ class YesNoAsrDataModule(DataModule):
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=23))
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)

View File

@ -32,14 +32,14 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
default=9,
default=14,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=15,
default=2,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
@ -104,16 +104,11 @@ def decode_one_batch(
nnet_output = model(feature)
# nnet_output is [N, T, C]
supervisions = batch["supervisions"]
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"],
supervisions["num_frames"],
),
1,
).to(torch.int32)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
lattice = get_lattice(
nnet_output=nnet_output,

View File

@ -23,7 +23,6 @@ class Tdnn(nn.Module):
in_channels=num_features,
out_channels=32,
kernel_size=3,
padding=1,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=32, affine=False),
@ -31,7 +30,6 @@ class Tdnn(nn.Module):
in_channels=32,
out_channels=32,
kernel_size=5,
padding=4,
dilation=2,
),
nn.ReLU(inplace=True),
@ -40,7 +38,6 @@ class Tdnn(nn.Module):
in_channels=32,
out_channels=32,
kernel_size=5,
padding=8,
dilation=4,
),
nn.ReLU(inplace=True),

View File

@ -24,12 +24,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
encode_supervisions,
setup_logger,
str2bool,
)
from icefall.utils import AttributeDict, setup_logger, str2bool
def get_parser():
@ -61,7 +56,7 @@ def get_parser():
parser.add_argument(
"--num-epochs",
type=int,
default=50,
default=15,
help="Number of epochs to train.",
)
@ -129,11 +124,10 @@ def get_params() -> AttributeDict:
{
"exp_dir": Path("tdnn/exp"),
"lang_dir": Path("data/lang_phone"),
"lr": 1e-3,
"lr": 1e-2,
"feature_dim": 23,
"weight_decay": 1e-6,
"start_epoch": 0,
"num_epochs": 50,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
@ -278,9 +272,14 @@ def compute_loss(
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
supervisions = batch["supervisions"]
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=1
texts = supervisions["text"]
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
decoding_graph = graph_compiler.compile(texts)
dense_fsa_vec = k2.DenseFsaVec(
@ -491,7 +490,7 @@ def run(rank, world_size, args):
if world_size > 1:
model = DDP(model, device_ids=[rank])
optimizer = optim.AdamW(
optimizer = optim.SGD(
model.parameters(),
lr=params.lr,
weight_decay=params.weight_decay,