mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
apply black on all files
This commit is contained in:
parent
b3920e5ab5
commit
107df3b115
11
.github/workflows/style_check.yml
vendored
11
.github/workflows/style_check.yml
vendored
@ -45,17 +45,18 @@ jobs:
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4
|
||||
# See https://github.com/psf/black/issues/2964
|
||||
# The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4
|
||||
python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0
|
||||
# Click issue fixed in https://github.com/psf/black/pull/2966
|
||||
|
||||
- name: Run flake8
|
||||
shell: bash
|
||||
working-directory: ${{github.workspace}}
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --count --show-source --statistics
|
||||
flake8 .
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 \
|
||||
--statistics --extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503
|
||||
|
||||
- name: Run black
|
||||
shell: bash
|
||||
|
@ -1,26 +1,38 @@
|
||||
repos:
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 21.6b0
|
||||
rev: 22.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
args: [--line-length=80]
|
||||
additional_dependencies: ['click==8.0.1']
|
||||
args: ["--line-length=88"]
|
||||
additional_dependencies: ['click==8.1.0']
|
||||
exclude: icefall\/__init__\.py
|
||||
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 3.9.2
|
||||
rev: 5.0.4
|
||||
hooks:
|
||||
- id: flake8
|
||||
args: [--max-line-length=80]
|
||||
args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503"]
|
||||
|
||||
# What are we ignoring here?
|
||||
# E203: whitespace before ':'
|
||||
# E266: too many leading '#' for block comment
|
||||
# E501: line too long
|
||||
# F401: module imported but unused
|
||||
# E402: module level import not at top of file
|
||||
# F403: 'from module import *' used; unable to detect undefined names
|
||||
# F841: local variable is assigned to but never used
|
||||
# W503: line break before binary operator
|
||||
# In addition, the default ignore list is:
|
||||
# E121,E123,E126,E226,E24,E704,W503,W504
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.9.2
|
||||
rev: 5.10.1
|
||||
hooks:
|
||||
- id: isort
|
||||
args: [--profile=black, --line-length=80]
|
||||
args: ["--profile=black"]
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.0.1
|
||||
rev: v4.2.0
|
||||
hooks:
|
||||
- id: check-executables-have-shebangs
|
||||
- id: end-of-file-fixer
|
||||
|
@ -88,4 +88,3 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
|
||||
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
|
||||
|
||||
WORKDIR /workspace/icefall
|
||||
|
||||
|
@ -19,4 +19,3 @@ It can be downloaded from `<https://www.openslr.org/33/>`_
|
||||
tdnn_lstm_ctc
|
||||
conformer_ctc
|
||||
stateless_transducer
|
||||
|
||||
|
@ -6,4 +6,3 @@ TIMIT
|
||||
|
||||
tdnn_ligru_ctc
|
||||
tdnn_lstm_ctc
|
||||
|
||||
|
@ -87,9 +87,7 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
|
||||
)
|
||||
if "train" in partition:
|
||||
cut_set = (
|
||||
cut_set
|
||||
+ cut_set.perturb_speed(0.9)
|
||||
+ cut_set.perturb_speed(1.1)
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
@ -116,9 +114,7 @@ def get_args():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
|
@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
|
||||
cur_state = loop_state
|
||||
|
||||
word = word2id[word]
|
||||
pieces = [
|
||||
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
|
||||
]
|
||||
pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
|
||||
|
||||
for i in range(len(pieces) - 1):
|
||||
w = word if i == 0 else eps
|
||||
@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def generate_lexicon(
|
||||
token_sym_table: Dict[str, int], words: List[str]
|
||||
) -> Lexicon:
|
||||
def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
|
||||
"""Generate a lexicon from a word list and token_sym_table.
|
||||
|
||||
Args:
|
||||
|
@ -317,9 +317,7 @@ def lexicon_to_fst(
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
|
||||
)
|
||||
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
|
||||
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
||||
fsa.draw("L.pdf", title="L")
|
||||
|
||||
fsa_disambig = lexicon_to_fst(
|
||||
lexicon_disambig, phone2id=phone2id, word2id=word2id
|
||||
)
|
||||
fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
|
||||
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
|
||||
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
||||
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
|
||||
|
@ -56,9 +56,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--space", default="<space>", type=str, help="space symbol"
|
||||
)
|
||||
parser.add_argument("--space", default="<space>", type=str, help="space symbol")
|
||||
parser.add_argument(
|
||||
"--non-lang-syms",
|
||||
"-l",
|
||||
@ -66,9 +64,7 @@ def get_parser():
|
||||
type=str,
|
||||
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"text", type=str, default=False, nargs="?", help="input text"
|
||||
)
|
||||
parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
|
||||
parser.add_argument(
|
||||
"--trans_type",
|
||||
"-t",
|
||||
@ -108,8 +104,7 @@ def token2id(
|
||||
if token_type == "lazy_pinyin":
|
||||
text = lazy_pinyin(chars_list)
|
||||
sub_ids = [
|
||||
token_table[txt] if txt in token_table else oov_id
|
||||
for txt in text
|
||||
token_table[txt] if txt in token_table else oov_id for txt in text
|
||||
]
|
||||
ids.append(sub_ids)
|
||||
else: # token_type = "pinyin"
|
||||
@ -135,9 +130,7 @@ def main():
|
||||
if args.text:
|
||||
f = codecs.open(args.text, encoding="utf-8")
|
||||
else:
|
||||
f = codecs.getreader("utf-8")(
|
||||
sys.stdin if is_python2 else sys.stdin.buffer
|
||||
)
|
||||
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
|
||||
|
||||
sys.stdout = codecs.getwriter("utf-8")(
|
||||
sys.stdout if is_python2 else sys.stdout.buffer
|
||||
|
@ -113,4 +113,3 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
./local/prepare_char.py
|
||||
fi
|
||||
fi
|
||||
|
||||
|
@ -205,17 +205,13 @@ class Aidatatang_200zhAsrDataModule:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(
|
||||
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
||||
)
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
transforms.append(
|
||||
CutMix(
|
||||
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||
)
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
@ -237,9 +233,7 @@ class Aidatatang_200zhAsrDataModule:
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(
|
||||
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||
)
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
@ -282,9 +276,7 @@ class Aidatatang_200zhAsrDataModule:
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
@ -340,9 +332,7 @@ class Aidatatang_200zhAsrDataModule:
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
|
@ -69,11 +69,7 @@ from beam_search import (
|
||||
)
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -192,8 +188,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
@ -249,9 +244,7 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
@ -266,10 +259,7 @@ def decode_one_batch(
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
@ -390,9 +380,7 @@ def decode_dataset(
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
@ -425,8 +413,7 @@ def save_results(
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
|
@ -103,8 +103,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
return parser
|
||||
@ -173,9 +172,7 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
@ -162,8 +162,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -194,8 +193,7 @@ def read_sound_files(
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
@ -257,9 +255,7 @@ def main():
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
@ -284,10 +280,7 @@ def main():
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
@ -339,9 +332,7 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
@ -81,9 +81,7 @@ from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
]
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||
|
||||
@ -187,8 +185,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -211,8 +208,7 @@ def get_parser():
|
||||
"--am-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The scale to smooth the loss with am (output of encoder network)"
|
||||
"part.",
|
||||
help="The scale to smooth the loss with am (output of encoder network)" "part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -542,22 +538,15 @@ def compute_loss(
|
||||
# overwhelming the simple_loss and causing it to diverge,
|
||||
# in case it had not fully learned the alignment yet.
|
||||
pruned_loss_scale = (
|
||||
0.0
|
||||
if warmup < 1.0
|
||||
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||
)
|
||||
loss = (
|
||||
params.simple_loss_scale * simple_loss
|
||||
+ pruned_loss_scale * pruned_loss
|
||||
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||
)
|
||||
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
@ -711,9 +700,7 @@ def train_one_epoch(
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
logging.info("Computing validation loss")
|
||||
|
@ -157,9 +157,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
normalize_before: bool = True,
|
||||
) -> None:
|
||||
super(ConformerEncoderLayer, self).__init__()
|
||||
self.self_attn = RelPositionMultiheadAttention(
|
||||
d_model, nhead, dropout=0.0
|
||||
)
|
||||
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
|
||||
|
||||
self.feed_forward = nn.Sequential(
|
||||
nn.Linear(d_model, dim_feedforward),
|
||||
@ -177,18 +175,14 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||
|
||||
self.norm_ff_macaron = nn.LayerNorm(
|
||||
d_model
|
||||
) # for the macaron style FNN module
|
||||
self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
|
||||
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
||||
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
|
||||
|
||||
self.ff_scale = 0.5
|
||||
|
||||
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
|
||||
self.norm_final = nn.LayerNorm(
|
||||
d_model
|
||||
) # for the final output of the block
|
||||
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
@ -222,9 +216,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_ff_macaron(src)
|
||||
src = residual + self.ff_scale * self.dropout(
|
||||
self.feed_forward_macaron(src)
|
||||
)
|
||||
src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
|
||||
if not self.normalize_before:
|
||||
src = self.norm_ff_macaron(src)
|
||||
|
||||
@ -343,9 +335,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, d_model: int, dropout_rate: float, max_len: int = 5000
|
||||
) -> None:
|
||||
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(RelPositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
@ -361,9 +351,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||
x.device
|
||||
):
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
# Suppose `i` means to the position of query vector and `j` means the
|
||||
@ -633,9 +621,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
if torch.equal(query, key) and torch.equal(key, value):
|
||||
# self-attention
|
||||
q, k, v = nn.functional.linear(
|
||||
query, in_proj_weight, in_proj_bias
|
||||
).chunk(3, dim=-1)
|
||||
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
|
||||
3, dim=-1
|
||||
)
|
||||
|
||||
elif torch.equal(key, value):
|
||||
# encoder-decoder attention
|
||||
@ -703,31 +691,22 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
if attn_mask.dim() == 2:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
||||
raise RuntimeError(
|
||||
"The size of the 2D attn_mask is not correct."
|
||||
)
|
||||
raise RuntimeError("The size of the 2D attn_mask is not correct.")
|
||||
elif attn_mask.dim() == 3:
|
||||
if list(attn_mask.size()) != [
|
||||
bsz * num_heads,
|
||||
query.size(0),
|
||||
key.size(0),
|
||||
]:
|
||||
raise RuntimeError(
|
||||
"The size of the 3D attn_mask is not correct."
|
||||
)
|
||||
raise RuntimeError("The size of the 3D attn_mask is not correct.")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"attn_mask's dimension {} is not supported".format(
|
||||
attn_mask.dim()
|
||||
)
|
||||
"attn_mask's dimension {} is not supported".format(attn_mask.dim())
|
||||
)
|
||||
# attn_mask's dim is 3 now.
|
||||
|
||||
# convert ByteTensor key_padding_mask to bool
|
||||
if (
|
||||
key_padding_mask is not None
|
||||
and key_padding_mask.dtype == torch.uint8
|
||||
):
|
||||
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
||||
warnings.warn(
|
||||
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
|
||||
)
|
||||
@ -766,9 +745,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
# first compute matrix a and matrix c
|
||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
||||
matrix_ac = torch.matmul(
|
||||
q_with_bias_u, k
|
||||
) # (batch, head, time1, time2)
|
||||
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
|
||||
|
||||
# compute matrix b and matrix d
|
||||
matrix_bd = torch.matmul(
|
||||
@ -780,9 +757,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
matrix_ac + matrix_bd
|
||||
) * scaling # (batch, head, time1, time2)
|
||||
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, tgt_len, -1
|
||||
)
|
||||
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
|
||||
|
||||
assert list(attn_output_weights.size()) == [
|
||||
bsz * num_heads,
|
||||
@ -816,13 +791,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1)
|
||||
.contiguous()
|
||||
.view(tgt_len, bsz, embed_dim)
|
||||
)
|
||||
attn_output = nn.functional.linear(
|
||||
attn_output, out_proj_weight, out_proj_bias
|
||||
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||
)
|
||||
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
|
||||
|
||||
if need_weights:
|
||||
# average attention weights over heads
|
||||
@ -845,9 +816,7 @@ class ConvolutionModule(nn.Module):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, channels: int, kernel_size: int, bias: bool = True
|
||||
) -> None:
|
||||
def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
|
@ -401,9 +401,7 @@ def decode_dataset(
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
@ -431,9 +429,7 @@ def save_results(
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append(
|
||||
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||
)
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
|
||||
@ -441,9 +437,7 @@ def save_results(
|
||||
test_set_wers[key] = wer
|
||||
|
||||
if enable_log:
|
||||
logging.info(
|
||||
"Wrote detailed error stats to {}".format(errs_filename)
|
||||
)
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
|
||||
@ -562,9 +556,7 @@ def main():
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params, test_set_name=test_set, results_dict=results_dict
|
||||
)
|
||||
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
@ -157,9 +157,7 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
@ -211,8 +211,7 @@ def read_sound_files(
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
@ -274,9 +273,7 @@ def main():
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
|
||||
features = pad_sequence(
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
|
||||
# Note: We don't use key padding mask for attention during decoding
|
||||
with torch.no_grad():
|
||||
@ -371,9 +368,7 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module):
|
||||
assert idim >= 7
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
||||
),
|
||||
nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
||||
),
|
||||
nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
||||
@ -132,17 +128,13 @@ class VggSubsampling(nn.Module):
|
||||
)
|
||||
)
|
||||
layers.append(
|
||||
torch.nn.MaxPool2d(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True
|
||||
)
|
||||
torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||
)
|
||||
cur_channels = block_dim
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
self.out = nn.Linear(
|
||||
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
|
||||
)
|
||||
self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
|
@ -16,9 +16,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from subsampling import Conv2dSubsampling
|
||||
from subsampling import VggSubsampling
|
||||
import torch
|
||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||
|
||||
|
||||
def test_conv2d_subsampling():
|
||||
|
@ -382,9 +382,7 @@ def compute_loss(
|
||||
#
|
||||
# See https://github.com/k2-fsa/icefall/issues/97
|
||||
# for more details
|
||||
unsorted_token_ids = graph_compiler.texts_to_ids(
|
||||
supervisions["text"]
|
||||
)
|
||||
unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
|
||||
att_loss = mmodel.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
@ -520,9 +518,7 @@ def train_one_epoch(
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
logging.info("Computing validation loss")
|
||||
@ -630,9 +626,7 @@ def run(rank, world_size, args):
|
||||
|
||||
cur_lr = optimizer._rate
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
if rank == 0:
|
||||
|
@ -149,9 +149,7 @@ class Transformer(nn.Module):
|
||||
norm=decoder_norm,
|
||||
)
|
||||
|
||||
self.decoder_output_layer = torch.nn.Linear(
|
||||
d_model, self.decoder_num_class
|
||||
)
|
||||
self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
|
||||
|
||||
self.decoder_criterion = LabelSmoothingLoss()
|
||||
else:
|
||||
@ -183,9 +181,7 @@ class Transformer(nn.Module):
|
||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
||||
x = self.feat_batchnorm(x)
|
||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
||||
x, supervision
|
||||
)
|
||||
encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision)
|
||||
x = self.ctc_output(encoder_memory)
|
||||
return x, encoder_memory, memory_key_padding_mask
|
||||
|
||||
@ -266,23 +262,17 @@ class Transformer(nn.Module):
|
||||
"""
|
||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||
ys_in = [torch.tensor(y) for y in ys_in]
|
||||
ys_in_pad = pad_sequence(
|
||||
ys_in, batch_first=True, padding_value=float(eos_id)
|
||||
)
|
||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
|
||||
|
||||
ys_out = add_eos(token_ids, eos_id=eos_id)
|
||||
ys_out = [torch.tensor(y) for y in ys_out]
|
||||
ys_out_pad = pad_sequence(
|
||||
ys_out, batch_first=True, padding_value=float(-1)
|
||||
)
|
||||
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
|
||||
|
||||
device = memory.device
|
||||
ys_in_pad = ys_in_pad.to(device)
|
||||
ys_out_pad = ys_out_pad.to(device)
|
||||
|
||||
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
|
||||
device
|
||||
)
|
||||
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
|
||||
|
||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||
# TODO: Use length information to create the decoder padding mask
|
||||
@ -343,23 +333,17 @@ class Transformer(nn.Module):
|
||||
|
||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||
ys_in = [torch.tensor(y) for y in ys_in]
|
||||
ys_in_pad = pad_sequence(
|
||||
ys_in, batch_first=True, padding_value=float(eos_id)
|
||||
)
|
||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
|
||||
|
||||
ys_out = add_eos(token_ids, eos_id=eos_id)
|
||||
ys_out = [torch.tensor(y) for y in ys_out]
|
||||
ys_out_pad = pad_sequence(
|
||||
ys_out, batch_first=True, padding_value=float(-1)
|
||||
)
|
||||
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
|
||||
|
||||
device = memory.device
|
||||
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
|
||||
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
|
||||
|
||||
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
|
||||
device
|
||||
)
|
||||
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
|
||||
|
||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||
# TODO: Use length information to create the decoder padding mask
|
||||
@ -632,9 +616,7 @@ def _get_activation_fn(activation: str):
|
||||
elif activation == "gelu":
|
||||
return nn.functional.gelu
|
||||
|
||||
raise RuntimeError(
|
||||
"activation should be relu/gelu, not {}".format(activation)
|
||||
)
|
||||
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
@ -836,9 +818,7 @@ def encoder_padding_mask(
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
||||
lengths = [
|
||||
0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
|
||||
]
|
||||
lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
|
||||
for idx in range(supervision_segments.size(0)):
|
||||
# Note: TorchScript doesn't allow to unpack tensors as tuples
|
||||
sequence_idx = supervision_segments[idx, 0].item()
|
||||
@ -859,9 +839,7 @@ def encoder_padding_mask(
|
||||
return mask
|
||||
|
||||
|
||||
def decoder_padding_mask(
|
||||
ys_pad: torch.Tensor, ignore_id: int = -1
|
||||
) -> torch.Tensor:
|
||||
def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
|
||||
"""Generate a length mask for input.
|
||||
|
||||
The masked position are filled with True,
|
||||
|
@ -157,9 +157,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
normalize_before: bool = True,
|
||||
) -> None:
|
||||
super(ConformerEncoderLayer, self).__init__()
|
||||
self.self_attn = RelPositionMultiheadAttention(
|
||||
d_model, nhead, dropout=0.0
|
||||
)
|
||||
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
|
||||
|
||||
self.feed_forward = nn.Sequential(
|
||||
nn.Linear(d_model, dim_feedforward),
|
||||
@ -177,18 +175,14 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||
|
||||
self.norm_ff_macaron = nn.LayerNorm(
|
||||
d_model
|
||||
) # for the macaron style FNN module
|
||||
self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
|
||||
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
||||
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
|
||||
|
||||
self.ff_scale = 0.5
|
||||
|
||||
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
|
||||
self.norm_final = nn.LayerNorm(
|
||||
d_model
|
||||
) # for the final output of the block
|
||||
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
@ -222,9 +216,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_ff_macaron(src)
|
||||
src = residual + self.ff_scale * self.dropout(
|
||||
self.feed_forward_macaron(src)
|
||||
)
|
||||
src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
|
||||
if not self.normalize_before:
|
||||
src = self.norm_ff_macaron(src)
|
||||
|
||||
@ -343,9 +335,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, d_model: int, dropout_rate: float, max_len: int = 5000
|
||||
) -> None:
|
||||
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(RelPositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
@ -361,9 +351,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||
x.device
|
||||
):
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
# Suppose `i` means to the position of query vector and `j` means the
|
||||
@ -633,9 +621,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
if torch.equal(query, key) and torch.equal(key, value):
|
||||
# self-attention
|
||||
q, k, v = nn.functional.linear(
|
||||
query, in_proj_weight, in_proj_bias
|
||||
).chunk(3, dim=-1)
|
||||
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
|
||||
3, dim=-1
|
||||
)
|
||||
|
||||
elif torch.equal(key, value):
|
||||
# encoder-decoder attention
|
||||
@ -703,31 +691,22 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
if attn_mask.dim() == 2:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
||||
raise RuntimeError(
|
||||
"The size of the 2D attn_mask is not correct."
|
||||
)
|
||||
raise RuntimeError("The size of the 2D attn_mask is not correct.")
|
||||
elif attn_mask.dim() == 3:
|
||||
if list(attn_mask.size()) != [
|
||||
bsz * num_heads,
|
||||
query.size(0),
|
||||
key.size(0),
|
||||
]:
|
||||
raise RuntimeError(
|
||||
"The size of the 3D attn_mask is not correct."
|
||||
)
|
||||
raise RuntimeError("The size of the 3D attn_mask is not correct.")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"attn_mask's dimension {} is not supported".format(
|
||||
attn_mask.dim()
|
||||
)
|
||||
"attn_mask's dimension {} is not supported".format(attn_mask.dim())
|
||||
)
|
||||
# attn_mask's dim is 3 now.
|
||||
|
||||
# convert ByteTensor key_padding_mask to bool
|
||||
if (
|
||||
key_padding_mask is not None
|
||||
and key_padding_mask.dtype == torch.uint8
|
||||
):
|
||||
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
||||
warnings.warn(
|
||||
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
|
||||
)
|
||||
@ -766,9 +745,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
# first compute matrix a and matrix c
|
||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
||||
matrix_ac = torch.matmul(
|
||||
q_with_bias_u, k
|
||||
) # (batch, head, time1, time2)
|
||||
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
|
||||
|
||||
# compute matrix b and matrix d
|
||||
matrix_bd = torch.matmul(
|
||||
@ -780,9 +757,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
matrix_ac + matrix_bd
|
||||
) * scaling # (batch, head, time1, time2)
|
||||
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, tgt_len, -1
|
||||
)
|
||||
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
|
||||
|
||||
assert list(attn_output_weights.size()) == [
|
||||
bsz * num_heads,
|
||||
@ -816,13 +791,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1)
|
||||
.contiguous()
|
||||
.view(tgt_len, bsz, embed_dim)
|
||||
)
|
||||
attn_output = nn.functional.linear(
|
||||
attn_output, out_proj_weight, out_proj_bias
|
||||
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||
)
|
||||
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
|
||||
|
||||
if need_weights:
|
||||
# average attention weights over heads
|
||||
@ -845,9 +816,7 @@ class ConvolutionModule(nn.Module):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, channels: int, kernel_size: int, bias: bool = True
|
||||
) -> None:
|
||||
def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
|
@ -413,9 +413,7 @@ def decode_dataset(
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
@ -443,9 +441,7 @@ def save_results(
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append(
|
||||
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||
)
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
|
||||
@ -453,9 +449,7 @@ def save_results(
|
||||
test_set_wers[key] = wer
|
||||
|
||||
if enable_log:
|
||||
logging.info(
|
||||
"Wrote detailed error stats to {}".format(errs_filename)
|
||||
)
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
|
||||
@ -550,9 +544,7 @@ def main():
|
||||
|
||||
if params.export:
|
||||
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||
torch.save(
|
||||
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
||||
)
|
||||
torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
|
||||
return
|
||||
|
||||
model.to(device)
|
||||
@ -581,9 +573,7 @@ def main():
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params, test_set_name=test_set, results_dict=results_dict
|
||||
)
|
||||
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module):
|
||||
assert idim >= 7
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
||||
),
|
||||
nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
||||
),
|
||||
nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
||||
@ -132,17 +128,13 @@ class VggSubsampling(nn.Module):
|
||||
)
|
||||
)
|
||||
layers.append(
|
||||
torch.nn.MaxPool2d(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True
|
||||
)
|
||||
torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||
)
|
||||
cur_channels = block_dim
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
self.out = nn.Linear(
|
||||
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
|
||||
)
|
||||
self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
|
@ -511,9 +511,7 @@ def train_one_epoch(
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
logging.info("Computing validation loss")
|
||||
@ -625,9 +623,7 @@ def run(rank, world_size, args):
|
||||
|
||||
cur_lr = optimizer._rate
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
if rank == 0:
|
||||
|
@ -149,9 +149,7 @@ class Transformer(nn.Module):
|
||||
norm=decoder_norm,
|
||||
)
|
||||
|
||||
self.decoder_output_layer = torch.nn.Linear(
|
||||
d_model, self.decoder_num_class
|
||||
)
|
||||
self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
|
||||
|
||||
self.decoder_criterion = LabelSmoothingLoss()
|
||||
else:
|
||||
@ -183,9 +181,7 @@ class Transformer(nn.Module):
|
||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
||||
x = self.feat_batchnorm(x)
|
||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
||||
x, supervision
|
||||
)
|
||||
encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision)
|
||||
x = self.ctc_output(encoder_memory)
|
||||
return x, encoder_memory, memory_key_padding_mask
|
||||
|
||||
@ -266,23 +262,17 @@ class Transformer(nn.Module):
|
||||
"""
|
||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||
ys_in = [torch.tensor(y) for y in ys_in]
|
||||
ys_in_pad = pad_sequence(
|
||||
ys_in, batch_first=True, padding_value=float(eos_id)
|
||||
)
|
||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
|
||||
|
||||
ys_out = add_eos(token_ids, eos_id=eos_id)
|
||||
ys_out = [torch.tensor(y) for y in ys_out]
|
||||
ys_out_pad = pad_sequence(
|
||||
ys_out, batch_first=True, padding_value=float(-1)
|
||||
)
|
||||
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
|
||||
|
||||
device = memory.device
|
||||
ys_in_pad = ys_in_pad.to(device)
|
||||
ys_out_pad = ys_out_pad.to(device)
|
||||
|
||||
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
|
||||
device
|
||||
)
|
||||
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
|
||||
|
||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||
# TODO: Use length information to create the decoder padding mask
|
||||
@ -343,23 +333,17 @@ class Transformer(nn.Module):
|
||||
|
||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||
ys_in = [torch.tensor(y) for y in ys_in]
|
||||
ys_in_pad = pad_sequence(
|
||||
ys_in, batch_first=True, padding_value=float(eos_id)
|
||||
)
|
||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
|
||||
|
||||
ys_out = add_eos(token_ids, eos_id=eos_id)
|
||||
ys_out = [torch.tensor(y) for y in ys_out]
|
||||
ys_out_pad = pad_sequence(
|
||||
ys_out, batch_first=True, padding_value=float(-1)
|
||||
)
|
||||
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
|
||||
|
||||
device = memory.device
|
||||
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
|
||||
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
|
||||
|
||||
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
|
||||
device
|
||||
)
|
||||
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
|
||||
|
||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||
# TODO: Use length information to create the decoder padding mask
|
||||
@ -632,9 +616,7 @@ def _get_activation_fn(activation: str):
|
||||
elif activation == "gelu":
|
||||
return nn.functional.gelu
|
||||
|
||||
raise RuntimeError(
|
||||
"activation should be relu/gelu, not {}".format(activation)
|
||||
)
|
||||
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
@ -836,9 +818,7 @@ def encoder_padding_mask(
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
||||
lengths = [
|
||||
0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
|
||||
]
|
||||
lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
|
||||
for idx in range(supervision_segments.size(0)):
|
||||
# Note: TorchScript doesn't allow to unpack tensors as tuples
|
||||
sequence_idx = supervision_segments[idx, 0].item()
|
||||
@ -859,9 +839,7 @@ def encoder_padding_mask(
|
||||
return mask
|
||||
|
||||
|
||||
def decoder_padding_mask(
|
||||
ys_pad: torch.Tensor, ignore_id: int = -1
|
||||
) -> torch.Tensor:
|
||||
def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
|
||||
"""Generate a length mask for input.
|
||||
|
||||
The masked position are filled with True,
|
||||
|
@ -87,9 +87,7 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
|
||||
)
|
||||
if "train" in partition:
|
||||
cut_set = (
|
||||
cut_set
|
||||
+ cut_set.perturb_speed(0.9)
|
||||
+ cut_set.perturb_speed(1.1)
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
@ -116,9 +114,7 @@ def get_args():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
|
@ -83,9 +83,7 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
|
||||
)
|
||||
if "train" in partition:
|
||||
cut_set = (
|
||||
cut_set
|
||||
+ cut_set.perturb_speed(0.9)
|
||||
+ cut_set.perturb_speed(1.1)
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
@ -111,9 +109,7 @@ def get_args():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
|
@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
|
||||
cur_state = loop_state
|
||||
|
||||
word = word2id[word]
|
||||
pieces = [
|
||||
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
|
||||
]
|
||||
pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
|
||||
|
||||
for i in range(len(pieces) - 1):
|
||||
w = word if i == 0 else eps
|
||||
@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def generate_lexicon(
|
||||
token_sym_table: Dict[str, int], words: List[str]
|
||||
) -> Lexicon:
|
||||
def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
|
||||
"""Generate a lexicon from a word list and token_sym_table.
|
||||
|
||||
Args:
|
||||
|
@ -317,9 +317,7 @@ def lexicon_to_fst(
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
|
||||
)
|
||||
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
|
||||
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
||||
fsa.draw("L.pdf", title="L")
|
||||
|
||||
fsa_disambig = lexicon_to_fst(
|
||||
lexicon_disambig, phone2id=phone2id, word2id=word2id
|
||||
)
|
||||
fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
|
||||
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
|
||||
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
||||
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
|
||||
|
@ -76,11 +76,7 @@ from beam_search import (
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -188,8 +184,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
@ -249,9 +244,7 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
@ -263,10 +256,7 @@ def decode_one_batch(
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
@ -387,9 +377,7 @@ def decode_dataset(
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
@ -415,9 +403,7 @@ def save_results(
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append(
|
||||
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||
)
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
@ -428,8 +414,7 @@ def save_results(
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
@ -473,9 +458,7 @@ def main():
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -504,8 +487,7 @@ def main():
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
|
@ -50,11 +50,7 @@ from pathlib import Path
|
||||
import torch
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import str2bool
|
||||
|
||||
@ -120,8 +116,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
@ -157,8 +152,7 @@ def main():
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
@ -191,9 +185,7 @@ def main():
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = (
|
||||
params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
)
|
||||
filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
@ -201,17 +193,14 @@ def main():
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = (
|
||||
params.exp_dir
|
||||
/ f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
)
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
@ -165,8 +165,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
@ -197,8 +196,7 @@ def read_sound_files(
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
@ -256,13 +254,9 @@ def main():
|
||||
feature_lens = [f.size(0) for f in features]
|
||||
feature_lens = torch.tensor(feature_lens, device=device)
|
||||
|
||||
features = pad_sequence(
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features, x_lens=feature_lens
|
||||
)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyp_list = []
|
||||
@ -310,9 +304,7 @@ def main():
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.method}"
|
||||
)
|
||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||
hyp_list.append(hyp)
|
||||
|
||||
hyps = []
|
||||
@ -329,9 +321,7 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
@ -49,7 +49,6 @@ import optim
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from asr_datamodule import AishellAsrDataModule
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
@ -75,9 +74,7 @@ from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
]
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
@ -203,8 +200,7 @@ def get_parser():
|
||||
"--initial-lr",
|
||||
type=float,
|
||||
default=0.003,
|
||||
help="The initial learning rate. This value should not need "
|
||||
"to be changed.",
|
||||
help="The initial learning rate. This value should not need " "to be changed.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -227,8 +223,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -251,8 +246,7 @@ def get_parser():
|
||||
"--am-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The scale to smooth the loss with am (output of encoder network)"
|
||||
"part.",
|
||||
help="The scale to smooth the loss with am (output of encoder network)" "part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -561,11 +555,7 @@ def compute_loss(
|
||||
warmup: a floating point value which increases throughout training;
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
device = (
|
||||
model.device
|
||||
if isinstance(model, DDP)
|
||||
else next(model.parameters()).device
|
||||
)
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
@ -593,23 +583,16 @@ def compute_loss(
|
||||
# overwhelming the simple_loss and causing it to diverge,
|
||||
# in case it had not fully learned the alignment yet.
|
||||
pruned_loss_scale = (
|
||||
0.0
|
||||
if warmup < 1.0
|
||||
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||
)
|
||||
loss = (
|
||||
params.simple_loss_scale * simple_loss
|
||||
+ pruned_loss_scale * pruned_loss
|
||||
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||
)
|
||||
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
@ -725,9 +708,7 @@ def train_one_epoch(
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
except: # noqa
|
||||
display_and_save_batch(
|
||||
batch, params=params, graph_compiler=graph_compiler
|
||||
)
|
||||
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
|
||||
raise
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
@ -1029,9 +1010,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
f"Failing criterion: {criterion} "
|
||||
f"(={crit_values[criterion]}) ..."
|
||||
)
|
||||
display_and_save_batch(
|
||||
batch, params=params, graph_compiler=graph_compiler
|
||||
)
|
||||
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
|
||||
raise
|
||||
|
||||
|
||||
|
@ -202,8 +202,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
@ -263,9 +262,7 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
@ -277,10 +274,7 @@ def decode_one_batch(
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
@ -401,9 +395,7 @@ def decode_dataset(
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
@ -429,9 +421,7 @@ def save_results(
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append(
|
||||
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||
)
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
@ -442,8 +432,7 @@ def save_results(
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tCER", file=f)
|
||||
@ -488,9 +477,7 @@ def main():
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -518,9 +505,9 @@ def main():
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -551,9 +538,9 @@ def main():
|
||||
)
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
|
@ -132,8 +132,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
@ -166,9 +165,9 @@ def main():
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -195,9 +194,9 @@ def main():
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -252,9 +251,7 @@ def main():
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = (
|
||||
params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
)
|
||||
filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
@ -262,17 +259,14 @@ def main():
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = (
|
||||
params.exp_dir
|
||||
/ f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
)
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
@ -84,9 +84,7 @@ class Transducer(nn.Module):
|
||||
self.decoder_datatang = decoder_datatang
|
||||
self.joiner_datatang = joiner_datatang
|
||||
|
||||
self.simple_am_proj = ScaledLinear(
|
||||
encoder_dim, vocab_size, initial_speed=0.5
|
||||
)
|
||||
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5)
|
||||
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
||||
|
||||
if decoder_datatang is not None:
|
||||
@ -179,9 +177,7 @@ class Transducer(nn.Module):
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros(
|
||||
(x.size(0), 4), dtype=torch.int64, device=x.device
|
||||
)
|
||||
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = encoder_out_lens
|
||||
|
||||
|
@ -165,8 +165,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
@ -197,8 +196,7 @@ def read_sound_files(
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
@ -257,13 +255,9 @@ def main():
|
||||
feature_lens = [f.size(0) for f in features]
|
||||
feature_lens = torch.tensor(feature_lens, device=device)
|
||||
|
||||
features = pad_sequence(
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features, x_lens=feature_lens
|
||||
)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyp_list = []
|
||||
@ -311,9 +305,7 @@ def main():
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.method}"
|
||||
)
|
||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||
hyp_list.append(hyp)
|
||||
|
||||
hyps = []
|
||||
@ -330,9 +322,7 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
@ -96,9 +96,7 @@ from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
]
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
@ -224,8 +222,7 @@ def get_parser():
|
||||
"--initial-lr",
|
||||
type=float,
|
||||
default=0.003,
|
||||
help="The initial learning rate. This value should not need "
|
||||
"to be changed.",
|
||||
help="The initial learning rate. This value should not need " "to be changed.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -248,8 +245,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -272,8 +268,7 @@ def get_parser():
|
||||
"--am-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The scale to smooth the loss with am (output of encoder network)"
|
||||
"part.",
|
||||
help="The scale to smooth the loss with am (output of encoder network)" "part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -635,11 +630,7 @@ def compute_loss(
|
||||
warmup: a floating point value which increases throughout training;
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
device = (
|
||||
model.device
|
||||
if isinstance(model, DDP)
|
||||
else next(model.parameters()).device
|
||||
)
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
@ -670,23 +661,16 @@ def compute_loss(
|
||||
# overwhelming the simple_loss and causing it to diverge,
|
||||
# in case it had not fully learned the alignment yet.
|
||||
pruned_loss_scale = (
|
||||
0.0
|
||||
if warmup < 1.0
|
||||
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||
)
|
||||
loss = (
|
||||
params.simple_loss_scale * simple_loss
|
||||
+ pruned_loss_scale * pruned_loss
|
||||
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||
)
|
||||
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
@ -824,9 +808,7 @@ def train_one_epoch(
|
||||
)
|
||||
# summary stats
|
||||
if datatang_train_dl is not None:
|
||||
tot_loss = (
|
||||
tot_loss * (1 - 1 / params.reset_interval)
|
||||
) + loss_info
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
||||
if aishell:
|
||||
aishell_tot_loss = (
|
||||
@ -847,9 +829,7 @@ def train_one_epoch(
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
except: # noqa
|
||||
display_and_save_batch(
|
||||
batch, params=params, graph_compiler=graph_compiler
|
||||
)
|
||||
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
|
||||
raise
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
@ -892,9 +872,7 @@ def train_one_epoch(
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
if datatang_train_dl is not None:
|
||||
datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], "
|
||||
tot_loss_str = (
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
)
|
||||
tot_loss_str = f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
else:
|
||||
tot_loss_str = ""
|
||||
datatang_str = ""
|
||||
@ -1076,9 +1054,7 @@ def run(rank, world_size, args):
|
||||
train_cuts = filter_short_and_long_utterances(train_cuts)
|
||||
|
||||
if args.enable_musan:
|
||||
cuts_musan = load_manifest(
|
||||
Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
|
||||
)
|
||||
cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
|
||||
else:
|
||||
cuts_musan = None
|
||||
|
||||
@ -1093,9 +1069,7 @@ def run(rank, world_size, args):
|
||||
if params.datatang_prob > 0:
|
||||
datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
|
||||
train_datatang_cuts = datatang.train_cuts()
|
||||
train_datatang_cuts = filter_short_and_long_utterances(
|
||||
train_datatang_cuts
|
||||
)
|
||||
train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts)
|
||||
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
|
||||
datatang_train_dl = asr_datamodule.train_dataloaders(
|
||||
train_datatang_cuts,
|
||||
@ -1249,9 +1223,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
f"Failing criterion: {criterion} "
|
||||
f"(={crit_values[criterion]}) ..."
|
||||
)
|
||||
display_and_save_batch(
|
||||
batch, params=params, graph_compiler=graph_compiler
|
||||
)
|
||||
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
|
||||
raise
|
||||
|
||||
|
||||
|
@ -183,17 +183,13 @@ class AishellAsrDataModule:
|
||||
|
||||
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(
|
||||
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
||||
)
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
transforms.append(
|
||||
CutMix(
|
||||
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||
)
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
@ -215,9 +211,7 @@ class AishellAsrDataModule:
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(
|
||||
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||
)
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
@ -260,9 +254,7 @@ class AishellAsrDataModule:
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
@ -308,9 +300,7 @@ class AishellAsrDataModule:
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
@ -366,13 +356,9 @@ class AishellAsrDataModule:
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz"
|
||||
)
|
||||
return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> List[CutSet]:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "aishell_cuts_test.jsonl.gz"
|
||||
)
|
||||
return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz")
|
||||
|
@ -265,9 +265,7 @@ def decode_dataset(
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
@ -289,9 +287,7 @@ def save_results(
|
||||
# We compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append(
|
||||
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||
)
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(f, f"{test_set_name}-{key}", results_char)
|
||||
test_set_wers[key] = wer
|
||||
@ -335,9 +331,7 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
|
||||
)
|
||||
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
|
||||
HLG = HLG.to(device)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
@ -362,9 +356,7 @@ def main():
|
||||
|
||||
if params.export:
|
||||
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||
torch.save(
|
||||
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
||||
)
|
||||
torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
@ -392,9 +384,7 @@ def main():
|
||||
lexicon=lexicon,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params, test_set_name=test_set, results_dict=results_dict
|
||||
)
|
||||
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
@ -66,10 +66,7 @@ class TdnnLstm(nn.Module):
|
||||
nn.BatchNorm1d(num_features=500, affine=False),
|
||||
)
|
||||
self.lstms = nn.ModuleList(
|
||||
[
|
||||
nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
|
||||
for _ in range(5)
|
||||
]
|
||||
[nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)]
|
||||
)
|
||||
self.lstm_bnorms = nn.ModuleList(
|
||||
[nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
|
||||
|
@ -53,9 +53,7 @@ def get_parser():
|
||||
help="Path to words.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--HLG", type=str, required=True, help="Path to HLG.pt."
|
||||
)
|
||||
parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
@ -113,8 +111,7 @@ def read_sound_files(
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
@ -173,9 +170,7 @@ def main():
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
|
||||
features = pad_sequence(
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
features = features.permute(0, 2, 1) # now features is [N, C, T]
|
||||
|
||||
with torch.no_grad():
|
||||
@ -219,9 +214,7 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
@ -49,12 +49,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
encode_supervisions,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -47,9 +47,9 @@ def greedy_search(
|
||||
|
||||
device = model.device
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[blank_id] * context_size, device=device
|
||||
).reshape(1, context_size)
|
||||
decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
|
||||
1, context_size
|
||||
)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
|
||||
@ -81,9 +81,9 @@ def greedy_search(
|
||||
y = logits.argmax().item()
|
||||
if y != blank_id:
|
||||
hyp.append(y)
|
||||
decoder_input = torch.tensor(
|
||||
[hyp[-context_size:]], device=device
|
||||
).reshape(1, context_size)
|
||||
decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
|
||||
1, context_size
|
||||
)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
|
||||
@ -157,9 +157,7 @@ class HypothesisList(object):
|
||||
|
||||
"""
|
||||
if length_norm:
|
||||
return max(
|
||||
self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
|
||||
)
|
||||
return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
|
||||
else:
|
||||
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
|
||||
|
||||
@ -246,9 +244,9 @@ def beam_search(
|
||||
|
||||
device = model.device
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[blank_id] * context_size, device=device
|
||||
).reshape(1, context_size)
|
||||
decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
|
||||
1, context_size
|
||||
)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
|
||||
|
@ -155,9 +155,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
normalize_before: bool = True,
|
||||
) -> None:
|
||||
super(ConformerEncoderLayer, self).__init__()
|
||||
self.self_attn = RelPositionMultiheadAttention(
|
||||
d_model, nhead, dropout=0.0
|
||||
)
|
||||
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
|
||||
|
||||
self.feed_forward = nn.Sequential(
|
||||
nn.Linear(d_model, dim_feedforward),
|
||||
@ -175,18 +173,14 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||
|
||||
self.norm_ff_macaron = nn.LayerNorm(
|
||||
d_model
|
||||
) # for the macaron style FNN module
|
||||
self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
|
||||
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
||||
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
|
||||
|
||||
self.ff_scale = 0.5
|
||||
|
||||
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
|
||||
self.norm_final = nn.LayerNorm(
|
||||
d_model
|
||||
) # for the final output of the block
|
||||
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
@ -220,9 +214,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_ff_macaron(src)
|
||||
src = residual + self.ff_scale * self.dropout(
|
||||
self.feed_forward_macaron(src)
|
||||
)
|
||||
src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
|
||||
if not self.normalize_before:
|
||||
src = self.norm_ff_macaron(src)
|
||||
|
||||
@ -341,9 +333,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, d_model: int, dropout_rate: float, max_len: int = 5000
|
||||
) -> None:
|
||||
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(RelPositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
@ -359,9 +349,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||
x.device
|
||||
):
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
# Suppose `i` means to the position of query vector and `j` means the
|
||||
@ -631,9 +619,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
if torch.equal(query, key) and torch.equal(key, value):
|
||||
# self-attention
|
||||
q, k, v = nn.functional.linear(
|
||||
query, in_proj_weight, in_proj_bias
|
||||
).chunk(3, dim=-1)
|
||||
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
|
||||
3, dim=-1
|
||||
)
|
||||
|
||||
elif torch.equal(key, value):
|
||||
# encoder-decoder attention
|
||||
@ -701,31 +689,22 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
if attn_mask.dim() == 2:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
||||
raise RuntimeError(
|
||||
"The size of the 2D attn_mask is not correct."
|
||||
)
|
||||
raise RuntimeError("The size of the 2D attn_mask is not correct.")
|
||||
elif attn_mask.dim() == 3:
|
||||
if list(attn_mask.size()) != [
|
||||
bsz * num_heads,
|
||||
query.size(0),
|
||||
key.size(0),
|
||||
]:
|
||||
raise RuntimeError(
|
||||
"The size of the 3D attn_mask is not correct."
|
||||
)
|
||||
raise RuntimeError("The size of the 3D attn_mask is not correct.")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"attn_mask's dimension {} is not supported".format(
|
||||
attn_mask.dim()
|
||||
)
|
||||
"attn_mask's dimension {} is not supported".format(attn_mask.dim())
|
||||
)
|
||||
# attn_mask's dim is 3 now.
|
||||
|
||||
# convert ByteTensor key_padding_mask to bool
|
||||
if (
|
||||
key_padding_mask is not None
|
||||
and key_padding_mask.dtype == torch.uint8
|
||||
):
|
||||
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
||||
warnings.warn(
|
||||
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
|
||||
)
|
||||
@ -764,9 +743,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
# first compute matrix a and matrix c
|
||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
||||
matrix_ac = torch.matmul(
|
||||
q_with_bias_u, k
|
||||
) # (batch, head, time1, time2)
|
||||
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
|
||||
|
||||
# compute matrix b and matrix d
|
||||
matrix_bd = torch.matmul(
|
||||
@ -778,9 +755,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
matrix_ac + matrix_bd
|
||||
) * scaling # (batch, head, time1, time2)
|
||||
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, tgt_len, -1
|
||||
)
|
||||
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
|
||||
|
||||
assert list(attn_output_weights.size()) == [
|
||||
bsz * num_heads,
|
||||
@ -814,13 +789,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1)
|
||||
.contiguous()
|
||||
.view(tgt_len, bsz, embed_dim)
|
||||
)
|
||||
attn_output = nn.functional.linear(
|
||||
attn_output, out_proj_weight, out_proj_bias
|
||||
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||
)
|
||||
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
|
||||
|
||||
if need_weights:
|
||||
# average attention weights over heads
|
||||
@ -843,9 +814,7 @@ class ConvolutionModule(nn.Module):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, channels: int, kernel_size: int, bias: bool = True
|
||||
) -> None:
|
||||
def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
|
@ -99,8 +99,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
@ -227,9 +226,7 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
hyps = []
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
@ -248,9 +245,7 @@ def decode_one_batch(
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||
hyps.append([lexicon.token_table[i] for i in hyp])
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
@ -319,9 +314,7 @@ def decode_dataset(
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
@ -346,9 +339,7 @@ def save_results(
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append(
|
||||
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||
)
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
@ -359,8 +350,7 @@ def save_results(
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tCER", file=f)
|
||||
@ -430,9 +420,7 @@ def main():
|
||||
|
||||
if params.export:
|
||||
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||
torch.save(
|
||||
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
||||
)
|
||||
torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
|
||||
return
|
||||
|
||||
model.to(device)
|
||||
|
@ -86,9 +86,7 @@ class Decoder(nn.Module):
|
||||
if self.context_size > 1:
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
if need_pad is True:
|
||||
embedding_out = F.pad(
|
||||
embedding_out, pad=(self.context_size - 1, 0)
|
||||
)
|
||||
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
|
||||
else:
|
||||
# During inference time, there is no need to do extra padding
|
||||
# as we only need one output
|
||||
|
@ -110,8 +110,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
return parser
|
||||
@ -243,9 +242,7 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
@ -103,9 +103,7 @@ class Transducer(nn.Module):
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros(
|
||||
(x.size(0), 4), dtype=torch.int64, device=x.device
|
||||
)
|
||||
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
|
@ -117,8 +117,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
@ -212,8 +211,7 @@ def read_sound_files(
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
@ -273,9 +271,7 @@ def main():
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
@ -319,9 +315,7 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
@ -126,8 +126,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -389,9 +388,7 @@ def compute_loss(
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
@ -504,9 +501,7 @@ def train_one_epoch(
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
logging.info("Computing validation loss")
|
||||
@ -625,9 +620,7 @@ def run(rank, world_size, args):
|
||||
|
||||
cur_lr = optimizer._rate
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
if rank == 0:
|
||||
|
@ -250,9 +250,7 @@ def _get_activation_fn(activation: str):
|
||||
elif activation == "gelu":
|
||||
return nn.functional.gelu
|
||||
|
||||
raise RuntimeError(
|
||||
"activation should be relu/gelu, not {}".format(activation)
|
||||
)
|
||||
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
|
@ -29,10 +29,7 @@ from lhotse.dataset import (
|
||||
K2SpeechRecognitionDataset,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import (
|
||||
OnTheFlyFeatures,
|
||||
PrecomputedFeatures,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
@ -162,9 +159,7 @@ class AsrDataModule:
|
||||
if cuts_musan is not None:
|
||||
logging.info("Enable MUSAN")
|
||||
transforms.append(
|
||||
CutMix(
|
||||
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||
)
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
@ -173,9 +168,7 @@ class AsrDataModule:
|
||||
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(
|
||||
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||
)
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
@ -252,9 +245,7 @@ class AsrDataModule:
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
|
@ -170,8 +170,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -227,9 +226,7 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
@ -241,10 +238,7 @@ def decode_one_batch(
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
@ -365,9 +359,7 @@ def decode_dataset(
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
@ -393,9 +385,7 @@ def save_results(
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append(
|
||||
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||
)
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
@ -406,8 +396,7 @@ def save_results(
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tCER", file=f)
|
||||
@ -448,9 +437,7 @@ def main():
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
@ -109,8 +109,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
return parser
|
||||
@ -241,9 +240,7 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
@ -165,8 +165,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
@ -195,8 +194,7 @@ def read_sound_files(
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
@ -254,13 +252,9 @@ def main():
|
||||
feature_lens = [f.size(0) for f in features]
|
||||
feature_lens = torch.tensor(feature_lens, device=device)
|
||||
|
||||
features = pad_sequence(
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features, x_lens=feature_lens
|
||||
)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyp_list = []
|
||||
@ -308,9 +302,7 @@ def main():
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.method}"
|
||||
)
|
||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||
hyp_list.append(hyp)
|
||||
|
||||
hyps = []
|
||||
@ -327,9 +319,7 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
@ -149,8 +149,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -168,8 +167,7 @@ def get_parser():
|
||||
"--datatang-prob",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="The probability to select a batch from the "
|
||||
"aidatatang_200zh dataset",
|
||||
help="The probability to select a batch from the " "aidatatang_200zh dataset",
|
||||
)
|
||||
|
||||
return parser
|
||||
@ -449,9 +447,7 @@ def compute_loss(
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
@ -605,9 +601,7 @@ def train_one_epoch(
|
||||
f"train/current_{prefix}_",
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
aishell_tot_loss.write_summary(
|
||||
tb_writer, "train/aishell_tot_", params.batch_idx_train
|
||||
)
|
||||
@ -735,9 +729,7 @@ def run(rank, world_size, args):
|
||||
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
|
||||
|
||||
if args.enable_musan:
|
||||
cuts_musan = load_manifest(
|
||||
Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
|
||||
)
|
||||
cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
|
||||
else:
|
||||
cuts_musan = None
|
||||
|
||||
@ -776,9 +768,7 @@ def run(rank, world_size, args):
|
||||
|
||||
cur_lr = optimizer._rate
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
if rank == 0:
|
||||
|
@ -171,8 +171,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -231,9 +230,7 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
@ -245,10 +242,7 @@ def decode_one_batch(
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
@ -369,9 +363,7 @@ def decode_dataset(
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
@ -397,9 +389,7 @@ def save_results(
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append(
|
||||
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||
)
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
@ -410,8 +400,7 @@ def save_results(
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tCER", file=f)
|
||||
@ -452,9 +441,7 @@ def main():
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
@ -109,8 +109,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
return parser
|
||||
@ -241,9 +240,7 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
@ -165,8 +165,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
@ -195,8 +194,7 @@ def read_sound_files(
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
@ -254,13 +252,9 @@ def main():
|
||||
feature_lens = [f.size(0) for f in features]
|
||||
feature_lens = torch.tensor(feature_lens, device=device)
|
||||
|
||||
features = pad_sequence(
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features, x_lens=feature_lens
|
||||
)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyp_list = []
|
||||
@ -308,9 +302,7 @@ def main():
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.method}"
|
||||
)
|
||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||
hyp_list.append(hyp)
|
||||
|
||||
hyps = []
|
||||
@ -327,9 +319,7 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
@ -142,8 +142,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -414,9 +413,7 @@ def compute_loss(
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
@ -529,9 +526,7 @@ def train_one_epoch(
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
logging.info("Computing validation loss")
|
||||
@ -657,9 +652,7 @@ def run(rank, world_size, args):
|
||||
|
||||
cur_lr = optimizer._rate
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
if rank == 0:
|
||||
|
0
egs/aishell2/ASR/local/__init__.py
Executable file → Normal file
0
egs/aishell2/ASR/local/__init__.py
Executable file → Normal file
@ -83,9 +83,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80):
|
||||
)
|
||||
if "train" in partition:
|
||||
cut_set = (
|
||||
cut_set
|
||||
+ cut_set.perturb_speed(0.9)
|
||||
+ cut_set.perturb_speed(1.1)
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
@ -111,9 +109,7 @@ def get_args():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
|
0
egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py
Executable file → Normal file
0
egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py
Executable file → Normal file
24
egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
Executable file → Normal file
24
egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
Executable file → Normal file
@ -216,13 +216,9 @@ class AiShell2AsrDataModule:
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(
|
||||
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
||||
)
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(
|
||||
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||
)
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
@ -244,9 +240,7 @@ class AiShell2AsrDataModule:
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(
|
||||
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||
)
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
@ -290,9 +284,7 @@ class AiShell2AsrDataModule:
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
@ -348,9 +340,7 @@ class AiShell2AsrDataModule:
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
@ -406,9 +396,7 @@ class AiShell2AsrDataModule:
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz"
|
||||
)
|
||||
return load_manifest_lazy(self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
|
@ -269,8 +269,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
@ -348,9 +347,7 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
@ -409,10 +406,7 @@ def decode_one_batch(
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
@ -538,9 +532,7 @@ def decode_dataset(
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
@ -573,8 +565,7 @@ def save_results(
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
@ -625,9 +616,7 @@ def main():
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -661,9 +650,9 @@ def main():
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -690,9 +679,9 @@ def main():
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -749,9 +738,7 @@ def main():
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
decoding_graph = k2.trivial_graph(
|
||||
params.vocab_size - 1, device=device
|
||||
)
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
|
@ -133,8 +133,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
@ -167,9 +166,9 @@ def main():
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -196,9 +195,9 @@ def main():
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -266,9 +265,7 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
@ -159,8 +159,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
@ -192,8 +191,7 @@ def read_sound_files(
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
@ -254,15 +252,11 @@ def main():
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features, x_lens=feature_lengths
|
||||
)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyps = []
|
||||
@ -334,9 +328,7 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
@ -92,9 +92,7 @@ from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
]
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
@ -220,8 +218,7 @@ def get_parser():
|
||||
"--initial-lr",
|
||||
type=float,
|
||||
default=0.003,
|
||||
help="The initial learning rate. This value should not need "
|
||||
"to be changed.",
|
||||
help="The initial learning rate. This value should not need " "to be changed.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -244,8 +241,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -268,8 +264,7 @@ def get_parser():
|
||||
"--am-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The scale to smooth the loss with am (output of encoder network)"
|
||||
"part.",
|
||||
help="The scale to smooth the loss with am (output of encoder network)" "part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -603,11 +598,7 @@ def compute_loss(
|
||||
warmup: a floating point value which increases throughout training;
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
device = (
|
||||
model.device
|
||||
if isinstance(model, DDP)
|
||||
else next(model.parameters()).device
|
||||
)
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
@ -636,23 +627,16 @@ def compute_loss(
|
||||
# overwhelming the simple_loss and causing it to diverge,
|
||||
# in case it had not fully learned the alignment yet.
|
||||
pruned_loss_scale = (
|
||||
0.0
|
||||
if warmup < 1.0
|
||||
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||
)
|
||||
loss = (
|
||||
params.simple_loss_scale * simple_loss
|
||||
+ pruned_loss_scale * pruned_loss
|
||||
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||
)
|
||||
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
@ -771,9 +755,7 @@ def train_one_epoch(
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
except: # noqa
|
||||
display_and_save_batch(
|
||||
batch, params=params, graph_compiler=graph_compiler
|
||||
)
|
||||
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
|
||||
raise
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
@ -829,9 +811,7 @@ def train_one_epoch(
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
logging.info("Computing validation loss")
|
||||
@ -1104,9 +1084,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
f"Failing criterion: {criterion} "
|
||||
f"(={crit_values[criterion]}) ..."
|
||||
)
|
||||
display_and_save_batch(
|
||||
batch, params=params, graph_compiler=graph_compiler
|
||||
)
|
||||
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
|
||||
raise
|
||||
|
||||
|
||||
|
@ -85,9 +85,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
|
||||
)
|
||||
if "train" in partition:
|
||||
cut_set = (
|
||||
cut_set
|
||||
+ cut_set.perturb_speed(0.9)
|
||||
+ cut_set.perturb_speed(1.1)
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
@ -120,9 +118,7 @@ def get_args():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
|
@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
|
||||
cur_state = loop_state
|
||||
|
||||
word = word2id[word]
|
||||
pieces = [
|
||||
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
|
||||
]
|
||||
pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
|
||||
|
||||
for i in range(len(pieces) - 1):
|
||||
w = word if i == 0 else eps
|
||||
@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def generate_lexicon(
|
||||
token_sym_table: Dict[str, int], words: List[str]
|
||||
) -> Lexicon:
|
||||
def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
|
||||
"""Generate a lexicon from a word list and token_sym_table.
|
||||
|
||||
Args:
|
||||
|
@ -317,9 +317,7 @@ def lexicon_to_fst(
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
|
||||
)
|
||||
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
|
||||
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
||||
fsa.draw("L.pdf", title="L")
|
||||
|
||||
fsa_disambig = lexicon_to_fst(
|
||||
lexicon_disambig, phone2id=phone2id, word2id=word2id
|
||||
)
|
||||
fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
|
||||
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
|
||||
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
||||
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
|
||||
|
@ -56,9 +56,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--space", default="<space>", type=str, help="space symbol"
|
||||
)
|
||||
parser.add_argument("--space", default="<space>", type=str, help="space symbol")
|
||||
parser.add_argument(
|
||||
"--non-lang-syms",
|
||||
"-l",
|
||||
@ -66,9 +64,7 @@ def get_parser():
|
||||
type=str,
|
||||
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"text", type=str, default=False, nargs="?", help="input text"
|
||||
)
|
||||
parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
|
||||
parser.add_argument(
|
||||
"--trans_type",
|
||||
"-t",
|
||||
@ -108,8 +104,7 @@ def token2id(
|
||||
if token_type == "lazy_pinyin":
|
||||
text = lazy_pinyin(chars_list)
|
||||
sub_ids = [
|
||||
token_table[txt] if txt in token_table else oov_id
|
||||
for txt in text
|
||||
token_table[txt] if txt in token_table else oov_id for txt in text
|
||||
]
|
||||
ids.append(sub_ids)
|
||||
else: # token_type = "pinyin"
|
||||
@ -135,9 +130,7 @@ def main():
|
||||
if args.text:
|
||||
f = codecs.open(args.text, encoding="utf-8")
|
||||
else:
|
||||
f = codecs.getreader("utf-8")(
|
||||
sys.stdin if is_python2 else sys.stdin.buffer
|
||||
)
|
||||
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
|
||||
|
||||
sys.stdout = codecs.getwriter("utf-8")(
|
||||
sys.stdout if is_python2 else sys.stdout.buffer
|
||||
|
@ -222,17 +222,13 @@ class Aishell4AsrDataModule:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(
|
||||
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
||||
)
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
transforms.append(
|
||||
CutMix(
|
||||
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||
)
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
@ -254,9 +250,7 @@ class Aishell4AsrDataModule:
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(
|
||||
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||
)
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
@ -300,9 +294,7 @@ class Aishell4AsrDataModule:
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
@ -359,9 +351,7 @@ class Aishell4AsrDataModule:
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
|
@ -201,8 +201,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
@ -260,9 +259,7 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
@ -277,10 +274,7 @@ def decode_one_batch(
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
@ -401,9 +395,7 @@ def decode_dataset(
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
@ -436,8 +428,7 @@ def save_results(
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
@ -480,9 +471,7 @@ def main():
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -510,9 +499,9 @@ def main():
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -543,9 +532,9 @@ def main():
|
||||
)
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
|
@ -136,8 +136,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
@ -169,9 +168,9 @@ def main():
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -202,9 +201,9 @@ def main():
|
||||
)
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -276,9 +275,7 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
@ -172,8 +172,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
@ -205,8 +204,7 @@ def read_sound_files(
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
@ -266,15 +264,11 @@ def main():
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features, x_lens=feature_lengths
|
||||
)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyps = []
|
||||
@ -306,10 +300,7 @@ def main():
|
||||
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
@ -350,9 +341,7 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
@ -85,9 +85,7 @@ from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
]
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
@ -213,8 +211,7 @@ def get_parser():
|
||||
"--initial-lr",
|
||||
type=float,
|
||||
default=0.003,
|
||||
help="The initial learning rate. This value should not need "
|
||||
"to be changed.",
|
||||
help="The initial learning rate. This value should not need " "to be changed.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -237,8 +234,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -261,8 +257,7 @@ def get_parser():
|
||||
"--am-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The scale to smooth the loss with am (output of encoder network)"
|
||||
"part.",
|
||||
help="The scale to smooth the loss with am (output of encoder network)" "part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -599,11 +594,7 @@ def compute_loss(
|
||||
warmup: a floating point value which increases throughout training;
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
device = (
|
||||
model.device
|
||||
if isinstance(model, DDP)
|
||||
else next(model.parameters()).device
|
||||
)
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
@ -633,22 +624,15 @@ def compute_loss(
|
||||
# overwhelming the simple_loss and causing it to diverge,
|
||||
# in case it had not fully learned the alignment yet.
|
||||
pruned_loss_scale = (
|
||||
0.0
|
||||
if warmup < 1.0
|
||||
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||
)
|
||||
loss = (
|
||||
params.simple_loss_scale * simple_loss
|
||||
+ pruned_loss_scale * pruned_loss
|
||||
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||
)
|
||||
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
@ -827,9 +811,7 @@ def train_one_epoch(
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
logging.info("Computing validation loss")
|
||||
|
@ -84,9 +84,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
|
||||
)
|
||||
if "train" in partition:
|
||||
cut_set = (
|
||||
cut_set
|
||||
+ cut_set.perturb_speed(0.9)
|
||||
+ cut_set.perturb_speed(1.1)
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
cur_num_jobs = num_jobs if ex is None else 80
|
||||
cur_num_jobs = min(cur_num_jobs, len(cut_set))
|
||||
@ -121,9 +119,7 @@ def get_args():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
|
@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
|
||||
cur_state = loop_state
|
||||
|
||||
word = word2id[word]
|
||||
pieces = [
|
||||
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
|
||||
]
|
||||
pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
|
||||
|
||||
for i in range(len(pieces) - 1):
|
||||
w = word if i == 0 else eps
|
||||
@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def generate_lexicon(
|
||||
token_sym_table: Dict[str, int], words: List[str]
|
||||
) -> Lexicon:
|
||||
def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
|
||||
"""Generate a lexicon from a word list and token_sym_table.
|
||||
|
||||
Args:
|
||||
|
@ -317,9 +317,7 @@ def lexicon_to_fst(
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
|
||||
)
|
||||
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
|
||||
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
||||
fsa.draw("L.pdf", title="L")
|
||||
|
||||
fsa_disambig = lexicon_to_fst(
|
||||
lexicon_disambig, phone2id=phone2id, word2id=word2id
|
||||
)
|
||||
fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
|
||||
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
|
||||
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
||||
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
|
||||
|
@ -30,8 +30,8 @@ with word segmenting:
|
||||
|
||||
import argparse
|
||||
|
||||
import paddle
|
||||
import jieba
|
||||
import paddle
|
||||
from tqdm import tqdm
|
||||
|
||||
paddle.enable_static()
|
||||
|
@ -56,9 +56,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--space", default="<space>", type=str, help="space symbol"
|
||||
)
|
||||
parser.add_argument("--space", default="<space>", type=str, help="space symbol")
|
||||
parser.add_argument(
|
||||
"--non-lang-syms",
|
||||
"-l",
|
||||
@ -66,9 +64,7 @@ def get_parser():
|
||||
type=str,
|
||||
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"text", type=str, default=False, nargs="?", help="input text"
|
||||
)
|
||||
parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
|
||||
parser.add_argument(
|
||||
"--trans_type",
|
||||
"-t",
|
||||
@ -108,8 +104,7 @@ def token2id(
|
||||
if token_type == "lazy_pinyin":
|
||||
text = lazy_pinyin(chars_list)
|
||||
sub_ids = [
|
||||
token_table[txt] if txt in token_table else oov_id
|
||||
for txt in text
|
||||
token_table[txt] if txt in token_table else oov_id for txt in text
|
||||
]
|
||||
ids.append(sub_ids)
|
||||
else: # token_type = "pinyin"
|
||||
@ -135,9 +130,7 @@ def main():
|
||||
if args.text:
|
||||
f = codecs.open(args.text, encoding="utf-8")
|
||||
else:
|
||||
f = codecs.getreader("utf-8")(
|
||||
sys.stdin if is_python2 else sys.stdin.buffer
|
||||
)
|
||||
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
|
||||
|
||||
sys.stdout = codecs.getwriter("utf-8")(
|
||||
sys.stdout if is_python2 else sys.stdout.buffer
|
||||
|
@ -205,17 +205,13 @@ class AlimeetingAsrDataModule:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(
|
||||
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
||||
)
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
transforms.append(
|
||||
CutMix(
|
||||
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||
)
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
@ -237,9 +233,7 @@ class AlimeetingAsrDataModule:
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(
|
||||
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||
)
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
@ -282,9 +276,7 @@ class AlimeetingAsrDataModule:
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
@ -341,9 +333,7 @@ class AlimeetingAsrDataModule:
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
|
@ -70,11 +70,7 @@ from beam_search import (
|
||||
from lhotse.cut import Cut
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -193,8 +189,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
@ -249,9 +244,7 @@ def decode_one_batch(
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
@ -266,10 +259,7 @@ def decode_one_batch(
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
@ -390,9 +380,7 @@ def decode_dataset(
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
@ -425,8 +413,7 @@ def save_results(
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
@ -563,8 +550,7 @@ def main():
|
||||
)
|
||||
|
||||
dev_shards = [
|
||||
str(path)
|
||||
for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
|
||||
str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
|
||||
]
|
||||
cuts_dev_webdataset = CutSet.from_webdataset(
|
||||
dev_shards,
|
||||
@ -574,8 +560,7 @@ def main():
|
||||
)
|
||||
|
||||
test_shards = [
|
||||
str(path)
|
||||
for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
|
||||
str(path) for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
|
||||
]
|
||||
cuts_test_webdataset = CutSet.from_webdataset(
|
||||
test_shards,
|
||||
@ -588,9 +573,7 @@ def main():
|
||||
return 1.0 <= c.duration
|
||||
|
||||
cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt)
|
||||
cuts_test_webdataset = cuts_test_webdataset.filter(
|
||||
remove_short_and_long_utt
|
||||
)
|
||||
cuts_test_webdataset = cuts_test_webdataset.filter(remove_short_and_long_utt)
|
||||
|
||||
dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset)
|
||||
test_dl = alimeeting.test_dataloaders(cuts_test_webdataset)
|
||||
|
@ -103,8 +103,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
return parser
|
||||
@ -173,9 +172,7 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
@ -162,8 +162,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -194,8 +193,7 @@ def read_sound_files(
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
@ -257,9 +255,7 @@ def main():
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
@ -284,10 +280,7 @@ def main():
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
@ -339,9 +332,7 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
@ -81,9 +81,7 @@ from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
]
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||
|
||||
@ -187,8 +185,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -211,8 +208,7 @@ def get_parser():
|
||||
"--am-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The scale to smooth the loss with am (output of encoder network)"
|
||||
"part.",
|
||||
help="The scale to smooth the loss with am (output of encoder network)" "part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -542,22 +538,15 @@ def compute_loss(
|
||||
# overwhelming the simple_loss and causing it to diverge,
|
||||
# in case it had not fully learned the alignment yet.
|
||||
pruned_loss_scale = (
|
||||
0.0
|
||||
if warmup < 1.0
|
||||
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||
)
|
||||
loss = (
|
||||
params.simple_loss_scale * simple_loss
|
||||
+ pruned_loss_scale * pruned_loss
|
||||
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||
)
|
||||
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
@ -711,9 +700,7 @@ def train_one_epoch(
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
logging.info("Computing validation loss")
|
||||
|
@ -25,15 +25,10 @@ from random import Random
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from lhotse import (
|
||||
from lhotse import ( # fmt: off; See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527; fmt: on
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
# fmt: off
|
||||
# See the following for why LilcomChunkyWriter is preferred
|
||||
# https://github.com/k2-fsa/icefall/pull/404
|
||||
# https://github.com/lhotse-speech/lhotse/pull/527
|
||||
# fmt: on
|
||||
LilcomChunkyWriter,
|
||||
RecordingSet,
|
||||
SupervisionSet,
|
||||
@ -81,17 +76,13 @@ def make_cutset_blueprints(
|
||||
cut_sets.append((f"eval{i}", cut_set))
|
||||
|
||||
# Create train and valid cuts
|
||||
logging.info(
|
||||
"Loading, trimming, and shuffling the remaining core+noncore cuts."
|
||||
)
|
||||
logging.info("Loading, trimming, and shuffling the remaining core+noncore cuts.")
|
||||
recording_set = RecordingSet.from_file(
|
||||
manifest_dir / "csj_recordings_core.jsonl.gz"
|
||||
) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz")
|
||||
supervision_set = SupervisionSet.from_file(
|
||||
manifest_dir / "csj_supervisions_core.jsonl.gz"
|
||||
) + SupervisionSet.from_file(
|
||||
manifest_dir / "csj_supervisions_noncore.jsonl.gz"
|
||||
)
|
||||
) + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz")
|
||||
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=recording_set,
|
||||
@ -101,15 +92,12 @@ def make_cutset_blueprints(
|
||||
cut_set = cut_set.shuffle(Random(RNG_SEED))
|
||||
|
||||
logging.info(
|
||||
"Creating valid and train cuts from core and noncore,"
|
||||
f"split at {split}."
|
||||
"Creating valid and train cuts from core and noncore," f"split at {split}."
|
||||
)
|
||||
valid_set = CutSet.from_cuts(islice(cut_set, 0, split))
|
||||
|
||||
train_set = CutSet.from_cuts(islice(cut_set, split, None))
|
||||
train_set = (
|
||||
train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1)
|
||||
)
|
||||
train_set = train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1)
|
||||
|
||||
cut_sets.extend([("valid", valid_set), ("train", train_set)])
|
||||
|
||||
@ -122,15 +110,9 @@ def get_args():
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--manifest-dir", type=Path, help="Path to save manifests"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fbank-dir", type=Path, help="Path to save fbank features"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split", type=int, default=4000, help="Split at this index"
|
||||
)
|
||||
parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests")
|
||||
parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features")
|
||||
parser.add_argument("--split", type=int, default=4000, help="Split at this index")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@ -141,9 +123,7 @@ def main():
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=80))
|
||||
num_jobs = min(16, os.cpu_count())
|
||||
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
|
@ -26,7 +26,6 @@ from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
|
||||
ARGPARSE_DESCRIPTION = """
|
||||
This file computes fbank features of the musan dataset.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
@ -84,9 +83,7 @@ def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path):
|
||||
# create chunks of Musan with duration 5 - 10 seconds
|
||||
musan_cuts = (
|
||||
CutSet.from_manifests(
|
||||
recordings=combine(
|
||||
part["recordings"] for part in manifests.values()
|
||||
)
|
||||
recordings=combine(part["recordings"] for part in manifests.values())
|
||||
)
|
||||
.cut_into_windows(10.0)
|
||||
.filter(lambda c: c.duration > 5)
|
||||
@ -107,21 +104,15 @@ def get_args():
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--manifest-dir", type=Path, help="Path to save manifests"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fbank-dir", type=Path, help="Path to save fbank features"
|
||||
)
|
||||
parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests")
|
||||
parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
compute_fbank_musan(args.manifest_dir, args.fbank_dir)
|
||||
|
@ -318,4 +318,3 @@ spk_id = 2
|
||||
ャ = ǐa
|
||||
ュ = ǐu
|
||||
ョ = ǐo
|
||||
|
||||
|
@ -318,4 +318,3 @@ spk_id = 2
|
||||
ャ = ǐa
|
||||
ュ = ǐu
|
||||
ョ = ǐo
|
||||
|
||||
|
@ -318,4 +318,3 @@ spk_id = 2
|
||||
ャ = ǐa
|
||||
ュ = ǐu
|
||||
ョ = ǐo
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user