mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Merge master.
This commit is contained in:
parent
b09224fb3a
commit
69a2bd5179
@ -1,7 +1,20 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
||||
# Apache 2.0
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
import warnings
|
||||
@ -396,7 +409,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
:,
|
||||
self.pe.size(1) // 2
|
||||
- x.size(1)
|
||||
+ 1 : self.pe.size(1) // 2
|
||||
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||
+ x.size(1),
|
||||
]
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
|
@ -1,8 +1,20 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# (still working in progress)
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
@ -45,28 +57,63 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=9,
|
||||
default=34,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=1,
|
||||
default=20,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="attention-decoder",
|
||||
help="""Decoding method.
|
||||
Supported values are:
|
||||
- (1) 1best. Extract the best path from the decoding lattice as the
|
||||
decoding result.
|
||||
- (2) nbest. Extract n paths from the decoding lattice; the path
|
||||
with the highest score is the decoding result.
|
||||
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
|
||||
the highest score is the decoding result.
|
||||
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
|
||||
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
|
||||
is the decoding result.
|
||||
- (5) attention-decoder. Extract n paths from the LM rescored
|
||||
lattice, the path with the highest score is the decoding result.
|
||||
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
rescoring method can achieve. Useful for debugging n-best
|
||||
rescoring method.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""Number of paths for n-best based decoding method.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lattice-score-scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The scale to be applied to `lattice.scores`."
|
||||
"It's needed if you use any kinds of n-best based rescoring. "
|
||||
"Currently, it is used when the decoding method is: nbest, "
|
||||
"nbest-rescoring, attention-decoder, and nbest-oracle. "
|
||||
"A smaller value results in more unique paths.",
|
||||
help="""The scale to be applied to `lattice.scores`.
|
||||
It's needed if you use any kinds of n-best based rescoring.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
||||
A smaller value results in more unique paths.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
@ -92,21 +139,6 @@ def get_params() -> AttributeDict:
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
# Possible values for method:
|
||||
# - 1best
|
||||
# - nbest
|
||||
# - nbest-rescoring
|
||||
# - whole-lattice-rescoring
|
||||
# - attention-decoder
|
||||
# - nbest-oracle
|
||||
# "method": "nbest",
|
||||
# "method": "nbest-rescoring",
|
||||
# "method": "whole-lattice-rescoring",
|
||||
"method": "attention-decoder",
|
||||
# "method": "nbest-oracle",
|
||||
# num_paths is used when method is "nbest", "nbest-rescoring",
|
||||
# attention-decoder, and nbest-oracle
|
||||
"num_paths": 100,
|
||||
}
|
||||
)
|
||||
return params
|
||||
@ -117,7 +149,7 @@ def decode_one_batch(
|
||||
model: nn.Module,
|
||||
HLG: k2.Fsa,
|
||||
batch: dict,
|
||||
lexicon: Lexicon,
|
||||
word_table: k2.SymbolTable,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
@ -151,8 +183,8 @@ def decode_one_batch(
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
lexicon:
|
||||
It contains word symbol table.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
sos_id:
|
||||
The token ID of the SOS.
|
||||
eos_id:
|
||||
@ -205,7 +237,7 @@ def decode_one_batch(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=supervisions["text"],
|
||||
lexicon=lexicon,
|
||||
word_table=word_table,
|
||||
scale=params.lattice_score_scale,
|
||||
)
|
||||
|
||||
@ -225,7 +257,7 @@ def decode_one_batch(
|
||||
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
return {key: hyps}
|
||||
|
||||
assert params.method in [
|
||||
@ -271,7 +303,7 @@ def decode_one_batch(
|
||||
ans = dict()
|
||||
for lm_scale_str, best_path in best_path_dict.items():
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
ans[lm_scale_str] = hyps
|
||||
return ans
|
||||
|
||||
@ -281,7 +313,7 @@ def decode_dataset(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
HLG: k2.Fsa,
|
||||
lexicon: Lexicon,
|
||||
word_table: k2.SymbolTable,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
@ -297,8 +329,8 @@ def decode_dataset(
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph.
|
||||
lexicon:
|
||||
It contains word symbol table.
|
||||
word_table:
|
||||
It is the word symbol table.
|
||||
sos_id:
|
||||
The token ID for SOS.
|
||||
eos_id:
|
||||
@ -332,7 +364,7 @@ def decode_dataset(
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
batch=batch,
|
||||
lexicon=lexicon,
|
||||
word_table=word_table,
|
||||
G=G,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
@ -528,7 +560,7 @@ def main():
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
lexicon=lexicon,
|
||||
word_table=lexicon.word_table,
|
||||
G=G,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
|
@ -1,350 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from conformer import Conformer
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.utils import AttributeDict, get_texts
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--words-file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to words.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--HLG", type=str, required=True, help="Path to HLG.pt."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="1best",
|
||||
help="""Decoding method.
|
||||
Possible values are:
|
||||
(1) 1best - Use the best path as decoding output. Only
|
||||
the transformer encoder output is used for decoding.
|
||||
We call it HLG decoding.
|
||||
(2) whole-lattice-rescoring - Use an LM to rescore the
|
||||
decoding lattice and then use 1best to decode the
|
||||
rescored lattice.
|
||||
We call it HLG decoding + n-gram LM rescoring.
|
||||
(3) attention-decoder - Extract n paths from he rescored
|
||||
lattice and use the transformer attention decoder for
|
||||
rescoring.
|
||||
We call it HLG decoding + n-gram LM rescoring + attention
|
||||
decoder rescoring.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--G",
|
||||
type=str,
|
||||
help="""An LM for rescoring.
|
||||
Used only when method is
|
||||
whole-lattice-rescoring or attention-decoder.
|
||||
It's usually a 4-gram LM.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""
|
||||
Used only when method is attention-decoder.
|
||||
It specifies the size of n-best list.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=1.3,
|
||||
help="""
|
||||
Used only when method is whole-lattice-rescoring and attention-decoder.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
(Note: You need to tune it on a dataset.)
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attention-decoder-scale",
|
||||
type=float,
|
||||
default=1.2,
|
||||
help="""
|
||||
Used only when method is attention-decoder.
|
||||
It specifies the scale for attention decoder scores.
|
||||
(Note: You need to tune it on a dataset.)
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lattice-score-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""
|
||||
Used only when method is attention-decoder.
|
||||
It specifies the scale for lattice.scores when
|
||||
extracting n-best lists. A smaller value results in
|
||||
more unique number of paths with the risk of missing
|
||||
the best path.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sos-id",
|
||||
type=float,
|
||||
default=1,
|
||||
help="""
|
||||
Used only when method is attention-decoder.
|
||||
It specifies ID for the SOS token.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--eos-id",
|
||||
type=float,
|
||||
default=1,
|
||||
help="""
|
||||
Used only when method is attention-decoder.
|
||||
It specifies ID for the EOS token.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"feature_dim": 80,
|
||||
"nhead": 8,
|
||||
"num_classes": 5000,
|
||||
"sample_rate": 16000,
|
||||
"attention_dim": 512,
|
||||
"subsampling_factor": 4,
|
||||
"num_decoder_layers": 6,
|
||||
"vgg_frontend": False,
|
||||
"is_espnet_structure": True,
|
||||
"mmi_loss": False,
|
||||
"use_feat_batchnorm": True,
|
||||
"search_beam": 20,
|
||||
"output_beam": 8,
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
nhead=params.nhead,
|
||||
d_model=params.attention_dim,
|
||||
num_classes=params.num_classes,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
num_decoder_layers=params.num_decoder_layers,
|
||||
vgg_frontend=params.vgg_frontend,
|
||||
is_espnet_structure=params.is_espnet_structure,
|
||||
mmi_loss=params.mmi_loss,
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
logging.info(f"Loading HLG from {params.HLG}")
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||
HLG = HLG.to(device)
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
# For whole-lattice-rescoring and attention-decoder
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||
logging.info(f"Loading G from {params.G}")
|
||||
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||
G = G.to(device)
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
G = k2.arc_sort(G)
|
||||
G.lm_scores = G.scores.clone()
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info(f"Decoding started")
|
||||
features = fbank(waves)
|
||||
|
||||
features = pad_sequence(
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
|
||||
# Note: We don't use key padding mask for attention during decoding
|
||||
with torch.no_grad():
|
||||
nnet_output, memory, memory_key_padding_mask = model(features)
|
||||
|
||||
batch_size = nnet_output.shape[0]
|
||||
supervision_segments = torch.tensor(
|
||||
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=nnet_output,
|
||||
HLG=HLG,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
if params.method == "1best":
|
||||
logging.info("Use HLG decoding")
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
elif params.method == "whole-lattice-rescoring":
|
||||
logging.info("Use HLG decoding + LM rescoring")
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=[params.ngram_lm_scale],
|
||||
)
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
elif params.method == "attention-decoder":
|
||||
logging.info("Use HLG + LM rescoring + attention decoder rescoring")
|
||||
rescored_lattice = rescore_with_whole_lattice(
|
||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
|
||||
)
|
||||
best_path_dict = rescore_with_attention_decoder(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=params.sos_id,
|
||||
eos_id=params.eos_id,
|
||||
scale=params.lattice_score_scale,
|
||||
ngram_lm_scale=params.ngram_lm_scale,
|
||||
attention_scale=params.attention_decoder_scale,
|
||||
)
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info(f"Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/librispeech/ASR/conformer_ctc_embedding_scale/pretrained.py
Symbolic link
1
egs/librispeech/ASR/conformer_ctc_embedding_scale/pretrained.py
Symbolic link
@ -0,0 +1 @@
|
||||
../conformer_ctc/pretrained.py
|
@ -1,3 +1,20 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
@ -1,33 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from subsampling import Conv2dSubsampling
|
||||
from subsampling import VggSubsampling
|
||||
import torch
|
||||
|
||||
|
||||
def test_conv2d_subsampling():
|
||||
N = 3
|
||||
odim = 2
|
||||
|
||||
for T in range(7, 19):
|
||||
for idim in range(7, 20):
|
||||
model = Conv2dSubsampling(idim=idim, odim=odim)
|
||||
x = torch.empty(N, T, idim)
|
||||
y = model(x)
|
||||
assert y.shape[0] == N
|
||||
assert y.shape[1] == ((T - 1) // 2 - 1) // 2
|
||||
assert y.shape[2] == odim
|
||||
|
||||
|
||||
def test_vgg_subsampling():
|
||||
N = 3
|
||||
odim = 2
|
||||
|
||||
for T in range(7, 19):
|
||||
for idim in range(7, 20):
|
||||
model = VggSubsampling(idim=idim, odim=odim)
|
||||
x = torch.empty(N, T, idim)
|
||||
y = model(x)
|
||||
assert y.shape[0] == N
|
||||
assert y.shape[1] == ((T - 1) // 2 - 1) // 2
|
||||
assert y.shape[2] == odim
|
@ -1,89 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
from transformer import (
|
||||
Transformer,
|
||||
encoder_padding_mask,
|
||||
generate_square_subsequent_mask,
|
||||
decoder_padding_mask,
|
||||
add_sos,
|
||||
add_eos,
|
||||
)
|
||||
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def test_encoder_padding_mask():
|
||||
supervisions = {
|
||||
"sequence_idx": torch.tensor([0, 1, 2]),
|
||||
"start_frame": torch.tensor([0, 0, 0]),
|
||||
"num_frames": torch.tensor([18, 7, 13]),
|
||||
}
|
||||
|
||||
max_len = ((18 - 1) // 2 - 1) // 2
|
||||
mask = encoder_padding_mask(max_len, supervisions)
|
||||
expected_mask = torch.tensor(
|
||||
[
|
||||
[False, False, False], # ((18 - 1)//2 - 1)//2 = 3,
|
||||
[False, True, True], # ((7 - 1)//2 - 1)//2 = 1,
|
||||
[False, False, True], # ((13 - 1)//2 - 1)//2 = 2,
|
||||
]
|
||||
)
|
||||
assert torch.all(torch.eq(mask, expected_mask))
|
||||
|
||||
|
||||
def test_transformer():
|
||||
num_features = 40
|
||||
num_classes = 87
|
||||
model = Transformer(num_features=num_features, num_classes=num_classes)
|
||||
|
||||
N = 31
|
||||
|
||||
for T in range(7, 30):
|
||||
x = torch.rand(N, T, num_features)
|
||||
y, _, _ = model(x)
|
||||
assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes)
|
||||
|
||||
|
||||
def test_generate_square_subsequent_mask():
|
||||
s = 5
|
||||
mask = generate_square_subsequent_mask(s)
|
||||
inf = float("inf")
|
||||
expected_mask = torch.tensor(
|
||||
[
|
||||
[0.0, -inf, -inf, -inf, -inf],
|
||||
[0.0, 0.0, -inf, -inf, -inf],
|
||||
[0.0, 0.0, 0.0, -inf, -inf],
|
||||
[0.0, 0.0, 0.0, 0.0, -inf],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
]
|
||||
)
|
||||
assert torch.all(torch.eq(mask, expected_mask))
|
||||
|
||||
|
||||
def test_decoder_padding_mask():
|
||||
x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])]
|
||||
y = pad_sequence(x, batch_first=True, padding_value=-1)
|
||||
mask = decoder_padding_mask(y, ignore_id=-1)
|
||||
expected_mask = torch.tensor(
|
||||
[
|
||||
[False, False, True],
|
||||
[False, True, True],
|
||||
[False, False, False],
|
||||
]
|
||||
)
|
||||
assert torch.all(torch.eq(mask, expected_mask))
|
||||
|
||||
|
||||
def test_add_sos():
|
||||
x = [[1, 2], [3], [2, 5, 8]]
|
||||
y = add_sos(x, sos_id=0)
|
||||
expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]]
|
||||
assert y == expected_y
|
||||
|
||||
|
||||
def test_add_eos():
|
||||
x = [[1, 2], [3], [2, 5, 8]]
|
||||
y = add_eos(x, eos_id=0)
|
||||
expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]]
|
||||
assert y == expected_y
|
@ -1,6 +1,20 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This is just at the very beginning ...
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
@ -60,6 +74,23 @@ def get_parser():
|
||||
help="Should various information be logged in tensorboard.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=35,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""Resume training from from this epoch.
|
||||
If it is positive, it will load checkpoint from
|
||||
conformer_ctc/exp/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -89,11 +120,6 @@ def get_params() -> AttributeDict:
|
||||
|
||||
- subsampling_factor: The subsampling factor for the model.
|
||||
|
||||
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
|
||||
and continue training from that checkpoint.
|
||||
|
||||
- num_epochs: Number of epochs to train.
|
||||
|
||||
- best_train_loss: Best training loss so far. It is used to select
|
||||
the model that has the lowest training loss. It is
|
||||
updated during the training.
|
||||
@ -124,13 +150,11 @@ def get_params() -> AttributeDict:
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("conformer_ctc_embedding_scale/exp"),
|
||||
"exp_dir": Path("conformer_ctc/exp"),
|
||||
"lang_dir": Path("data/lang_bpe"),
|
||||
"feature_dim": 80,
|
||||
"weight_decay": 1e-6,
|
||||
"subsampling_factor": 4,
|
||||
"start_epoch": 0,
|
||||
"num_epochs": 20,
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
|
@ -1,5 +1,19 @@
|
||||
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
||||
# Apache 2.0
|
||||
# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -641,7 +655,7 @@ class PositionalEncoding(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.pos_scale = 1. / math.sqrt(self.d_model)
|
||||
self.pos_scale = 1.0 / math.sqrt(self.d_model)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
self.pe = None
|
||||
|
||||
@ -780,7 +794,8 @@ class Noam(object):
|
||||
|
||||
class LabelSmoothingLoss(nn.Module):
|
||||
"""
|
||||
Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w)
|
||||
Label-smoothing loss. KL-divergence between
|
||||
q_{smoothed ground truth prob.}(w)
|
||||
and p_{prob. computed by model}(w) is minimized.
|
||||
Modified from
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa
|
||||
@ -865,7 +880,8 @@ def encoder_padding_mask(
|
||||
frames, before subsampling)
|
||||
|
||||
Returns:
|
||||
Tensor: Mask tensor of dimension (batch_size, input_length), True denote the masked indices.
|
||||
Tensor: Mask tensor of dimension (batch_size, input_length),
|
||||
True denote the masked indices.
|
||||
"""
|
||||
if supervisions is None:
|
||||
return None
|
||||
|
Loading…
x
Reference in New Issue
Block a user