mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 19:42:19 +00:00
Remove padding in the model to make the results reproducible.
This commit is contained in:
parent
2e37b29e66
commit
ece74b7542
5
.github/workflows/run-yesno-recipe.yml
vendored
5
.github/workflows/run-yesno-recipe.yml
vendored
@ -74,8 +74,5 @@ jobs:
|
||||
cd egs/yesno/ASR
|
||||
./prepare.sh
|
||||
python3 ./tdnn/train.py
|
||||
python3 ./tdnn/decode.py --avg 2
|
||||
python3 ./tdnn/decode.py --avg 3
|
||||
python3 ./tdnn/decode.py --avg 4
|
||||
python3 ./tdnn/decode.py --avg 5
|
||||
python3 ./tdnn/decode.py
|
||||
# TODO: Check that the WER is less than some value
|
||||
|
@ -32,14 +32,14 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=9,
|
||||
default=14,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=4,
|
||||
default=2,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
@ -104,16 +104,11 @@ def decode_one_batch(
|
||||
nnet_output = model(feature)
|
||||
# nnet_output is [N, T, C]
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
supervisions["start_frame"],
|
||||
supervisions["num_frames"],
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
batch_size = nnet_output.shape[0]
|
||||
supervision_segments = torch.tensor(
|
||||
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=nnet_output,
|
||||
|
@ -23,7 +23,6 @@ class Tdnn(nn.Module):
|
||||
in_channels=num_features,
|
||||
out_channels=32,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.BatchNorm1d(num_features=32, affine=False),
|
||||
@ -31,7 +30,6 @@ class Tdnn(nn.Module):
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
kernel_size=5,
|
||||
padding=4,
|
||||
dilation=2,
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
@ -40,7 +38,6 @@ class Tdnn(nn.Module):
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
kernel_size=5,
|
||||
padding=8,
|
||||
dilation=4,
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
|
@ -24,12 +24,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
encode_supervisions,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -61,7 +56,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=10,
|
||||
default=15,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
@ -129,7 +124,7 @@ def get_params() -> AttributeDict:
|
||||
{
|
||||
"exp_dir": Path("tdnn/exp"),
|
||||
"lang_dir": Path("data/lang_phone"),
|
||||
"lr": 1e-1,
|
||||
"lr": 1e-2,
|
||||
"feature_dim": 23,
|
||||
"weight_decay": 1e-6,
|
||||
"start_epoch": 0,
|
||||
@ -277,9 +272,14 @@ def compute_loss(
|
||||
# different duration in decreasing order, required by
|
||||
# `k2.intersect_dense` called in `k2.ctc_loss`
|
||||
supervisions = batch["supervisions"]
|
||||
supervision_segments, texts = encode_supervisions(
|
||||
supervisions, subsampling_factor=1
|
||||
texts = supervisions["text"]
|
||||
|
||||
batch_size = nnet_output.shape[0]
|
||||
supervision_segments = torch.tensor(
|
||||
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
decoding_graph = graph_compiler.compile(texts)
|
||||
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
|
Loading…
x
Reference in New Issue
Block a user