Revert "Apply new Black style changes"

This commit is contained in:
Fangjun Kuang 2022-11-17 20:19:32 +08:00 committed by GitHub
parent a7fbb18bdc
commit 60317120ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
441 changed files with 14535 additions and 6789 deletions

View File

@ -1,2 +0,0 @@
# Migrate to 88 characters per line (see: https://github.com/lhotse-speech/lhotse/issues/890)
d110b04ad389134c82fa314e3aafc7b40043efb0

View File

@ -45,18 +45,17 @@ jobs:
- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0
# Click issue fixed in https://github.com/psf/black/pull/2966
python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4
# See https://github.com/psf/black/issues/2964
# The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4
- name: Run flake8
shell: bash
working-directory: ${{github.workspace}}
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 \
--statistics --extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503
flake8 . --count --show-source --statistics
flake8 .
- name: Run black
shell: bash

View File

@ -1,38 +1,26 @@
repos:
- repo: https://github.com/psf/black
rev: 22.3.0
rev: 21.6b0
hooks:
- id: black
args: ["--line-length=88"]
additional_dependencies: ['click==8.1.0']
args: [--line-length=80]
additional_dependencies: ['click==8.0.1']
exclude: icefall\/__init__\.py
- repo: https://github.com/PyCQA/flake8
rev: 5.0.4
rev: 3.9.2
hooks:
- id: flake8
args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503"]
# What are we ignoring here?
# E203: whitespace before ':'
# E266: too many leading '#' for block comment
# E501: line too long
# F401: module imported but unused
# E402: module level import not at top of file
# F403: 'from module import *' used; unable to detect undefined names
# F841: local variable is assigned to but never used
# W503: line break before binary operator
# In addition, the default ignore list is:
# E121,E123,E126,E226,E24,E704,W503,W504
args: [--max-line-length=80]
- repo: https://github.com/pycqa/isort
rev: 5.10.1
rev: 5.9.2
hooks:
- id: isort
args: ["--profile=black"]
args: [--profile=black, --line-length=80]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.2.0
rev: v4.0.1
hooks:
- id: check-executables-have-shebangs
- id: end-of-file-fixer

View File

@ -88,3 +88,4 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -19,3 +19,4 @@ It can be downloaded from `<https://www.openslr.org/33/>`_
tdnn_lstm_ctc
conformer_ctc
stateless_transducer

View File

@ -6,3 +6,4 @@ TIMIT
tdnn_ligru_ctc
tdnn_lstm_ctc

View File

@ -87,7 +87,9 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
)
if "train" in partition:
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
@ -114,7 +116,9 @@ def get_args():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state
word = word2id[word]
pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
pieces = [
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False
def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
def generate_lexicon(
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table.
Args:

View File

@ -317,7 +317,9 @@ def lexicon_to_fst(
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
parser.add_argument(
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
)
return parser.parse_args()

View File

@ -88,7 +88,9 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
fsa_disambig = lexicon_to_fst(
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")

View File

@ -50,15 +50,15 @@ def get_parser():
"-n",
default=1,
type=int,
help=(
"number of characters to split, i.e., aabb -> a a b"
" b with -n 1 and aa bb with -n 2"
),
help="number of characters to split, i.e., \
aabb -> a a b b with -n 1 and aa bb with -n 2",
)
parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
)
parser.add_argument("--space", default="<space>", type=str, help="space symbol")
parser.add_argument(
"--space", default="<space>", type=str, help="space symbol"
)
parser.add_argument(
"--non-lang-syms",
"-l",
@ -66,7 +66,9 @@ def get_parser():
type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
)
parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
parser.add_argument(
"text", type=str, default=False, nargs="?", help="input text"
)
parser.add_argument(
"--trans_type",
"-t",
@ -106,7 +108,8 @@ def token2id(
if token_type == "lazy_pinyin":
text = lazy_pinyin(chars_list)
sub_ids = [
token_table[txt] if txt in token_table else oov_id for txt in text
token_table[txt] if txt in token_table else oov_id
for txt in text
]
ids.append(sub_ids)
else: # token_type = "pinyin"
@ -132,7 +135,9 @@ def main():
if args.text:
f = codecs.open(args.text, encoding="utf-8")
else:
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
f = codecs.getreader("utf-8")(
sys.stdin if is_python2 else sys.stdin.buffer
)
sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer

View File

@ -113,3 +113,4 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
./local/prepare_char.py
fi
fi

View File

@ -81,12 +81,10 @@ class Aidatatang_200zhAsrDataModule:
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description=(
"These options are used for the preparation of "
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc."
),
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
@ -98,91 +96,75 @@ class Aidatatang_200zhAsrDataModule:
"--max-duration",
type=int,
default=200.0,
help=(
"Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM."
),
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help=(
"When enabled, the batches will come from buckets of "
"similar duration (saves padding frames)."
),
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=300,
help=(
"The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets)."
),
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help=(
"When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding."
),
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help=(
"Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch."
),
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help=(
"The amount of padding (in seconds) inserted between "
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used."
),
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help=(
"When enabled, use on-the-fly cut mixing and feature "
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available."
),
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help=(
"When enabled (=default), the examples will be shuffled for each epoch."
),
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help=(
"When enabled, each batch will have the "
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it."
),
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that collect the batches.",
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
@ -196,22 +178,18 @@ class Aidatatang_200zhAsrDataModule:
"--spec-aug-time-warp-factor",
type=int,
default=80,
help=(
"Used only when --enable-spec-aug is True. "
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp."
),
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help=(
"When enabled, select noise from MUSAN and mix it"
"with training dataset. "
),
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
def train_dataloaders(
@ -227,20 +205,24 @@ class Aidatatang_200zhAsrDataModule:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
cuts_musan = load_manifest(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
"Using cut concatenation with duration factor "
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
@ -255,7 +237,9 @@ class Aidatatang_200zhAsrDataModule:
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
@ -298,7 +282,9 @@ class Aidatatang_200zhAsrDataModule:
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
@ -354,7 +340,9 @@ class Aidatatang_200zhAsrDataModule:
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts,
)
else:

View File

@ -69,7 +69,11 @@ from beam_search import (
)
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@ -88,30 +92,25 @@ def get_parser():
"--epoch",
type=int,
default=28,
help=(
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--batch",
type=int,
default=None,
help=(
"It specifies the batch checkpoint to use for decoding."
"Note: Epoch counts from 0."
),
help="It specifies the batch checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help=(
"Number of checkpoints to average. Automatically select "
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. "
),
"'--epoch'. ",
)
parser.add_argument(
@ -193,7 +192,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -249,7 +249,9 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
@ -264,7 +266,10 @@ def decode_one_batch(
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@ -310,7 +315,11 @@ def decode_one_batch(
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
@ -381,7 +390,9 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
@ -414,7 +425,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)

View File

@ -62,20 +62,17 @@ def get_parser():
"--epoch",
type=int,
default=28,
help=(
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help=(
"Number of checkpoints to average. Automatically select "
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. "
),
"'--epoch'. ",
)
parser.add_argument(
@ -106,7 +103,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser
@ -175,7 +173,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -85,11 +85,9 @@ def get_parser():
"--checkpoint",
type=str,
required=True,
help=(
"Path to the checkpoint. "
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint()."
),
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
@ -114,12 +112,10 @@ def get_parser():
"sound_files",
type=str,
nargs="+",
help=(
"The input sound file(s) to transcribe. "
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz."
),
"The sample rate has to be 16kHz.",
)
parser.add_argument(
@ -166,7 +162,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
@ -196,9 +193,10 @@ def read_sound_files(
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@ -259,7 +257,9 @@ def main():
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
@ -284,7 +284,10 @@ def main():
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@ -336,7 +339,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -81,7 +81,9 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
@ -185,45 +187,42 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--prune-range",
type=int,
default=5,
help=(
"The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss"
),
help="The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss",
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.25,
help=(
"The scale to smooth the loss with lm (output of prediction network) part."
),
help="The scale to smooth the loss with lm "
"(output of prediction network) part.",
)
parser.add_argument(
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)part.",
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
)
parser.add_argument(
"--simple-loss-scale",
type=float,
default=0.5,
help=(
"To get pruning ranges, we will calculate a simple version"
help="To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss."
),
"with this parameter before adding to the final loss.",
)
parser.add_argument(
@ -543,15 +542,22 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -705,7 +711,9 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")

View File

@ -157,7 +157,9 @@ class ConformerEncoderLayer(nn.Module):
normalize_before: bool = True,
) -> None:
super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0
)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
@ -175,14 +177,18 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
self.norm_ff_macaron = nn.LayerNorm(
d_model
) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
self.norm_final = nn.LayerNorm(
d_model
) # for the final output of the block
self.dropout = nn.Dropout(dropout)
@ -216,7 +222,9 @@ class ConformerEncoderLayer(nn.Module):
residual = src
if self.normalize_before:
src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
src = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(src)
)
if not self.normalize_before:
src = self.norm_ff_macaron(src)
@ -335,7 +343,9 @@ class RelPositionalEncoding(torch.nn.Module):
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
def __init__(
self, d_model: int, dropout_rate: float, max_len: int = 5000
) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
@ -351,7 +361,9 @@ class RelPositionalEncoding(torch.nn.Module):
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vector and `j` means the
@ -621,9 +633,9 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
3, dim=-1
)
q, k, v = nn.functional.linear(
query, in_proj_weight, in_proj_bias
).chunk(3, dim=-1)
elif torch.equal(key, value):
# encoder-decoder attention
@ -691,25 +703,33 @@ class RelPositionMultiheadAttention(nn.Module):
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError("The size of the 2D attn_mask is not correct.")
raise RuntimeError(
"The size of the 2D attn_mask is not correct."
)
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * num_heads,
query.size(0),
key.size(0),
]:
raise RuntimeError("The size of the 3D attn_mask is not correct.")
raise RuntimeError(
"The size of the 3D attn_mask is not correct."
)
else:
raise RuntimeError(
"attn_mask's dimension {} is not supported".format(attn_mask.dim())
"attn_mask's dimension {} is not supported".format(
attn_mask.dim()
)
)
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
if (
key_padding_mask is not None
and key_padding_mask.dtype == torch.uint8
):
warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor"
" instead."
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
)
key_padding_mask = key_padding_mask.to(torch.bool)
@ -746,7 +766,9 @@ class RelPositionMultiheadAttention(nn.Module):
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
matrix_ac = torch.matmul(
q_with_bias_u, k
) # (batch, head, time1, time2)
# compute matrix b and matrix d
matrix_bd = torch.matmul(
@ -758,7 +780,9 @@ class RelPositionMultiheadAttention(nn.Module):
matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1
)
assert list(attn_output_weights.size()) == [
bsz * num_heads,
@ -792,9 +816,13 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = (
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output.transpose(0, 1)
.contiguous()
.view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias
)
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights:
# average attention weights over heads
@ -817,7 +845,9 @@ class ConvolutionModule(nn.Module):
"""
def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
def __init__(
self, channels: int, kernel_size: int, bias: bool = True
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding

View File

@ -58,19 +58,16 @@ def get_parser():
"--epoch",
type=int,
default=49,
help=(
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=20,
help=(
"Number of checkpoints to average. Automatically select "
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. "
),
"'--epoch'. ",
)
parser.add_argument(
@ -404,7 +401,9 @@ def decode_dataset(
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
@ -432,7 +431,9 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
@ -440,7 +441,9 @@ def save_results(
test_set_wers[key] = wer
if enable_log:
logging.info("Wrote detailed error stats to {}".format(errs_filename))
logging.info(
"Wrote detailed error stats to {}".format(errs_filename)
)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
@ -559,7 +562,9 @@ def main():
eos_id=eos_id,
)
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
save_results(
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!")

View File

@ -40,20 +40,17 @@ def get_parser():
"--epoch",
type=int,
default=84,
help=(
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=25,
help=(
"Number of checkpoints to average. Automatically select "
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. "
),
"'--epoch'. ",
)
parser.add_argument(
@ -160,7 +157,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -46,29 +46,27 @@ def get_parser():
"--checkpoint",
type=str,
required=True,
help=(
"Path to the checkpoint. "
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint()."
),
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--tokens-file",
type=str,
help="Path to tokens.txtUsed only when method is ctc-decoding",
help="Path to tokens.txt" "Used only when method is ctc-decoding",
)
parser.add_argument(
"--words-file",
type=str,
help="Path to words.txtUsed when method is NOT ctc-decoding",
help="Path to words.txt" "Used when method is NOT ctc-decoding",
)
parser.add_argument(
"--HLG",
type=str,
help="Path to HLG.pt.Used when method is NOT ctc-decoding",
help="Path to HLG.pt." "Used when method is NOT ctc-decoding",
)
parser.add_argument(
@ -165,12 +163,10 @@ def get_parser():
"sound_files",
type=str,
nargs="+",
help=(
"The input sound file(s) to transcribe. "
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz."
),
"The sample rate has to be 16kHz.",
)
return parser
@ -214,9 +210,10 @@ def read_sound_files(
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@ -277,7 +274,9 @@ def main():
logging.info("Decoding started")
features = fbank(waves)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
# Note: We don't use key padding mask for attention during decoding
with torch.no_grad():
@ -372,7 +371,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -42,9 +42,13 @@ class Conv2dSubsampling(nn.Module):
assert idim >= 7
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
nn.Conv2d(
in_channels=1, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(),
nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
nn.Conv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(),
)
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
@ -128,13 +132,17 @@ class VggSubsampling(nn.Module):
)
)
layers.append(
torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
torch.nn.MaxPool2d(
kernel_size=2, stride=2, padding=0, ceil_mode=True
)
)
cur_channels = block_dim
self.layers = nn.Sequential(*layers)
self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
self.out = nn.Linear(
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.

View File

@ -16,8 +16,9 @@
# limitations under the License.
from subsampling import Conv2dSubsampling
from subsampling import VggSubsampling
import torch
from subsampling import Conv2dSubsampling, VggSubsampling
def test_conv2d_subsampling():

View File

@ -382,7 +382,9 @@ def compute_loss(
#
# See https://github.com/k2-fsa/icefall/issues/97
# for more details
unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
unsorted_token_ids = graph_compiler.texts_to_ids(
supervisions["text"]
)
att_loss = mmodel.decoder_forward(
encoder_memory,
memory_mask,
@ -518,7 +520,9 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
@ -626,7 +630,9 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:

View File

@ -149,7 +149,9 @@ class Transformer(nn.Module):
norm=decoder_norm,
)
self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
self.decoder_output_layer = torch.nn.Linear(
d_model, self.decoder_num_class
)
self.decoder_criterion = LabelSmoothingLoss()
else:
@ -181,7 +183,9 @@ class Transformer(nn.Module):
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision)
encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision
)
x = self.ctc_output(encoder_memory)
return x, encoder_memory, memory_key_padding_mask
@ -262,17 +266,23 @@ class Transformer(nn.Module):
"""
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device
ys_in_pad = ys_in_pad.to(device)
ys_out_pad = ys_out_pad.to(device)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
device
)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask
@ -333,17 +343,23 @@ class Transformer(nn.Module):
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
device
)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask
@ -616,7 +632,9 @@ def _get_activation_fn(activation: str):
elif activation == "gelu":
return nn.functional.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
raise RuntimeError(
"activation should be relu/gelu, not {}".format(activation)
)
class PositionalEncoding(nn.Module):
@ -818,7 +836,9 @@ def encoder_padding_mask(
1,
).to(torch.int32)
lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
lengths = [
0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
]
for idx in range(supervision_segments.size(0)):
# Note: TorchScript doesn't allow to unpack tensors as tuples
sequence_idx = supervision_segments[idx, 0].item()
@ -839,7 +859,9 @@ def encoder_padding_mask(
return mask
def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
def decoder_padding_mask(
ys_pad: torch.Tensor, ignore_id: int = -1
) -> torch.Tensor:
"""Generate a length mask for input.
The masked position are filled with True,

View File

@ -157,7 +157,9 @@ class ConformerEncoderLayer(nn.Module):
normalize_before: bool = True,
) -> None:
super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0
)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
@ -175,14 +177,18 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
self.norm_ff_macaron = nn.LayerNorm(
d_model
) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
self.norm_final = nn.LayerNorm(
d_model
) # for the final output of the block
self.dropout = nn.Dropout(dropout)
@ -216,7 +222,9 @@ class ConformerEncoderLayer(nn.Module):
residual = src
if self.normalize_before:
src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
src = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(src)
)
if not self.normalize_before:
src = self.norm_ff_macaron(src)
@ -335,7 +343,9 @@ class RelPositionalEncoding(torch.nn.Module):
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
def __init__(
self, d_model: int, dropout_rate: float, max_len: int = 5000
) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
@ -351,7 +361,9 @@ class RelPositionalEncoding(torch.nn.Module):
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vector and `j` means the
@ -621,9 +633,9 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
3, dim=-1
)
q, k, v = nn.functional.linear(
query, in_proj_weight, in_proj_bias
).chunk(3, dim=-1)
elif torch.equal(key, value):
# encoder-decoder attention
@ -691,25 +703,33 @@ class RelPositionMultiheadAttention(nn.Module):
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError("The size of the 2D attn_mask is not correct.")
raise RuntimeError(
"The size of the 2D attn_mask is not correct."
)
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * num_heads,
query.size(0),
key.size(0),
]:
raise RuntimeError("The size of the 3D attn_mask is not correct.")
raise RuntimeError(
"The size of the 3D attn_mask is not correct."
)
else:
raise RuntimeError(
"attn_mask's dimension {} is not supported".format(attn_mask.dim())
"attn_mask's dimension {} is not supported".format(
attn_mask.dim()
)
)
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
if (
key_padding_mask is not None
and key_padding_mask.dtype == torch.uint8
):
warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor"
" instead."
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
)
key_padding_mask = key_padding_mask.to(torch.bool)
@ -746,7 +766,9 @@ class RelPositionMultiheadAttention(nn.Module):
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
matrix_ac = torch.matmul(
q_with_bias_u, k
) # (batch, head, time1, time2)
# compute matrix b and matrix d
matrix_bd = torch.matmul(
@ -758,7 +780,9 @@ class RelPositionMultiheadAttention(nn.Module):
matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1
)
assert list(attn_output_weights.size()) == [
bsz * num_heads,
@ -792,9 +816,13 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = (
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output.transpose(0, 1)
.contiguous()
.view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias
)
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights:
# average attention weights over heads
@ -817,7 +845,9 @@ class ConvolutionModule(nn.Module):
"""
def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
def __init__(
self, channels: int, kernel_size: int, bias: bool = True
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding

View File

@ -59,19 +59,16 @@ def get_parser():
"--epoch",
type=int,
default=49,
help=(
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=20,
help=(
"Number of checkpoints to average. Automatically select "
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. "
),
"'--epoch'. ",
)
parser.add_argument(
@ -416,7 +413,9 @@ def decode_dataset(
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
@ -444,7 +443,9 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
@ -452,7 +453,9 @@ def save_results(
test_set_wers[key] = wer
if enable_log:
logging.info("Wrote detailed error stats to {}".format(errs_filename))
logging.info(
"Wrote detailed error stats to {}".format(errs_filename)
)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
@ -547,7 +550,9 @@ def main():
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return
model.to(device)
@ -576,7 +581,9 @@ def main():
eos_id=eos_id,
)
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
save_results(
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!")

View File

@ -42,9 +42,13 @@ class Conv2dSubsampling(nn.Module):
assert idim >= 7
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
nn.Conv2d(
in_channels=1, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(),
nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
nn.Conv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(),
)
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
@ -128,13 +132,17 @@ class VggSubsampling(nn.Module):
)
)
layers.append(
torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
torch.nn.MaxPool2d(
kernel_size=2, stride=2, padding=0, ceil_mode=True
)
)
cur_channels = block_dim
self.layers = nn.Sequential(*layers)
self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
self.out = nn.Linear(
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.

View File

@ -511,7 +511,9 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
@ -623,7 +625,9 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:

View File

@ -149,7 +149,9 @@ class Transformer(nn.Module):
norm=decoder_norm,
)
self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
self.decoder_output_layer = torch.nn.Linear(
d_model, self.decoder_num_class
)
self.decoder_criterion = LabelSmoothingLoss()
else:
@ -181,7 +183,9 @@ class Transformer(nn.Module):
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision)
encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision
)
x = self.ctc_output(encoder_memory)
return x, encoder_memory, memory_key_padding_mask
@ -262,17 +266,23 @@ class Transformer(nn.Module):
"""
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device
ys_in_pad = ys_in_pad.to(device)
ys_out_pad = ys_out_pad.to(device)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
device
)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask
@ -333,17 +343,23 @@ class Transformer(nn.Module):
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
device
)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask
@ -616,7 +632,9 @@ def _get_activation_fn(activation: str):
elif activation == "gelu":
return nn.functional.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
raise RuntimeError(
"activation should be relu/gelu, not {}".format(activation)
)
class PositionalEncoding(nn.Module):
@ -818,7 +836,9 @@ def encoder_padding_mask(
1,
).to(torch.int32)
lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
lengths = [
0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
]
for idx in range(supervision_segments.size(0)):
# Note: TorchScript doesn't allow to unpack tensors as tuples
sequence_idx = supervision_segments[idx, 0].item()
@ -839,7 +859,9 @@ def encoder_padding_mask(
return mask
def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
def decoder_padding_mask(
ys_pad: torch.Tensor, ignore_id: int = -1
) -> torch.Tensor:
"""Generate a length mask for input.
The masked position are filled with True,

View File

@ -87,7 +87,9 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
)
if "train" in partition:
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
@ -114,7 +116,9 @@ def get_args():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -83,7 +83,9 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
)
if "train" in partition:
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
@ -109,7 +111,9 @@ def get_args():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state
word = word2id[word]
pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
pieces = [
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False
def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
def generate_lexicon(
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table.
Args:

View File

@ -317,7 +317,9 @@ def lexicon_to_fst(
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
parser.add_argument(
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
)
return parser.parse_args()

View File

@ -88,7 +88,9 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
fsa_disambig = lexicon_to_fst(
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")

View File

@ -76,7 +76,11 @@ from beam_search import (
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@ -114,11 +118,9 @@ def get_parser():
"--avg",
type=int,
default=15,
help=(
"Number of checkpoints to average. Automatically select "
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'"
),
"'--epoch' and '--iter'",
)
parser.add_argument(
@ -186,7 +188,8 @@ def get_parser():
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -246,7 +249,9 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
@ -258,7 +263,10 @@ def decode_one_batch(
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@ -302,7 +310,11 @@ def decode_one_batch(
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
@ -375,7 +387,9 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
@ -401,7 +415,9 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -412,7 +428,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
@ -456,7 +473,9 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -485,7 +504,8 @@ def main():
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(

View File

@ -50,7 +50,11 @@ from pathlib import Path
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
@ -83,11 +87,9 @@ def get_parser():
"--avg",
type=int,
default=15,
help=(
"Number of checkpoints to average. Automatically select "
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'"
),
"'--epoch' and '--iter'",
)
parser.add_argument(
@ -118,7 +120,8 @@ def get_parser():
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
add_model_arguments(parser)
@ -154,7 +157,8 @@ def main():
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
@ -187,7 +191,9 @@ def main():
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
filename = (
params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
)
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
@ -195,14 +201,17 @@ def main():
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = (
params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
params.exp_dir
/ f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
)
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -87,11 +87,9 @@ def get_parser():
"--checkpoint",
type=str,
required=True,
help=(
"Path to the checkpoint. "
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint()."
),
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
@ -117,12 +115,10 @@ def get_parser():
"sound_files",
type=str,
nargs="+",
help=(
"The input sound file(s) to transcribe. "
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz."
),
"The sample rate has to be 16kHz.",
)
parser.add_argument(
@ -169,16 +165,15 @@ def get_parser():
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help=(
"Maximum number of symbols per frame. "
"Use only when --method is greedy_search"
),
help="Maximum number of symbols per frame. "
"Use only when --method is greedy_search",
)
add_model_arguments(parser)
@ -201,9 +196,10 @@ def read_sound_files(
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@ -260,9 +256,13 @@ def main():
feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lens
)
num_waves = encoder_out.size(0)
hyp_list = []
@ -310,7 +310,9 @@ def main():
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported decoding method: {params.method}")
raise ValueError(
f"Unsupported decoding method: {params.method}"
)
hyp_list.append(hyp)
hyps = []
@ -327,7 +329,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -49,6 +49,7 @@ import optim
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from conformer import Conformer
from decoder import Decoder
@ -74,7 +75,9 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def add_model_arguments(parser: argparse.ArgumentParser):
@ -200,7 +203,8 @@ def get_parser():
"--initial-lr",
type=float,
default=0.003,
help="The initial learning rate. This value should not need to be changed.",
help="The initial learning rate. This value should not need "
"to be changed.",
)
parser.add_argument(
@ -223,45 +227,42 @@ def get_parser():
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--prune-range",
type=int,
default=5,
help=(
"The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss"
),
help="The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss",
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.25,
help=(
"The scale to smooth the loss with lm (output of prediction network) part."
),
help="The scale to smooth the loss with lm "
"(output of prediction network) part.",
)
parser.add_argument(
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)part.",
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
)
parser.add_argument(
"--simple-loss-scale",
type=float,
default=0.5,
help=(
"To get pruning ranges, we will calculate a simple version"
help="To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss."
),
"with this parameter before adding to the final loss.",
)
parser.add_argument(
@ -560,7 +561,11 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
@ -588,16 +593,23 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -713,7 +725,9 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
except: # noqa
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
display_and_save_batch(
batch, params=params, graph_compiler=graph_compiler
)
raise
if params.print_diagnostics and batch_idx == 5:
@ -1015,7 +1029,9 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
display_and_save_batch(
batch, params=params, graph_compiler=graph_compiler
)
raise

View File

@ -121,24 +121,20 @@ def get_parser():
"--avg",
type=int,
default=15,
help=(
"Number of checkpoints to average. Automatically select "
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'"
),
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help=(
"Whether to load averaged model. Currently it only supports "
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. "
),
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
@ -206,7 +202,8 @@ def get_parser():
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -266,7 +263,9 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
@ -278,7 +277,10 @@ def decode_one_batch(
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@ -322,7 +324,11 @@ def decode_one_batch(
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
@ -395,7 +401,9 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
@ -421,7 +429,9 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -432,7 +442,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
@ -477,7 +488,9 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -505,12 +518,13 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
@ -537,12 +551,13 @@ def main():
)
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
@ -571,7 +586,7 @@ def main():
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
"Calculating the averaged model over epoch range from "
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)

View File

@ -88,24 +88,20 @@ def get_parser():
"--avg",
type=int,
default=15,
help=(
"Number of checkpoints to average. Automatically select "
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'"
),
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help=(
"Whether to load averaged model. Currently it only supports "
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. "
),
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
@ -136,7 +132,8 @@ def get_parser():
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
add_model_arguments(parser)
@ -169,12 +166,13 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
@ -197,12 +195,13 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
@ -230,7 +229,7 @@ def main():
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
"Calculating the averaged model over epoch range from "
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
@ -253,7 +252,9 @@ def main():
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
filename = (
params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
)
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
@ -261,14 +262,17 @@ def main():
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = (
params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
params.exp_dir
/ f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
)
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -84,7 +84,9 @@ class Transducer(nn.Module):
self.decoder_datatang = decoder_datatang
self.joiner_datatang = joiner_datatang
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5)
self.simple_am_proj = ScaledLinear(
encoder_dim, vocab_size, initial_speed=0.5
)
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
if decoder_datatang is not None:
@ -177,7 +179,9 @@ class Transducer(nn.Module):
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
boundary = torch.zeros(
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens
boundary[:, 3] = encoder_out_lens

View File

@ -87,11 +87,9 @@ def get_parser():
"--checkpoint",
type=str,
required=True,
help=(
"Path to the checkpoint. "
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint()."
),
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
@ -117,12 +115,10 @@ def get_parser():
"sound_files",
type=str,
nargs="+",
help=(
"The input sound file(s) to transcribe. "
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz."
),
"The sample rate has to be 16kHz.",
)
parser.add_argument(
@ -169,16 +165,15 @@ def get_parser():
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help=(
"Maximum number of symbols per frame. "
"Use only when --method is greedy_search"
),
help="Maximum number of symbols per frame. "
"Use only when --method is greedy_search",
)
add_model_arguments(parser)
@ -201,9 +196,10 @@ def read_sound_files(
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@ -261,9 +257,13 @@ def main():
feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lens
)
num_waves = encoder_out.size(0)
hyp_list = []
@ -311,7 +311,9 @@ def main():
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported decoding method: {params.method}")
raise ValueError(
f"Unsupported decoding method: {params.method}"
)
hyp_list.append(hyp)
hyps = []
@ -328,7 +330,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -96,7 +96,9 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def add_model_arguments(parser: argparse.ArgumentParser):
@ -222,7 +224,8 @@ def get_parser():
"--initial-lr",
type=float,
default=0.003,
help="The initial learning rate. This value should not need to be changed.",
help="The initial learning rate. This value should not need "
"to be changed.",
)
parser.add_argument(
@ -245,45 +248,42 @@ def get_parser():
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--prune-range",
type=int,
default=5,
help=(
"The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss"
),
help="The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss",
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.25,
help=(
"The scale to smooth the loss with lm (output of prediction network) part."
),
help="The scale to smooth the loss with lm "
"(output of prediction network) part.",
)
parser.add_argument(
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)part.",
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
)
parser.add_argument(
"--simple-loss-scale",
type=float,
default=0.5,
help=(
"To get pruning ranges, we will calculate a simple version"
help="To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss."
),
"with this parameter before adding to the final loss.",
)
parser.add_argument(
@ -635,7 +635,11 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
@ -666,16 +670,23 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -813,7 +824,9 @@ def train_one_epoch(
)
# summary stats
if datatang_train_dl is not None:
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
tot_loss = (
tot_loss * (1 - 1 / params.reset_interval)
) + loss_info
if aishell:
aishell_tot_loss = (
@ -834,7 +847,9 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
except: # noqa
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
display_and_save_batch(
batch, params=params, graph_compiler=graph_compiler
)
raise
if params.print_diagnostics and batch_idx == 5:
@ -877,7 +892,9 @@ def train_one_epoch(
cur_lr = scheduler.get_last_lr()[0]
if datatang_train_dl is not None:
datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], "
tot_loss_str = f"tot_loss[{tot_loss}], batch size: {batch_size}, "
tot_loss_str = (
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
)
else:
tot_loss_str = ""
datatang_str = ""
@ -1059,7 +1076,9 @@ def run(rank, world_size, args):
train_cuts = filter_short_and_long_utterances(train_cuts)
if args.enable_musan:
cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
cuts_musan = load_manifest(
Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
)
else:
cuts_musan = None
@ -1074,7 +1093,9 @@ def run(rank, world_size, args):
if params.datatang_prob > 0:
datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
train_datatang_cuts = datatang.train_cuts()
train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts)
train_datatang_cuts = filter_short_and_long_utterances(
train_datatang_cuts
)
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
datatang_train_dl = asr_datamodule.train_dataloaders(
train_datatang_cuts,
@ -1228,7 +1249,9 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
display_and_save_batch(
batch, params=params, graph_compiler=graph_compiler
)
raise

View File

@ -64,12 +64,10 @@ class AishellAsrDataModule:
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description=(
"These options are used for the preparation of "
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc."
),
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
@ -81,74 +79,59 @@ class AishellAsrDataModule:
"--max-duration",
type=int,
default=200.0,
help=(
"Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM."
),
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help=(
"When enabled, the batches will come from buckets of "
"similar duration (saves padding frames)."
),
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help=(
"The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets)."
),
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help=(
"When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding."
),
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help=(
"Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch."
),
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help=(
"The amount of padding (in seconds) inserted between "
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used."
),
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help=(
"When enabled, use on-the-fly cut mixing and feature "
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available."
),
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help=(
"When enabled (=default), the examples will be shuffled for each epoch."
),
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
@ -160,18 +143,17 @@ class AishellAsrDataModule:
"--return-cuts",
type=str2bool,
default=True,
help=(
"When enabled, each batch will have the "
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it."
),
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that collect the batches.",
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
@ -185,40 +167,40 @@ class AishellAsrDataModule:
"--spec-aug-time-warp-factor",
type=int,
default=80,
help=(
"Used only when --enable-spec-aug is True. "
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp."
),
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help=(
"When enabled, select noise from MUSAN and mix it"
"with training dataset. "
),
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
cuts_musan = load_manifest(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
"Using cut concatenation with duration factor "
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
@ -233,7 +215,9 @@ class AishellAsrDataModule:
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
@ -276,7 +260,9 @@ class AishellAsrDataModule:
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
@ -322,7 +308,9 @@ class AishellAsrDataModule:
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts,
)
else:
@ -378,9 +366,13 @@ class AishellAsrDataModule:
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz")
return load_manifest_lazy(
self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts")
return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz")
return load_manifest_lazy(
self.args.manifest_dir / "aishell_cuts_test.jsonl.gz"
)

View File

@ -49,19 +49,16 @@ def get_parser():
"--epoch",
type=int,
default=19,
help=(
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=5,
help=(
"Number of checkpoints to average. Automatically select "
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. "
),
"'--epoch'. ",
)
parser.add_argument(
"--method",
@ -268,7 +265,9 @@ def decode_dataset(
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
@ -290,7 +289,9 @@ def save_results(
# We compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results_char)
test_set_wers[key] = wer
@ -334,7 +335,9 @@ def main():
logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device)
assert HLG.requires_grad is False
@ -359,7 +362,9 @@ def main():
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
model.to(device)
model.eval()
@ -387,7 +392,9 @@ def main():
lexicon=lexicon,
)
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
save_results(
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!")

View File

@ -66,7 +66,10 @@ class TdnnLstm(nn.Module):
nn.BatchNorm1d(num_features=500, affine=False),
)
self.lstms = nn.ModuleList(
[nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)]
[
nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
for _ in range(5)
]
)
self.lstm_bnorms = nn.ModuleList(
[nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]

View File

@ -41,11 +41,9 @@ def get_parser():
"--checkpoint",
type=str,
required=True,
help=(
"Path to the checkpoint. "
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint()."
),
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
@ -55,7 +53,9 @@ def get_parser():
help="Path to words.txt",
)
parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
parser.add_argument(
"--HLG", type=str, required=True, help="Path to HLG.pt."
)
parser.add_argument(
"--method",
@ -71,12 +71,10 @@ def get_parser():
"sound_files",
type=str,
nargs="+",
help=(
"The input sound file(s) to transcribe. "
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz."
),
"The sample rate has to be 16kHz.",
)
return parser
@ -114,9 +112,10 @@ def read_sound_files(
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@ -174,7 +173,9 @@ def main():
logging.info("Decoding started")
features = fbank(waves)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
features = features.permute(0, 2, 1) # now features is [N, C, T]
with torch.no_grad():
@ -218,7 +219,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -49,7 +49,12 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
encode_supervisions,
setup_logger,
str2bool,
)
def get_parser():

View File

@ -47,9 +47,9 @@ def greedy_search(
device = model.device
decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
1, context_size
)
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
@ -81,9 +81,9 @@ def greedy_search(
y = logits.argmax().item()
if y != blank_id:
hyp.append(y)
decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
1, context_size
)
decoder_input = torch.tensor(
[hyp[-context_size:]], device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
@ -157,7 +157,9 @@ class HypothesisList(object):
"""
if length_norm:
return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
return max(
self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
)
else:
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
@ -244,9 +246,9 @@ def beam_search(
device = model.device
decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
1, context_size
)
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)

View File

@ -155,7 +155,9 @@ class ConformerEncoderLayer(nn.Module):
normalize_before: bool = True,
) -> None:
super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0
)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
@ -173,14 +175,18 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
self.norm_ff_macaron = nn.LayerNorm(
d_model
) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
self.norm_final = nn.LayerNorm(
d_model
) # for the final output of the block
self.dropout = nn.Dropout(dropout)
@ -214,7 +220,9 @@ class ConformerEncoderLayer(nn.Module):
residual = src
if self.normalize_before:
src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
src = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(src)
)
if not self.normalize_before:
src = self.norm_ff_macaron(src)
@ -333,7 +341,9 @@ class RelPositionalEncoding(torch.nn.Module):
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
def __init__(
self, d_model: int, dropout_rate: float, max_len: int = 5000
) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
@ -349,7 +359,9 @@ class RelPositionalEncoding(torch.nn.Module):
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vector and `j` means the
@ -619,9 +631,9 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
3, dim=-1
)
q, k, v = nn.functional.linear(
query, in_proj_weight, in_proj_bias
).chunk(3, dim=-1)
elif torch.equal(key, value):
# encoder-decoder attention
@ -689,25 +701,33 @@ class RelPositionMultiheadAttention(nn.Module):
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError("The size of the 2D attn_mask is not correct.")
raise RuntimeError(
"The size of the 2D attn_mask is not correct."
)
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * num_heads,
query.size(0),
key.size(0),
]:
raise RuntimeError("The size of the 3D attn_mask is not correct.")
raise RuntimeError(
"The size of the 3D attn_mask is not correct."
)
else:
raise RuntimeError(
"attn_mask's dimension {} is not supported".format(attn_mask.dim())
"attn_mask's dimension {} is not supported".format(
attn_mask.dim()
)
)
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
if (
key_padding_mask is not None
and key_padding_mask.dtype == torch.uint8
):
warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor"
" instead."
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
)
key_padding_mask = key_padding_mask.to(torch.bool)
@ -744,7 +764,9 @@ class RelPositionMultiheadAttention(nn.Module):
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
matrix_ac = torch.matmul(
q_with_bias_u, k
) # (batch, head, time1, time2)
# compute matrix b and matrix d
matrix_bd = torch.matmul(
@ -756,7 +778,9 @@ class RelPositionMultiheadAttention(nn.Module):
matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1
)
assert list(attn_output_weights.size()) == [
bsz * num_heads,
@ -790,9 +814,13 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = (
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output.transpose(0, 1)
.contiguous()
.view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias
)
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights:
# average attention weights over heads
@ -815,7 +843,9 @@ class ConvolutionModule(nn.Module):
"""
def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
def __init__(
self, channels: int, kernel_size: int, bias: bool = True
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding

View File

@ -52,19 +52,16 @@ def get_parser():
"--epoch",
type=int,
default=30,
help=(
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=10,
help=(
"Number of checkpoints to average. Automatically select "
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. "
),
"'--epoch'. ",
)
parser.add_argument(
@ -102,7 +99,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -229,7 +227,9 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
batch_size = encoder_out.size(0)
@ -248,7 +248,9 @@ def decode_one_batch(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[i] for i in hyp])
if params.decoding_method == "greedy_search":
@ -317,7 +319,9 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
@ -342,7 +346,9 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -353,7 +359,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
@ -423,7 +430,9 @@ def main():
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return
model.to(device)

View File

@ -86,7 +86,9 @@ class Decoder(nn.Module):
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output

View File

@ -69,20 +69,17 @@ def get_parser():
"--epoch",
type=int,
default=20,
help=(
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=10,
help=(
"Number of checkpoints to average. Automatically select "
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. "
),
"'--epoch'. ",
)
parser.add_argument(
@ -113,7 +110,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser
@ -245,7 +243,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -103,7 +103,9 @@ class Transducer(nn.Module):
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
boundary = torch.zeros(
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens

View File

@ -73,11 +73,9 @@ def get_parser():
"--checkpoint",
type=str,
required=True,
help=(
"Path to the checkpoint. "
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint()."
),
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
@ -102,12 +100,10 @@ def get_parser():
"sound_files",
type=str,
nargs="+",
help=(
"The input sound file(s) to transcribe. "
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz."
),
"The sample rate has to be 16kHz.",
)
parser.add_argument(
@ -121,7 +117,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -214,9 +211,10 @@ def read_sound_files(
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@ -275,7 +273,9 @@ def main():
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
@ -319,7 +319,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -126,7 +126,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
@ -388,7 +389,9 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -501,7 +504,9 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
@ -620,7 +625,9 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:

View File

@ -250,7 +250,9 @@ def _get_activation_fn(activation: str):
elif activation == "gelu":
return nn.functional.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
raise RuntimeError(
"activation should be relu/gelu, not {}".format(activation)
)
class PositionalEncoding(nn.Module):

View File

@ -29,7 +29,10 @@ from lhotse.dataset import (
K2SpeechRecognitionDataset,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures
from lhotse.dataset.input_strategies import (
OnTheFlyFeatures,
PrecomputedFeatures,
)
from torch.utils.data import DataLoader
from icefall.utils import str2bool
@ -43,69 +46,59 @@ class AsrDataModule:
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description=(
"These options are used for the preparation of "
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc."
),
"augmentations, etc.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help=(
"Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM."
),
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help=(
"When enabled, the batches will come from buckets of "
"similar duration (saves padding frames)."
),
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help=(
"The number of buckets for the DynamicBucketingSampler "
"(you might want to increase it for larger datasets)."
),
help="The number of buckets for the DynamicBucketingSampler "
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help=(
"When enabled (=default), the examples will be shuffled for each epoch."
),
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help=(
"When enabled, each batch will have the "
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it."
),
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that collect the batches.",
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
@ -119,22 +112,18 @@ class AsrDataModule:
"--spec-aug-time-warp-factor",
type=int,
default=80,
help=(
"Used only when --enable-spec-aug is True. "
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp."
),
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help=(
"When enabled, select noise from MUSAN and mix it"
"with training dataset. "
),
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
@ -148,11 +137,9 @@ class AsrDataModule:
"--on-the-fly-feats",
type=str2bool,
default=False,
help=(
"When enabled, use on-the-fly cut mixing and feature "
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available. Used only in dev/test CutSet"
),
"if available. Used only in dev/test CutSet",
)
def train_dataloaders(
@ -175,7 +162,9 @@ class AsrDataModule:
if cuts_musan is not None:
logging.info("Enable MUSAN")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
)
else:
logging.info("Disable MUSAN")
@ -184,7 +173,9 @@ class AsrDataModule:
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
@ -261,7 +252,9 @@ class AsrDataModule:
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts,
)
else:

View File

@ -93,19 +93,16 @@ def get_parser():
"--epoch",
type=int,
default=30,
help=(
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=10,
help=(
"Number of checkpoints to average. Automatically select "
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. "
),
"'--epoch'. ",
)
parser.add_argument(
@ -173,7 +170,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
@ -229,7 +227,9 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
@ -241,7 +241,10 @@ def decode_one_batch(
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@ -285,7 +288,11 @@ def decode_one_batch(
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
@ -358,7 +365,9 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
@ -384,7 +393,9 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -395,7 +406,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
@ -436,7 +448,9 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"

View File

@ -68,20 +68,17 @@ def get_parser():
"--epoch",
type=int,
default=20,
help=(
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=10,
help=(
"Number of checkpoints to average. Automatically select "
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. "
),
"'--epoch'. ",
)
parser.add_argument(
@ -112,7 +109,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser
@ -243,7 +241,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -87,11 +87,9 @@ def get_parser():
"--checkpoint",
type=str,
required=True,
help=(
"Path to the checkpoint. "
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint()."
),
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
@ -117,12 +115,10 @@ def get_parser():
"sound_files",
type=str,
nargs="+",
help=(
"The input sound file(s) to transcribe. "
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz."
),
"The sample rate has to be 16kHz.",
)
parser.add_argument(
@ -169,16 +165,15 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help=(
"Maximum number of symbols per frame. "
"Use only when --method is greedy_search"
),
help="Maximum number of symbols per frame. "
"Use only when --method is greedy_search",
)
return parser
@ -199,9 +194,10 @@ def read_sound_files(
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@ -258,9 +254,13 @@ def main():
feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lens
)
num_waves = encoder_out.size(0)
hyp_list = []
@ -308,7 +308,9 @@ def main():
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported decoding method: {params.method}")
raise ValueError(
f"Unsupported decoding method: {params.method}"
)
hyp_list.append(hyp)
hyps = []
@ -325,7 +327,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -149,7 +149,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
@ -167,7 +168,8 @@ def get_parser():
"--datatang-prob",
type=float,
default=0.2,
help="The probability to select a batch from the aidatatang_200zh dataset",
help="The probability to select a batch from the "
"aidatatang_200zh dataset",
)
return parser
@ -447,7 +449,9 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -601,7 +605,9 @@ def train_one_epoch(
f"train/current_{prefix}_",
params.batch_idx_train,
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
aishell_tot_loss.write_summary(
tb_writer, "train/aishell_tot_", params.batch_idx_train
)
@ -729,7 +735,9 @@ def run(rank, world_size, args):
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
if args.enable_musan:
cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
cuts_musan = load_manifest(
Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
)
else:
cuts_musan = None
@ -768,7 +776,9 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:

View File

@ -94,19 +94,16 @@ def get_parser():
"--epoch",
type=int,
default=30,
help=(
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=10,
help=(
"Number of checkpoints to average. Automatically select "
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. "
),
"'--epoch'. ",
)
parser.add_argument(
@ -174,7 +171,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
@ -233,7 +231,9 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
@ -245,7 +245,10 @@ def decode_one_batch(
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@ -289,7 +292,11 @@ def decode_one_batch(
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
@ -362,7 +369,9 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
@ -388,7 +397,9 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -399,7 +410,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
@ -440,7 +452,9 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"

View File

@ -68,20 +68,17 @@ def get_parser():
"--epoch",
type=int,
default=20,
help=(
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=10,
help=(
"Number of checkpoints to average. Automatically select "
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. "
),
"'--epoch'. ",
)
parser.add_argument(
@ -112,7 +109,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser
@ -243,7 +241,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -87,11 +87,9 @@ def get_parser():
"--checkpoint",
type=str,
required=True,
help=(
"Path to the checkpoint. "
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint()."
),
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
@ -117,12 +115,10 @@ def get_parser():
"sound_files",
type=str,
nargs="+",
help=(
"The input sound file(s) to transcribe. "
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz."
),
"The sample rate has to be 16kHz.",
)
parser.add_argument(
@ -169,16 +165,15 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help=(
"Maximum number of symbols per frame. "
"Use only when --method is greedy_search"
),
help="Maximum number of symbols per frame. "
"Use only when --method is greedy_search",
)
return parser
@ -199,9 +194,10 @@ def read_sound_files(
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@ -258,9 +254,13 @@ def main():
feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lens
)
num_waves = encoder_out.size(0)
hyp_list = []
@ -308,7 +308,9 @@ def main():
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported decoding method: {params.method}")
raise ValueError(
f"Unsupported decoding method: {params.method}"
)
hyp_list.append(hyp)
hyps = []
@ -325,7 +327,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -142,7 +142,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
@ -413,7 +414,9 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -526,7 +529,9 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
@ -652,7 +657,9 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:

0
egs/aishell2/ASR/local/__init__.py Normal file → Executable file
View File

View File

@ -83,7 +83,9 @@ def compute_fbank_aishell2(num_mel_bins: int = 80):
)
if "train" in partition:
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
@ -109,7 +111,9 @@ def get_args():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)

View File

View File

@ -76,12 +76,10 @@ class AiShell2AsrDataModule:
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description=(
"These options are used for the preparation of "
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc."
),
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
@ -93,74 +91,59 @@ class AiShell2AsrDataModule:
"--max-duration",
type=int,
default=200.0,
help=(
"Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM."
),
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help=(
"When enabled, the batches will come from buckets of "
"similar duration (saves padding frames)."
),
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help=(
"The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets)."
),
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help=(
"When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding."
),
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help=(
"Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch."
),
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help=(
"The amount of padding (in seconds) inserted between "
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used."
),
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help=(
"When enabled, use on-the-fly cut mixing and feature "
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available."
),
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help=(
"When enabled (=default), the examples will be shuffled for each epoch."
),
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
@ -172,18 +155,17 @@ class AiShell2AsrDataModule:
"--return-cuts",
type=str2bool,
default=True,
help=(
"When enabled, each batch will have the "
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it."
),
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that collect the batches.",
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
@ -197,22 +179,18 @@ class AiShell2AsrDataModule:
"--spec-aug-time-warp-factor",
type=int,
default=80,
help=(
"Used only when --enable-spec-aug is True. "
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp."
),
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help=(
"When enabled, select noise from MUSAN and mix it"
"with training dataset. "
),
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
@ -238,16 +216,20 @@ class AiShell2AsrDataModule:
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
cuts_musan = load_manifest(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
"Using cut concatenation with duration factor "
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
@ -262,7 +244,9 @@ class AiShell2AsrDataModule:
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
@ -306,7 +290,9 @@ class AiShell2AsrDataModule:
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
@ -362,7 +348,9 @@ class AiShell2AsrDataModule:
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts,
)
else:
@ -418,7 +406,9 @@ class AiShell2AsrDataModule:
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz")
return load_manifest_lazy(self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz")
return load_manifest_lazy(
self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> CutSet:

View File

@ -168,24 +168,20 @@ def get_parser():
"--avg",
type=int,
default=15,
help=(
"Number of checkpoints to average. Automatically select "
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'"
),
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help=(
"Whether to load averaged model. Currently it only supports "
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. "
),
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
@ -273,7 +269,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -351,7 +348,9 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
@ -410,7 +409,10 @@ def decode_one_batch(
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@ -536,7 +538,9 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
@ -569,7 +573,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
@ -620,7 +625,9 @@ def main():
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -654,12 +661,13 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
@ -682,12 +690,13 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
@ -715,7 +724,7 @@ def main():
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
"Calculating the averaged model over epoch range from "
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
@ -740,7 +749,9 @@ def main():
)
decoding_graph.scores *= params.ngram_lm_scale
else:
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else:
decoding_graph = None

View File

@ -89,24 +89,20 @@ def get_parser():
"--avg",
type=int,
default=15,
help=(
"Number of checkpoints to average. Automatically select "
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'"
),
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help=(
"Whether to load averaged model. Currently it only supports "
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. "
),
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
@ -137,7 +133,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
add_model_arguments(parser)
@ -170,12 +167,13 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
@ -198,12 +196,13 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
@ -231,7 +230,7 @@ def main():
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
"Calculating the averaged model over epoch range from "
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
@ -267,7 +266,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -81,11 +81,9 @@ def get_parser():
"--checkpoint",
type=str,
required=True,
help=(
"Path to the checkpoint. "
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint()."
),
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
@ -111,12 +109,10 @@ def get_parser():
"sound_files",
type=str,
nargs="+",
help=(
"The input sound file(s) to transcribe. "
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz."
),
"The sample rate has to be 16kHz.",
)
parser.add_argument(
@ -163,7 +159,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -194,9 +191,10 @@ def read_sound_files(
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@ -256,11 +254,15 @@ def main():
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyps = []
@ -332,7 +334,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -92,7 +92,9 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def add_model_arguments(parser: argparse.ArgumentParser):
@ -218,7 +220,8 @@ def get_parser():
"--initial-lr",
type=float,
default=0.003,
help="The initial learning rate. This value should not need to be changed.",
help="The initial learning rate. This value should not need "
"to be changed.",
)
parser.add_argument(
@ -241,45 +244,42 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--prune-range",
type=int,
default=5,
help=(
"The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss"
),
help="The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss",
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.25,
help=(
"The scale to smooth the loss with lm (output of prediction network) part."
),
help="The scale to smooth the loss with lm "
"(output of prediction network) part.",
)
parser.add_argument(
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)part.",
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
)
parser.add_argument(
"--simple-loss-scale",
type=float,
default=0.5,
help=(
"To get pruning ranges, we will calculate a simple version"
help="To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss."
),
"with this parameter before adding to the final loss.",
)
parser.add_argument(
@ -603,7 +603,11 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
@ -632,16 +636,23 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -760,7 +771,9 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
except: # noqa
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
display_and_save_batch(
batch, params=params, graph_compiler=graph_compiler
)
raise
if params.print_diagnostics and batch_idx == 5:
@ -816,7 +829,9 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
@ -1089,7 +1104,9 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
display_and_save_batch(
batch, params=params, graph_compiler=graph_compiler
)
raise

View File

@ -85,7 +85,9 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
)
if "train" in partition:
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
@ -118,7 +120,9 @@ def get_args():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state
word = word2id[word]
pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
pieces = [
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False
def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
def generate_lexicon(
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table.
Args:

View File

@ -317,7 +317,9 @@ def lexicon_to_fst(
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
parser.add_argument(
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
)
return parser.parse_args()

View File

@ -88,7 +88,9 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
fsa_disambig = lexicon_to_fst(
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")

View File

@ -50,15 +50,15 @@ def get_parser():
"-n",
default=1,
type=int,
help=(
"number of characters to split, i.e., aabb -> a a b"
" b with -n 1 and aa bb with -n 2"
),
help="number of characters to split, i.e., \
aabb -> a a b b with -n 1 and aa bb with -n 2",
)
parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
)
parser.add_argument("--space", default="<space>", type=str, help="space symbol")
parser.add_argument(
"--space", default="<space>", type=str, help="space symbol"
)
parser.add_argument(
"--non-lang-syms",
"-l",
@ -66,7 +66,9 @@ def get_parser():
type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
)
parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
parser.add_argument(
"text", type=str, default=False, nargs="?", help="input text"
)
parser.add_argument(
"--trans_type",
"-t",
@ -106,7 +108,8 @@ def token2id(
if token_type == "lazy_pinyin":
text = lazy_pinyin(chars_list)
sub_ids = [
token_table[txt] if txt in token_table else oov_id for txt in text
token_table[txt] if txt in token_table else oov_id
for txt in text
]
ids.append(sub_ids)
else: # token_type = "pinyin"
@ -132,7 +135,9 @@ def main():
if args.text:
f = codecs.open(args.text, encoding="utf-8")
else:
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
f = codecs.getreader("utf-8")(
sys.stdin if is_python2 else sys.stdin.buffer
)
sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer

View File

@ -74,12 +74,10 @@ class Aishell4AsrDataModule:
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description=(
"These options are used for the preparation of "
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc."
),
"augmentations, etc.",
)
group.add_argument(
@ -93,81 +91,66 @@ class Aishell4AsrDataModule:
"--max-duration",
type=int,
default=200.0,
help=(
"Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM."
),
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help=(
"When enabled, the batches will come from buckets of "
"similar duration (saves padding frames)."
),
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=300,
help=(
"The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets)."
),
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help=(
"When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding."
),
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help=(
"Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch."
),
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help=(
"The amount of padding (in seconds) inserted between "
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used."
),
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help=(
"When enabled, use on-the-fly cut mixing and feature "
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available."
),
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help=(
"When enabled (=default), the examples will be shuffled for each epoch."
),
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
@ -181,18 +164,17 @@ class Aishell4AsrDataModule:
"--return-cuts",
type=str2bool,
default=True,
help=(
"When enabled, each batch will have the "
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it."
),
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that collect the batches.",
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
@ -206,22 +188,18 @@ class Aishell4AsrDataModule:
"--spec-aug-time-warp-factor",
type=int,
default=80,
help=(
"Used only when --enable-spec-aug is True. "
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp."
),
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help=(
"When enabled, select noise from MUSAN and mix it"
"with training dataset. "
),
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
@ -244,20 +222,24 @@ class Aishell4AsrDataModule:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
cuts_musan = load_manifest(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
"Using cut concatenation with duration factor "
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
@ -272,7 +254,9 @@ class Aishell4AsrDataModule:
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
@ -316,7 +300,9 @@ class Aishell4AsrDataModule:
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
@ -373,7 +359,9 @@ class Aishell4AsrDataModule:
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts,
)
else:

View File

@ -117,24 +117,20 @@ def get_parser():
"--avg",
type=int,
default=15,
help=(
"Number of checkpoints to average. Automatically select "
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'"
),
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help=(
"Whether to load averaged model. Currently it only supports "
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. "
),
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
@ -205,7 +201,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -263,7 +260,9 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
@ -278,7 +277,10 @@ def decode_one_batch(
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@ -324,7 +326,11 @@ def decode_one_batch(
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
@ -395,7 +401,9 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
@ -428,7 +436,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
@ -471,7 +480,9 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -499,12 +510,13 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
@ -531,12 +543,13 @@ def main():
)
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
@ -565,7 +578,7 @@ def main():
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
"Calculating the averaged model over epoch range from "
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)

View File

@ -89,24 +89,20 @@ def get_parser():
"--avg",
type=int,
default=15,
help=(
"Number of checkpoints to average. Automatically select "
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'"
),
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help=(
"Whether to load averaged model. Currently it only supports "
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. "
),
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
@ -140,7 +136,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
add_model_arguments(parser)
@ -172,12 +169,13 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
@ -204,12 +202,13 @@ def main():
)
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
@ -238,7 +237,7 @@ def main():
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
"Calculating the averaged model over epoch range from "
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
@ -277,7 +276,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -94,11 +94,9 @@ def get_parser():
"--checkpoint",
type=str,
required=True,
help=(
"Path to the checkpoint. "
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint()."
),
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
@ -124,12 +122,10 @@ def get_parser():
"sound_files",
type=str,
nargs="+",
help=(
"The input sound file(s) to transcribe. "
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz."
),
"The sample rate has to be 16kHz.",
)
parser.add_argument(
@ -176,7 +172,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -207,9 +204,10 @@ def read_sound_files(
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@ -268,11 +266,15 @@ def main():
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyps = []
@ -304,7 +306,10 @@ def main():
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@ -345,7 +350,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -85,7 +85,9 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def add_model_arguments(parser: argparse.ArgumentParser):
@ -211,7 +213,8 @@ def get_parser():
"--initial-lr",
type=float,
default=0.003,
help="The initial learning rate. This value should not need to be changed.",
help="The initial learning rate. This value should not need "
"to be changed.",
)
parser.add_argument(
@ -234,45 +237,42 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--prune-range",
type=int,
default=5,
help=(
"The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss"
),
help="The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss",
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.25,
help=(
"The scale to smooth the loss with lm (output of prediction network) part."
),
help="The scale to smooth the loss with lm "
"(output of prediction network) part.",
)
parser.add_argument(
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)part.",
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
)
parser.add_argument(
"--simple-loss-scale",
type=float,
default=0.5,
help=(
"To get pruning ranges, we will calculate a simple version"
help="To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss."
),
"with this parameter before adding to the final loss.",
)
parser.add_argument(
@ -599,7 +599,11 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
@ -629,15 +633,22 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -816,7 +827,9 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")

View File

@ -84,7 +84,9 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
)
if "train" in partition:
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cur_num_jobs = num_jobs if ex is None else 80
cur_num_jobs = min(cur_num_jobs, len(cut_set))
@ -119,7 +121,9 @@ def get_args():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state
word = word2id[word]
pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
pieces = [
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False
def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
def generate_lexicon(
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table.
Args:

View File

@ -317,7 +317,9 @@ def lexicon_to_fst(
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
parser.add_argument(
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
)
return parser.parse_args()

View File

@ -88,7 +88,9 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
fsa_disambig = lexicon_to_fst(
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")

View File

@ -30,8 +30,8 @@ with word segmenting:
import argparse
import jieba
import paddle
import jieba
from tqdm import tqdm
paddle.enable_static()

View File

@ -50,15 +50,15 @@ def get_parser():
"-n",
default=1,
type=int,
help=(
"number of characters to split, i.e., aabb -> a a b"
" b with -n 1 and aa bb with -n 2"
),
help="number of characters to split, i.e., \
aabb -> a a b b with -n 1 and aa bb with -n 2",
)
parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
)
parser.add_argument("--space", default="<space>", type=str, help="space symbol")
parser.add_argument(
"--space", default="<space>", type=str, help="space symbol"
)
parser.add_argument(
"--non-lang-syms",
"-l",
@ -66,7 +66,9 @@ def get_parser():
type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
)
parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
parser.add_argument(
"text", type=str, default=False, nargs="?", help="input text"
)
parser.add_argument(
"--trans_type",
"-t",
@ -106,7 +108,8 @@ def token2id(
if token_type == "lazy_pinyin":
text = lazy_pinyin(chars_list)
sub_ids = [
token_table[txt] if txt in token_table else oov_id for txt in text
token_table[txt] if txt in token_table else oov_id
for txt in text
]
ids.append(sub_ids)
else: # token_type = "pinyin"
@ -132,7 +135,9 @@ def main():
if args.text:
f = codecs.open(args.text, encoding="utf-8")
else:
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
f = codecs.getreader("utf-8")(
sys.stdin if is_python2 else sys.stdin.buffer
)
sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer

View File

@ -81,12 +81,10 @@ class AlimeetingAsrDataModule:
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description=(
"These options are used for the preparation of "
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc."
),
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
@ -98,91 +96,75 @@ class AlimeetingAsrDataModule:
"--max-duration",
type=int,
default=200.0,
help=(
"Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM."
),
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help=(
"When enabled, the batches will come from buckets of "
"similar duration (saves padding frames)."
),
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=300,
help=(
"The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets)."
),
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help=(
"When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding."
),
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help=(
"Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch."
),
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help=(
"The amount of padding (in seconds) inserted between "
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used."
),
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help=(
"When enabled, use on-the-fly cut mixing and feature "
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available."
),
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help=(
"When enabled (=default), the examples will be shuffled for each epoch."
),
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help=(
"When enabled, each batch will have the "
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it."
),
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that collect the batches.",
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
@ -196,22 +178,18 @@ class AlimeetingAsrDataModule:
"--spec-aug-time-warp-factor",
type=int,
default=80,
help=(
"Used only when --enable-spec-aug is True. "
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp."
),
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help=(
"When enabled, select noise from MUSAN and mix it"
"with training dataset. "
),
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
def train_dataloaders(
@ -227,20 +205,24 @@ class AlimeetingAsrDataModule:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
cuts_musan = load_manifest(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
"Using cut concatenation with duration factor "
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
@ -255,7 +237,9 @@ class AlimeetingAsrDataModule:
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
@ -298,7 +282,9 @@ class AlimeetingAsrDataModule:
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
@ -355,7 +341,9 @@ class AlimeetingAsrDataModule:
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts,
)
else:

View File

@ -70,7 +70,11 @@ from beam_search import (
from lhotse.cut import Cut
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@ -89,30 +93,25 @@ def get_parser():
"--epoch",
type=int,
default=28,
help=(
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--batch",
type=int,
default=None,
help=(
"It specifies the batch checkpoint to use for decoding."
"Note: Epoch counts from 0."
),
help="It specifies the batch checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help=(
"Number of checkpoints to average. Automatically select "
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. "
),
"'--epoch'. ",
)
parser.add_argument(
@ -194,7 +193,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -249,7 +249,9 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
@ -264,7 +266,10 @@ def decode_one_batch(
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@ -310,7 +315,11 @@ def decode_one_batch(
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
@ -381,7 +390,9 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
@ -414,7 +425,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
@ -551,7 +563,8 @@ def main():
)
dev_shards = [
str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
str(path)
for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
]
cuts_dev_webdataset = CutSet.from_webdataset(
dev_shards,
@ -561,7 +574,8 @@ def main():
)
test_shards = [
str(path) for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
str(path)
for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
]
cuts_test_webdataset = CutSet.from_webdataset(
test_shards,
@ -574,7 +588,9 @@ def main():
return 1.0 <= c.duration
cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt)
cuts_test_webdataset = cuts_test_webdataset.filter(remove_short_and_long_utt)
cuts_test_webdataset = cuts_test_webdataset.filter(
remove_short_and_long_utt
)
dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset)
test_dl = alimeeting.test_dataloaders(cuts_test_webdataset)

View File

@ -62,20 +62,17 @@ def get_parser():
"--epoch",
type=int,
default=28,
help=(
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help=(
"Number of checkpoints to average. Automatically select "
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. "
),
"'--epoch'. ",
)
parser.add_argument(
@ -106,7 +103,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser
@ -175,7 +173,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -85,11 +85,9 @@ def get_parser():
"--checkpoint",
type=str,
required=True,
help=(
"Path to the checkpoint. "
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint()."
),
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
@ -114,12 +112,10 @@ def get_parser():
"sound_files",
type=str,
nargs="+",
help=(
"The input sound file(s) to transcribe. "
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz."
),
"The sample rate has to be 16kHz.",
)
parser.add_argument(
@ -166,7 +162,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
@ -196,9 +193,10 @@ def read_sound_files(
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@ -259,7 +257,9 @@ def main():
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
@ -284,7 +284,10 @@ def main():
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@ -336,7 +339,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -81,7 +81,9 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
@ -185,45 +187,42 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--prune-range",
type=int,
default=5,
help=(
"The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss"
),
help="The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss",
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.25,
help=(
"The scale to smooth the loss with lm (output of prediction network) part."
),
help="The scale to smooth the loss with lm "
"(output of prediction network) part.",
)
parser.add_argument(
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)part.",
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
)
parser.add_argument(
"--simple-loss-scale",
type=float,
default=0.5,
help=(
"To get pruning ranges, we will calculate a simple version"
help="To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss."
),
"with this parameter before adding to the final loss.",
)
parser.add_argument(
@ -543,15 +542,22 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -705,7 +711,9 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")

View File

@ -25,10 +25,15 @@ from random import Random
from typing import List, Tuple
import torch
from lhotse import ( # fmt: off; See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527; fmt: on
from lhotse import (
CutSet,
Fbank,
FbankConfig,
# fmt: off
# See the following for why LilcomChunkyWriter is preferred
# https://github.com/k2-fsa/icefall/pull/404
# https://github.com/lhotse-speech/lhotse/pull/527
# fmt: on
LilcomChunkyWriter,
RecordingSet,
SupervisionSet,
@ -76,13 +81,17 @@ def make_cutset_blueprints(
cut_sets.append((f"eval{i}", cut_set))
# Create train and valid cuts
logging.info("Loading, trimming, and shuffling the remaining core+noncore cuts.")
logging.info(
"Loading, trimming, and shuffling the remaining core+noncore cuts."
)
recording_set = RecordingSet.from_file(
manifest_dir / "csj_recordings_core.jsonl.gz"
) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz")
supervision_set = SupervisionSet.from_file(
manifest_dir / "csj_supervisions_core.jsonl.gz"
) + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz")
) + SupervisionSet.from_file(
manifest_dir / "csj_supervisions_noncore.jsonl.gz"
)
cut_set = CutSet.from_manifests(
recordings=recording_set,
@ -92,12 +101,15 @@ def make_cutset_blueprints(
cut_set = cut_set.shuffle(Random(RNG_SEED))
logging.info(
f"Creating valid and train cuts from core and noncore,split at {split}."
"Creating valid and train cuts from core and noncore,"
f"split at {split}."
)
valid_set = CutSet.from_cuts(islice(cut_set, 0, split))
train_set = CutSet.from_cuts(islice(cut_set, split, None))
train_set = train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1)
train_set = (
train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1)
)
cut_sets.extend([("valid", valid_set), ("train", train_set)])
@ -110,9 +122,15 @@ def get_args():
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests")
parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features")
parser.add_argument("--split", type=int, default=4000, help="Split at this index")
parser.add_argument(
"--manifest-dir", type=Path, help="Path to save manifests"
)
parser.add_argument(
"--fbank-dir", type=Path, help="Path to save fbank features"
)
parser.add_argument(
"--split", type=int, default=4000, help="Split at this index"
)
return parser.parse_args()
@ -123,7 +141,9 @@ def main():
extractor = Fbank(FbankConfig(num_mel_bins=80))
num_jobs = min(16, os.cpu_count())
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -26,6 +26,7 @@ from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
ARGPARSE_DESCRIPTION = """
This file computes fbank features of the musan dataset.
It looks for manifests in the directory data/manifests.
@ -83,7 +84,9 @@ def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path):
# create chunks of Musan with duration 5 - 10 seconds
musan_cuts = (
CutSet.from_manifests(
recordings=combine(part["recordings"] for part in manifests.values())
recordings=combine(
part["recordings"] for part in manifests.values()
)
)
.cut_into_windows(10.0)
.filter(lambda c: c.duration > 5)
@ -104,15 +107,21 @@ def get_args():
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests")
parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features")
parser.add_argument(
"--manifest-dir", type=Path, help="Path to save manifests"
)
parser.add_argument(
"--fbank-dir", type=Path, help="Path to save fbank features"
)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_musan(args.manifest_dir, args.fbank_dir)

View File

@ -318,3 +318,4 @@ spk_id = 2
= ǐa
= ǐu
= ǐo

View File

@ -318,3 +318,4 @@ spk_id = 2
= ǐa
= ǐu
= ǐo

Some files were not shown because too many files have changed in this diff Show More