From ece74b7542eadbed8bcfecf3f4dc4020a427cd09 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 23 Aug 2021 15:48:01 +0800 Subject: [PATCH] Remove padding in the model to make the results reproducible. --- .github/workflows/run-yesno-recipe.yml | 5 +---- egs/yesno/ASR/tdnn/decode.py | 19 +++++++------------ egs/yesno/ASR/tdnn/model.py | 3 --- egs/yesno/ASR/tdnn/train.py | 20 ++++++++++---------- 4 files changed, 18 insertions(+), 29 deletions(-) diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index 2c9b59aba..39a6a0e80 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -74,8 +74,5 @@ jobs: cd egs/yesno/ASR ./prepare.sh python3 ./tdnn/train.py - python3 ./tdnn/decode.py --avg 2 - python3 ./tdnn/decode.py --avg 3 - python3 ./tdnn/decode.py --avg 4 - python3 ./tdnn/decode.py --avg 5 + python3 ./tdnn/decode.py # TODO: Check that the WER is less than some value diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index 860ae3165..b600c182c 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -32,14 +32,14 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=9, + default=14, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, - default=4, + default=2, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", @@ -104,16 +104,11 @@ def decode_one_batch( nnet_output = model(feature) # nnet_output is [N, T, C] - supervisions = batch["supervisions"] - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - supervisions["start_frame"], - supervisions["num_frames"], - ), - 1, - ).to(torch.int32) + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) lattice = get_lattice( nnet_output=nnet_output, diff --git a/egs/yesno/ASR/tdnn/model.py b/egs/yesno/ASR/tdnn/model.py index df0aa246d..52cff37e0 100755 --- a/egs/yesno/ASR/tdnn/model.py +++ b/egs/yesno/ASR/tdnn/model.py @@ -23,7 +23,6 @@ class Tdnn(nn.Module): in_channels=num_features, out_channels=32, kernel_size=3, - padding=1, ), nn.ReLU(inplace=True), nn.BatchNorm1d(num_features=32, affine=False), @@ -31,7 +30,6 @@ class Tdnn(nn.Module): in_channels=32, out_channels=32, kernel_size=5, - padding=4, dilation=2, ), nn.ReLU(inplace=True), @@ -40,7 +38,6 @@ class Tdnn(nn.Module): in_channels=32, out_channels=32, kernel_size=5, - padding=8, dilation=4, ), nn.ReLU(inplace=True), diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py index 836dd2794..04e1ab698 100755 --- a/egs/yesno/ASR/tdnn/train.py +++ b/egs/yesno/ASR/tdnn/train.py @@ -24,12 +24,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - encode_supervisions, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, setup_logger, str2bool def get_parser(): @@ -61,7 +56,7 @@ def get_parser(): parser.add_argument( "--num-epochs", type=int, - default=10, + default=15, help="Number of epochs to train.", ) @@ -129,7 +124,7 @@ def get_params() -> AttributeDict: { "exp_dir": Path("tdnn/exp"), "lang_dir": Path("data/lang_phone"), - "lr": 1e-1, + "lr": 1e-2, "feature_dim": 23, "weight_decay": 1e-6, "start_epoch": 0, @@ -277,9 +272,14 @@ def compute_loss( # different duration in decreasing order, required by # `k2.intersect_dense` called in `k2.ctc_loss` supervisions = batch["supervisions"] - supervision_segments, texts = encode_supervisions( - supervisions, subsampling_factor=1 + texts = supervisions["text"] + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, ) + decoding_graph = graph_compiler.compile(texts) dense_fsa_vec = k2.DenseFsaVec(