Remove padding in the model to make the results reproducible.

This commit is contained in:
Fangjun Kuang 2021-08-23 15:48:01 +08:00
parent 2e37b29e66
commit ece74b7542
4 changed files with 18 additions and 29 deletions

View File

@ -74,8 +74,5 @@ jobs:
cd egs/yesno/ASR
./prepare.sh
python3 ./tdnn/train.py
python3 ./tdnn/decode.py --avg 2
python3 ./tdnn/decode.py --avg 3
python3 ./tdnn/decode.py --avg 4
python3 ./tdnn/decode.py --avg 5
python3 ./tdnn/decode.py
# TODO: Check that the WER is less than some value

View File

@ -32,14 +32,14 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
default=9,
default=14,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=4,
default=2,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
@ -104,16 +104,11 @@ def decode_one_batch(
nnet_output = model(feature)
# nnet_output is [N, T, C]
supervisions = batch["supervisions"]
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"],
supervisions["num_frames"],
),
1,
).to(torch.int32)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
lattice = get_lattice(
nnet_output=nnet_output,

View File

@ -23,7 +23,6 @@ class Tdnn(nn.Module):
in_channels=num_features,
out_channels=32,
kernel_size=3,
padding=1,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=32, affine=False),
@ -31,7 +30,6 @@ class Tdnn(nn.Module):
in_channels=32,
out_channels=32,
kernel_size=5,
padding=4,
dilation=2,
),
nn.ReLU(inplace=True),
@ -40,7 +38,6 @@ class Tdnn(nn.Module):
in_channels=32,
out_channels=32,
kernel_size=5,
padding=8,
dilation=4,
),
nn.ReLU(inplace=True),

View File

@ -24,12 +24,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
encode_supervisions,
setup_logger,
str2bool,
)
from icefall.utils import AttributeDict, setup_logger, str2bool
def get_parser():
@ -61,7 +56,7 @@ def get_parser():
parser.add_argument(
"--num-epochs",
type=int,
default=10,
default=15,
help="Number of epochs to train.",
)
@ -129,7 +124,7 @@ def get_params() -> AttributeDict:
{
"exp_dir": Path("tdnn/exp"),
"lang_dir": Path("data/lang_phone"),
"lr": 1e-1,
"lr": 1e-2,
"feature_dim": 23,
"weight_decay": 1e-6,
"start_epoch": 0,
@ -277,9 +272,14 @@ def compute_loss(
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
supervisions = batch["supervisions"]
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=1
texts = supervisions["text"]
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
decoding_graph = graph_compiler.compile(texts)
dense_fsa_vec = k2.DenseFsaVec(