mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
[yesno] Remove padding in TDNN (#21)
* Disable SpecAug for yesno. Also replace Adam with SGD. * Remove padding in the model to make the results reproducible.
This commit is contained in:
parent
6c2c9b9d74
commit
57cb611665
15
.github/workflows/run-yesno-recipe.yml
vendored
15
.github/workflows/run-yesno-recipe.yml
vendored
@ -69,21 +69,10 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
echo $PYTHONPATH
|
echo $PYTHONPATH
|
||||||
ls -lh
|
|
||||||
|
|
||||||
# The following three lines are for macOS
|
|
||||||
lib_path=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")
|
|
||||||
echo "lib_path: $lib_path"
|
|
||||||
export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH
|
|
||||||
ls -lh $lib_path
|
|
||||||
|
|
||||||
cd egs/yesno/ASR
|
cd egs/yesno/ASR
|
||||||
./prepare.sh
|
./prepare.sh
|
||||||
python3 ./tdnn/train.py --num-epochs 100
|
python3 ./tdnn/train.py
|
||||||
python3 ./tdnn/decode.py --epoch 99
|
python3 ./tdnn/decode.py
|
||||||
python3 ./tdnn/decode.py --epoch 95
|
|
||||||
python3 ./tdnn/decode.py --epoch 90
|
|
||||||
python3 ./tdnn/decode.py --epoch 80
|
|
||||||
python3 ./tdnn/decode.py --epoch 70
|
|
||||||
python3 ./tdnn/decode.py --epoch 60
|
|
||||||
# TODO: Check that the WER is less than some value
|
# TODO: Check that the WER is less than some value
|
||||||
|
@ -27,7 +27,6 @@ from lhotse.dataset import (
|
|||||||
K2SpeechRecognitionDataset,
|
K2SpeechRecognitionDataset,
|
||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SingleCutSampler,
|
SingleCutSampler,
|
||||||
SpecAugment,
|
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@ -163,18 +162,8 @@ class YesNoAsrDataModule(DataModule):
|
|||||||
)
|
)
|
||||||
] + transforms
|
] + transforms
|
||||||
|
|
||||||
input_transforms = [
|
|
||||||
SpecAugment(
|
|
||||||
num_frame_masks=2,
|
|
||||||
features_mask_size=27,
|
|
||||||
num_feature_masks=2,
|
|
||||||
frames_mask_size=100,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
train = K2SpeechRecognitionDataset(
|
train = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_transforms=input_transforms,
|
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -194,7 +183,6 @@ class YesNoAsrDataModule(DataModule):
|
|||||||
input_strategy=OnTheFlyFeatures(
|
input_strategy=OnTheFlyFeatures(
|
||||||
Fbank(FbankConfig(num_mel_bins=23))
|
Fbank(FbankConfig(num_mel_bins=23))
|
||||||
),
|
),
|
||||||
input_transforms=input_transforms,
|
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -32,14 +32,14 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=9,
|
default=14,
|
||||||
help="It specifies the checkpoint to use for decoding."
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"Note: Epoch counts from 0.",
|
"Note: Epoch counts from 0.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=15,
|
default=2,
|
||||||
help="Number of checkpoints to average. Automatically select "
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. ",
|
"'--epoch'. ",
|
||||||
@ -104,16 +104,11 @@ def decode_one_batch(
|
|||||||
nnet_output = model(feature)
|
nnet_output = model(feature)
|
||||||
# nnet_output is [N, T, C]
|
# nnet_output is [N, T, C]
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
batch_size = nnet_output.shape[0]
|
||||||
|
supervision_segments = torch.tensor(
|
||||||
supervision_segments = torch.stack(
|
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
|
||||||
(
|
dtype=torch.int32,
|
||||||
supervisions["sequence_idx"],
|
)
|
||||||
supervisions["start_frame"],
|
|
||||||
supervisions["num_frames"],
|
|
||||||
),
|
|
||||||
1,
|
|
||||||
).to(torch.int32)
|
|
||||||
|
|
||||||
lattice = get_lattice(
|
lattice = get_lattice(
|
||||||
nnet_output=nnet_output,
|
nnet_output=nnet_output,
|
||||||
|
@ -23,7 +23,6 @@ class Tdnn(nn.Module):
|
|||||||
in_channels=num_features,
|
in_channels=num_features,
|
||||||
out_channels=32,
|
out_channels=32,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
padding=1,
|
|
||||||
),
|
),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.BatchNorm1d(num_features=32, affine=False),
|
nn.BatchNorm1d(num_features=32, affine=False),
|
||||||
@ -31,7 +30,6 @@ class Tdnn(nn.Module):
|
|||||||
in_channels=32,
|
in_channels=32,
|
||||||
out_channels=32,
|
out_channels=32,
|
||||||
kernel_size=5,
|
kernel_size=5,
|
||||||
padding=4,
|
|
||||||
dilation=2,
|
dilation=2,
|
||||||
),
|
),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
@ -40,7 +38,6 @@ class Tdnn(nn.Module):
|
|||||||
in_channels=32,
|
in_channels=32,
|
||||||
out_channels=32,
|
out_channels=32,
|
||||||
kernel_size=5,
|
kernel_size=5,
|
||||||
padding=8,
|
|
||||||
dilation=4,
|
dilation=4,
|
||||||
),
|
),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
|
@ -24,12 +24,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
|||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||||
AttributeDict,
|
|
||||||
encode_supervisions,
|
|
||||||
setup_logger,
|
|
||||||
str2bool,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -61,7 +56,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-epochs",
|
"--num-epochs",
|
||||||
type=int,
|
type=int,
|
||||||
default=50,
|
default=15,
|
||||||
help="Number of epochs to train.",
|
help="Number of epochs to train.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -129,11 +124,10 @@ def get_params() -> AttributeDict:
|
|||||||
{
|
{
|
||||||
"exp_dir": Path("tdnn/exp"),
|
"exp_dir": Path("tdnn/exp"),
|
||||||
"lang_dir": Path("data/lang_phone"),
|
"lang_dir": Path("data/lang_phone"),
|
||||||
"lr": 1e-3,
|
"lr": 1e-2,
|
||||||
"feature_dim": 23,
|
"feature_dim": 23,
|
||||||
"weight_decay": 1e-6,
|
"weight_decay": 1e-6,
|
||||||
"start_epoch": 0,
|
"start_epoch": 0,
|
||||||
"num_epochs": 50,
|
|
||||||
"best_train_loss": float("inf"),
|
"best_train_loss": float("inf"),
|
||||||
"best_valid_loss": float("inf"),
|
"best_valid_loss": float("inf"),
|
||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
@ -278,9 +272,14 @@ def compute_loss(
|
|||||||
# different duration in decreasing order, required by
|
# different duration in decreasing order, required by
|
||||||
# `k2.intersect_dense` called in `k2.ctc_loss`
|
# `k2.intersect_dense` called in `k2.ctc_loss`
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
supervision_segments, texts = encode_supervisions(
|
texts = supervisions["text"]
|
||||||
supervisions, subsampling_factor=1
|
|
||||||
|
batch_size = nnet_output.shape[0]
|
||||||
|
supervision_segments = torch.tensor(
|
||||||
|
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
|
||||||
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
|
|
||||||
decoding_graph = graph_compiler.compile(texts)
|
decoding_graph = graph_compiler.compile(texts)
|
||||||
|
|
||||||
dense_fsa_vec = k2.DenseFsaVec(
|
dense_fsa_vec = k2.DenseFsaVec(
|
||||||
@ -491,7 +490,7 @@ def run(rank, world_size, args):
|
|||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
model = DDP(model, device_ids=[rank])
|
model = DDP(model, device_ids=[rank])
|
||||||
|
|
||||||
optimizer = optim.AdamW(
|
optimizer = optim.SGD(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
lr=params.lr,
|
lr=params.lr,
|
||||||
weight_decay=params.weight_decay,
|
weight_decay=params.weight_decay,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user