Revert "Apply new Black style changes"

This commit is contained in:
Fangjun Kuang 2022-11-17 20:19:32 +08:00 committed by GitHub
parent a7fbb18bdc
commit 60317120ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
441 changed files with 14535 additions and 6789 deletions

View File

@ -1,2 +0,0 @@
# Migrate to 88 characters per line (see: https://github.com/lhotse-speech/lhotse/issues/890)
d110b04ad389134c82fa314e3aafc7b40043efb0

View File

@ -45,18 +45,17 @@ jobs:
- name: Install Python dependencies - name: Install Python dependencies
run: | run: |
python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4
# Click issue fixed in https://github.com/psf/black/pull/2966 # See https://github.com/psf/black/issues/2964
# The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4
- name: Run flake8 - name: Run flake8
shell: bash shell: bash
working-directory: ${{github.workspace}} working-directory: ${{github.workspace}}
run: | run: |
# stop the build if there are Python syntax errors or undefined names # stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics flake8 . --count --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 .
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 \
--statistics --extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503
- name: Run black - name: Run black
shell: bash shell: bash

View File

@ -1,38 +1,26 @@
repos: repos:
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 22.3.0 rev: 21.6b0
hooks: hooks:
- id: black - id: black
args: ["--line-length=88"] args: [--line-length=80]
additional_dependencies: ['click==8.1.0'] additional_dependencies: ['click==8.0.1']
exclude: icefall\/__init__\.py exclude: icefall\/__init__\.py
- repo: https://github.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 5.0.4 rev: 3.9.2
hooks: hooks:
- id: flake8 - id: flake8
args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503"] args: [--max-line-length=80]
# What are we ignoring here?
# E203: whitespace before ':'
# E266: too many leading '#' for block comment
# E501: line too long
# F401: module imported but unused
# E402: module level import not at top of file
# F403: 'from module import *' used; unable to detect undefined names
# F841: local variable is assigned to but never used
# W503: line break before binary operator
# In addition, the default ignore list is:
# E121,E123,E126,E226,E24,E704,W503,W504
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 5.10.1 rev: 5.9.2
hooks: hooks:
- id: isort - id: isort
args: ["--profile=black"] args: [--profile=black, --line-length=80]
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.2.0 rev: v4.0.1
hooks: hooks:
- id: check-executables-have-shebangs - id: check-executables-have-shebangs
- id: end-of-file-fixer - id: end-of-file-fixer

View File

@ -88,3 +88,4 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall WORKDIR /workspace/icefall

View File

@ -19,3 +19,4 @@ It can be downloaded from `<https://www.openslr.org/33/>`_
tdnn_lstm_ctc tdnn_lstm_ctc
conformer_ctc conformer_ctc
stateless_transducer stateless_transducer

View File

@ -6,3 +6,4 @@ TIMIT
tdnn_ligru_ctc tdnn_ligru_ctc
tdnn_lstm_ctc tdnn_lstm_ctc

View File

@ -87,7 +87,9 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
) )
if "train" in partition: if "train" in partition:
cut_set = ( cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
) )
cut_set = cut_set.compute_and_store_features( cut_set = cut_set.compute_and_store_features(
extractor=extractor, extractor=extractor,
@ -114,7 +116,9 @@ def get_args():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state cur_state = loop_state
word = word2id[word] word = word2id[word]
pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces] pieces = [
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1): for i in range(len(pieces) - 1):
w = word if i == 0 else eps w = word if i == 0 else eps
@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False return False
def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: def generate_lexicon(
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table. """Generate a lexicon from a word list and token_sym_table.
Args: Args:

View File

@ -317,7 +317,9 @@ def lexicon_to_fst(
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") parser.add_argument(
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
)
return parser.parse_args() return parser.parse_args()

View File

@ -88,7 +88,9 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L") fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig = lexicon_to_fst(
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig") fsa_disambig.draw("L_disambig.pdf", title="L_disambig")

View File

@ -50,15 +50,15 @@ def get_parser():
"-n", "-n",
default=1, default=1,
type=int, type=int,
help=( help="number of characters to split, i.e., \
"number of characters to split, i.e., aabb -> a a b" aabb -> a a b b with -n 1 and aa bb with -n 2",
" b with -n 1 and aa bb with -n 2"
),
) )
parser.add_argument( parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns" "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
) )
parser.add_argument("--space", default="<space>", type=str, help="space symbol") parser.add_argument(
"--space", default="<space>", type=str, help="space symbol"
)
parser.add_argument( parser.add_argument(
"--non-lang-syms", "--non-lang-syms",
"-l", "-l",
@ -66,7 +66,9 @@ def get_parser():
type=str, type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.", help="list of non-linguistic symobles, e.g., <NOISE> etc.",
) )
parser.add_argument("text", type=str, default=False, nargs="?", help="input text") parser.add_argument(
"text", type=str, default=False, nargs="?", help="input text"
)
parser.add_argument( parser.add_argument(
"--trans_type", "--trans_type",
"-t", "-t",
@ -106,7 +108,8 @@ def token2id(
if token_type == "lazy_pinyin": if token_type == "lazy_pinyin":
text = lazy_pinyin(chars_list) text = lazy_pinyin(chars_list)
sub_ids = [ sub_ids = [
token_table[txt] if txt in token_table else oov_id for txt in text token_table[txt] if txt in token_table else oov_id
for txt in text
] ]
ids.append(sub_ids) ids.append(sub_ids)
else: # token_type = "pinyin" else: # token_type = "pinyin"
@ -132,7 +135,9 @@ def main():
if args.text: if args.text:
f = codecs.open(args.text, encoding="utf-8") f = codecs.open(args.text, encoding="utf-8")
else: else:
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) f = codecs.getreader("utf-8")(
sys.stdin if is_python2 else sys.stdin.buffer
)
sys.stdout = codecs.getwriter("utf-8")( sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer sys.stdout if is_python2 else sys.stdout.buffer

View File

@ -113,3 +113,4 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
./local/prepare_char.py ./local/prepare_char.py
fi fi
fi fi

View File

@ -81,12 +81,10 @@ class Aidatatang_200zhAsrDataModule:
def add_arguments(cls, parser: argparse.ArgumentParser): def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group( group = parser.add_argument_group(
title="ASR data related options", title="ASR data related options",
description=( description="These options are used for the preparation of "
"These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the " "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data " "effective batch sizes, sampling strategies, applied data "
"augmentations, etc." "augmentations, etc.",
),
) )
group.add_argument( group.add_argument(
"--manifest-dir", "--manifest-dir",
@ -98,91 +96,75 @@ class Aidatatang_200zhAsrDataModule:
"--max-duration", "--max-duration",
type=int, type=int,
default=200.0, default=200.0,
help=( help="Maximum pooled recordings duration (seconds) in a "
"Maximum pooled recordings duration (seconds) in a " "single batch. You can reduce it if it causes CUDA OOM.",
"single batch. You can reduce it if it causes CUDA OOM."
),
) )
group.add_argument( group.add_argument(
"--bucketing-sampler", "--bucketing-sampler",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled, the batches will come from buckets of "
"When enabled, the batches will come from buckets of " "similar duration (saves padding frames).",
"similar duration (saves padding frames)."
),
) )
group.add_argument( group.add_argument(
"--num-buckets", "--num-buckets",
type=int, type=int,
default=300, default=300,
help=( help="The number of buckets for the DynamicBucketingSampler"
"The number of buckets for the DynamicBucketingSampler" "(you might want to increase it for larger datasets).",
"(you might want to increase it for larger datasets)."
),
) )
group.add_argument( group.add_argument(
"--concatenate-cuts", "--concatenate-cuts",
type=str2bool, type=str2bool,
default=False, default=False,
help=( help="When enabled, utterances (cuts) will be concatenated "
"When enabled, utterances (cuts) will be concatenated " "to minimize the amount of padding.",
"to minimize the amount of padding."
),
) )
group.add_argument( group.add_argument(
"--duration-factor", "--duration-factor",
type=float, type=float,
default=1.0, default=1.0,
help=( help="Determines the maximum duration of a concatenated cut "
"Determines the maximum duration of a concatenated cut " "relative to the duration of the longest cut in a batch.",
"relative to the duration of the longest cut in a batch."
),
) )
group.add_argument( group.add_argument(
"--gap", "--gap",
type=float, type=float,
default=1.0, default=1.0,
help=( help="The amount of padding (in seconds) inserted between "
"The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when " "concatenated cuts. This padding is filled with noise when "
"noise augmentation is used." "noise augmentation is used.",
),
) )
group.add_argument( group.add_argument(
"--on-the-fly-feats", "--on-the-fly-feats",
type=str2bool, type=str2bool,
default=False, default=False,
help=( help="When enabled, use on-the-fly cut mixing and feature "
"When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests " "extraction. Will drop existing precomputed feature manifests "
"if available." "if available.",
),
) )
group.add_argument( group.add_argument(
"--shuffle", "--shuffle",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled (=default), the examples will be "
"When enabled (=default), the examples will be shuffled for each epoch." "shuffled for each epoch.",
),
) )
group.add_argument( group.add_argument(
"--return-cuts", "--return-cuts",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled, each batch will have the "
"When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that " "field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it." "were used to construct it.",
),
) )
group.add_argument( group.add_argument(
"--num-workers", "--num-workers",
type=int, type=int,
default=2, default=2,
help="The number of training dataloader workers that collect the batches.", help="The number of training dataloader workers that "
"collect the batches.",
) )
group.add_argument( group.add_argument(
@ -196,22 +178,18 @@ class Aidatatang_200zhAsrDataModule:
"--spec-aug-time-warp-factor", "--spec-aug-time-warp-factor",
type=int, type=int,
default=80, default=80,
help=( help="Used only when --enable-spec-aug is True. "
"Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. " "It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. " "Larger values mean more warping. "
"A value less than 1 means to disable time warp." "A value less than 1 means to disable time warp.",
),
) )
group.add_argument( group.add_argument(
"--enable-musan", "--enable-musan",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled, select noise from MUSAN and mix it"
"When enabled, select noise from MUSAN and mix it" "with training dataset. ",
"with training dataset. "
),
) )
def train_dataloaders( def train_dataloaders(
@ -227,20 +205,24 @@ class Aidatatang_200zhAsrDataModule:
The state dict for the training sampler. The state dict for the training sampler.
""" """
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") cuts_musan = load_manifest(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms = [] transforms = []
if self.args.enable_musan: if self.args.enable_musan:
logging.info("Enable MUSAN") logging.info("Enable MUSAN")
transforms.append( transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
if self.args.concatenate_cuts: if self.args.concatenate_cuts:
logging.info( logging.info(
"Using cut concatenation with duration factor " f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}." f"{self.args.duration_factor} and gap {self.args.gap}."
) )
# Cut concatenation should be the first transform in the list, # Cut concatenation should be the first transform in the list,
@ -255,7 +237,9 @@ class Aidatatang_200zhAsrDataModule:
input_transforms = [] input_transforms = []
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
logging.info("Enable SpecAugment") logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version. # Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is # In different Lhotse's versions, the default of num_frame_masks is
# different. # different.
@ -298,7 +282,9 @@ class Aidatatang_200zhAsrDataModule:
# Drop feats to be on the safe side. # Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -354,7 +340,9 @@ class Aidatatang_200zhAsrDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:

View File

@ -69,7 +69,11 @@ from beam_search import (
) )
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -88,30 +92,25 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=28, default=28,
help=( help="It specifies the checkpoint to use for decoding."
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." "Note: Epoch counts from 0.",
),
) )
parser.add_argument( parser.add_argument(
"--batch", "--batch",
type=int, type=int,
default=None, default=None,
help=( help="It specifies the batch checkpoint to use for decoding."
"It specifies the batch checkpoint to use for decoding." "Note: Epoch counts from 0.",
"Note: Epoch counts from 0."
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help=( help="Number of checkpoints to average. Automatically select "
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. " "'--epoch'. ",
),
) )
parser.add_argument( parser.add_argument(
@ -193,7 +192,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -249,7 +249,9 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
@ -264,7 +266,10 @@ def decode_one_batch(
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -310,7 +315,11 @@ def decode_one_batch(
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
return { return {
f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
} }
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -381,7 +390,9 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -414,7 +425,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)

View File

@ -62,20 +62,17 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=28, default=28,
help=( help="It specifies the checkpoint to use for decoding."
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." "Note: Epoch counts from 0.",
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help=( help="Number of checkpoints to average. Automatically select "
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. " "'--epoch'. ",
),
) )
parser.add_argument( parser.add_argument(
@ -106,7 +103,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
return parser return parser
@ -175,7 +173,9 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -85,11 +85,9 @@ def get_parser():
"--checkpoint", "--checkpoint",
type=str, type=str,
required=True, required=True,
help=( help="Path to the checkpoint. "
"Path to the checkpoint. "
"The checkpoint is assumed to be saved by " "The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint()." "icefall.checkpoint.save_checkpoint().",
),
) )
parser.add_argument( parser.add_argument(
@ -114,12 +112,10 @@ def get_parser():
"sound_files", "sound_files",
type=str, type=str,
nargs="+", nargs="+",
help=( help="The input sound file(s) to transcribe. "
"The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). " "Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. " "For example, wav and flac are supported. "
"The sample rate has to be 16kHz." "The sample rate has to be 16kHz.",
),
) )
parser.add_argument( parser.add_argument(
@ -166,7 +162,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -196,9 +193,10 @@ def read_sound_files(
ans = [] ans = []
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert ( assert sample_rate == expected_sample_rate, (
sample_rate == expected_sample_rate f"expected sample rate: {expected_sample_rate}. "
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" f"Given: {sample_rate}"
)
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
return ans return ans
@ -259,7 +257,9 @@ def main():
features = fbank(waves) features = fbank(waves)
feature_lengths = [f.size(0) for f in features] feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device) feature_lengths = torch.tensor(feature_lengths, device=device)
@ -284,7 +284,10 @@ def main():
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -336,7 +339,9 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -81,7 +81,9 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
@ -185,45 +187,42 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--prune-range", "--prune-range",
type=int, type=int,
default=5, default=5,
help=( help="The prune range for rnnt loss, it means how many symbols(context)"
"The prune range for rnnt loss, it means how many symbols(context)" "we are using to compute the loss",
"we are using to compute the loss"
),
) )
parser.add_argument( parser.add_argument(
"--lm-scale", "--lm-scale",
type=float, type=float,
default=0.25, default=0.25,
help=( help="The scale to smooth the loss with lm "
"The scale to smooth the loss with lm (output of prediction network) part." "(output of prediction network) part.",
),
) )
parser.add_argument( parser.add_argument(
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)part.", help="The scale to smooth the loss with am (output of encoder network)"
"part.",
) )
parser.add_argument( parser.add_argument(
"--simple-loss-scale", "--simple-loss-scale",
type=float, type=float,
default=0.5, default=0.5,
help=( help="To get pruning ranges, we will calculate a simple version"
"To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for" "loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss" "training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss." "with this parameter before adding to the final loss.",
),
) )
parser.add_argument( parser.add_argument(
@ -543,15 +542,22 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge, # overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet. # in case it had not fully learned the alignment yet.
pruned_loss_scale = ( pruned_loss_scale = (
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) 0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
) )
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -705,7 +711,9 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")

View File

@ -157,7 +157,9 @@ class ConformerEncoderLayer(nn.Module):
normalize_before: bool = True, normalize_before: bool = True,
) -> None: ) -> None:
super(ConformerEncoderLayer, self).__init__() super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0
)
self.feed_forward = nn.Sequential( self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
@ -175,14 +177,18 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff_macaron = nn.LayerNorm(
d_model
) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5 self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.norm_final = nn.LayerNorm(
d_model
) # for the final output of the block
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
@ -216,7 +222,9 @@ class ConformerEncoderLayer(nn.Module):
residual = src residual = src
if self.normalize_before: if self.normalize_before:
src = self.norm_ff_macaron(src) src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) src = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(src)
)
if not self.normalize_before: if not self.normalize_before:
src = self.norm_ff_macaron(src) src = self.norm_ff_macaron(src)
@ -335,7 +343,9 @@ class RelPositionalEncoding(torch.nn.Module):
""" """
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: def __init__(
self, d_model: int, dropout_rate: float, max_len: int = 5000
) -> None:
"""Construct an PositionalEncoding object.""" """Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__() super(RelPositionalEncoding, self).__init__()
self.d_model = d_model self.d_model = d_model
@ -351,7 +361,9 @@ class RelPositionalEncoding(torch.nn.Module):
# the length of self.pe is 2 * input_len - 1 # the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1: if self.pe.size(1) >= x.size(1) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device # Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device) self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return return
# Suppose `i` means to the position of query vector and `j` means the # Suppose `i` means to the position of query vector and `j` means the
@ -621,9 +633,9 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value): if torch.equal(query, key) and torch.equal(key, value):
# self-attention # self-attention
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( q, k, v = nn.functional.linear(
3, dim=-1 query, in_proj_weight, in_proj_bias
) ).chunk(3, dim=-1)
elif torch.equal(key, value): elif torch.equal(key, value):
# encoder-decoder attention # encoder-decoder attention
@ -691,25 +703,33 @@ class RelPositionMultiheadAttention(nn.Module):
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0) attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError("The size of the 2D attn_mask is not correct.") raise RuntimeError(
"The size of the 2D attn_mask is not correct."
)
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [ if list(attn_mask.size()) != [
bsz * num_heads, bsz * num_heads,
query.size(0), query.size(0),
key.size(0), key.size(0),
]: ]:
raise RuntimeError("The size of the 3D attn_mask is not correct.") raise RuntimeError(
"The size of the 3D attn_mask is not correct."
)
else: else:
raise RuntimeError( raise RuntimeError(
"attn_mask's dimension {} is not supported".format(attn_mask.dim()) "attn_mask's dimension {} is not supported".format(
attn_mask.dim()
)
) )
# attn_mask's dim is 3 now. # attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool # convert ByteTensor key_padding_mask to bool
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: if (
key_padding_mask is not None
and key_padding_mask.dtype == torch.uint8
):
warnings.warn( warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor" "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
" instead."
) )
key_padding_mask = key_padding_mask.to(torch.bool) key_padding_mask = key_padding_mask.to(torch.bool)
@ -746,7 +766,9 @@ class RelPositionMultiheadAttention(nn.Module):
# first compute matrix a and matrix c # first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) matrix_ac = torch.matmul(
q_with_bias_u, k
) # (batch, head, time1, time2)
# compute matrix b and matrix d # compute matrix b and matrix d
matrix_bd = torch.matmul( matrix_bd = torch.matmul(
@ -758,7 +780,9 @@ class RelPositionMultiheadAttention(nn.Module):
matrix_ac + matrix_bd matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2) ) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1
)
assert list(attn_output_weights.size()) == [ assert list(attn_output_weights.size()) == [
bsz * num_heads, bsz * num_heads,
@ -792,9 +816,13 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = torch.bmm(attn_output_weights, v) attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = ( attn_output = (
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn_output.transpose(0, 1)
.contiguous()
.view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias
) )
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights: if need_weights:
# average attention weights over heads # average attention weights over heads
@ -817,7 +845,9 @@ class ConvolutionModule(nn.Module):
""" """
def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: def __init__(
self, channels: int, kernel_size: int, bias: bool = True
) -> None:
"""Construct an ConvolutionModule object.""" """Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__() super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding # kernerl_size should be a odd number for 'SAME' padding

View File

@ -58,19 +58,16 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=49, default=49,
help=( help="It specifies the checkpoint to use for decoding."
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." "Note: Epoch counts from 0.",
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=20, default=20,
help=( help="Number of checkpoints to average. Automatically select "
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. " "'--epoch'. ",
),
) )
parser.add_argument( parser.add_argument(
@ -404,7 +401,9 @@ def decode_dataset(
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -432,7 +431,9 @@ def save_results(
# we compute CER for aishell dataset. # we compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
@ -440,7 +441,9 @@ def save_results(
test_set_wers[key] = wer test_set_wers[key] = wer
if enable_log: if enable_log:
logging.info("Wrote detailed error stats to {}".format(errs_filename)) logging.info(
"Wrote detailed error stats to {}".format(errs_filename)
)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt" errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
@ -559,7 +562,9 @@ def main():
eos_id=eos_id, eos_id=eos_id,
) )
save_results(params=params, test_set_name=test_set, results_dict=results_dict) save_results(
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!") logging.info("Done!")

View File

@ -40,20 +40,17 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=84, default=84,
help=( help="It specifies the checkpoint to use for decoding."
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." "Note: Epoch counts from 0.",
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=25, default=25,
help=( help="Number of checkpoints to average. Automatically select "
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. " "'--epoch'. ",
),
) )
parser.add_argument( parser.add_argument(
@ -160,7 +157,9 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -46,29 +46,27 @@ def get_parser():
"--checkpoint", "--checkpoint",
type=str, type=str,
required=True, required=True,
help=( help="Path to the checkpoint. "
"Path to the checkpoint. "
"The checkpoint is assumed to be saved by " "The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint()." "icefall.checkpoint.save_checkpoint().",
),
) )
parser.add_argument( parser.add_argument(
"--tokens-file", "--tokens-file",
type=str, type=str,
help="Path to tokens.txtUsed only when method is ctc-decoding", help="Path to tokens.txt" "Used only when method is ctc-decoding",
) )
parser.add_argument( parser.add_argument(
"--words-file", "--words-file",
type=str, type=str,
help="Path to words.txtUsed when method is NOT ctc-decoding", help="Path to words.txt" "Used when method is NOT ctc-decoding",
) )
parser.add_argument( parser.add_argument(
"--HLG", "--HLG",
type=str, type=str,
help="Path to HLG.pt.Used when method is NOT ctc-decoding", help="Path to HLG.pt." "Used when method is NOT ctc-decoding",
) )
parser.add_argument( parser.add_argument(
@ -165,12 +163,10 @@ def get_parser():
"sound_files", "sound_files",
type=str, type=str,
nargs="+", nargs="+",
help=( help="The input sound file(s) to transcribe. "
"The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). " "Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. " "For example, wav and flac are supported. "
"The sample rate has to be 16kHz." "The sample rate has to be 16kHz.",
),
) )
return parser return parser
@ -214,9 +210,10 @@ def read_sound_files(
ans = [] ans = []
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert ( assert sample_rate == expected_sample_rate, (
sample_rate == expected_sample_rate f"expected sample rate: {expected_sample_rate}. "
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" f"Given: {sample_rate}"
)
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
return ans return ans
@ -277,7 +274,9 @@ def main():
logging.info("Decoding started") logging.info("Decoding started")
features = fbank(waves) features = fbank(waves)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
# Note: We don't use key padding mask for attention during decoding # Note: We don't use key padding mask for attention during decoding
with torch.no_grad(): with torch.no_grad():
@ -372,7 +371,9 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -42,9 +42,13 @@ class Conv2dSubsampling(nn.Module):
assert idim >= 7 assert idim >= 7
super().__init__() super().__init__()
self.conv = nn.Sequential( self.conv = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.Conv2d(
in_channels=1, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(), nn.ReLU(),
nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.Conv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(), nn.ReLU(),
) )
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
@ -128,13 +132,17 @@ class VggSubsampling(nn.Module):
) )
) )
layers.append( layers.append(
torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) torch.nn.MaxPool2d(
kernel_size=2, stride=2, padding=0, ceil_mode=True
)
) )
cur_channels = block_dim cur_channels = block_dim
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) self.out = nn.Linear(
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x. """Subsample x.

View File

@ -16,8 +16,9 @@
# limitations under the License. # limitations under the License.
from subsampling import Conv2dSubsampling
from subsampling import VggSubsampling
import torch import torch
from subsampling import Conv2dSubsampling, VggSubsampling
def test_conv2d_subsampling(): def test_conv2d_subsampling():

View File

@ -382,7 +382,9 @@ def compute_loss(
# #
# See https://github.com/k2-fsa/icefall/issues/97 # See https://github.com/k2-fsa/icefall/issues/97
# for more details # for more details
unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) unsorted_token_ids = graph_compiler.texts_to_ids(
supervisions["text"]
)
att_loss = mmodel.decoder_forward( att_loss = mmodel.decoder_forward(
encoder_memory, encoder_memory,
memory_mask, memory_mask,
@ -518,7 +520,9 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
@ -626,7 +630,9 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate cur_lr = optimizer._rate
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0: if rank == 0:

View File

@ -149,7 +149,9 @@ class Transformer(nn.Module):
norm=decoder_norm, norm=decoder_norm,
) )
self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_output_layer = torch.nn.Linear(
d_model, self.decoder_num_class
)
self.decoder_criterion = LabelSmoothingLoss() self.decoder_criterion = LabelSmoothingLoss()
else: else:
@ -181,7 +183,9 @@ class Transformer(nn.Module):
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x) x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision
)
x = self.ctc_output(encoder_memory) x = self.ctc_output(encoder_memory)
return x, encoder_memory, memory_key_padding_mask return x, encoder_memory, memory_key_padding_mask
@ -262,17 +266,23 @@ class Transformer(nn.Module):
""" """
ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in] ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out] ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device device = memory.device
ys_in_pad = ys_in_pad.to(device) ys_in_pad = ys_in_pad.to(device)
ys_out_pad = ys_out_pad.to(device) ys_out_pad = ys_out_pad.to(device)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
device
)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask # TODO: Use length information to create the decoder padding mask
@ -333,17 +343,23 @@ class Transformer(nn.Module):
ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in] ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out] ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
device
)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask # TODO: Use length information to create the decoder padding mask
@ -616,7 +632,9 @@ def _get_activation_fn(activation: str):
elif activation == "gelu": elif activation == "gelu":
return nn.functional.gelu return nn.functional.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) raise RuntimeError(
"activation should be relu/gelu, not {}".format(activation)
)
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):
@ -818,7 +836,9 @@ def encoder_padding_mask(
1, 1,
).to(torch.int32) ).to(torch.int32)
lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] lengths = [
0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
]
for idx in range(supervision_segments.size(0)): for idx in range(supervision_segments.size(0)):
# Note: TorchScript doesn't allow to unpack tensors as tuples # Note: TorchScript doesn't allow to unpack tensors as tuples
sequence_idx = supervision_segments[idx, 0].item() sequence_idx = supervision_segments[idx, 0].item()
@ -839,7 +859,9 @@ def encoder_padding_mask(
return mask return mask
def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: def decoder_padding_mask(
ys_pad: torch.Tensor, ignore_id: int = -1
) -> torch.Tensor:
"""Generate a length mask for input. """Generate a length mask for input.
The masked position are filled with True, The masked position are filled with True,

View File

@ -157,7 +157,9 @@ class ConformerEncoderLayer(nn.Module):
normalize_before: bool = True, normalize_before: bool = True,
) -> None: ) -> None:
super(ConformerEncoderLayer, self).__init__() super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0
)
self.feed_forward = nn.Sequential( self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
@ -175,14 +177,18 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff_macaron = nn.LayerNorm(
d_model
) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5 self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.norm_final = nn.LayerNorm(
d_model
) # for the final output of the block
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
@ -216,7 +222,9 @@ class ConformerEncoderLayer(nn.Module):
residual = src residual = src
if self.normalize_before: if self.normalize_before:
src = self.norm_ff_macaron(src) src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) src = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(src)
)
if not self.normalize_before: if not self.normalize_before:
src = self.norm_ff_macaron(src) src = self.norm_ff_macaron(src)
@ -335,7 +343,9 @@ class RelPositionalEncoding(torch.nn.Module):
""" """
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: def __init__(
self, d_model: int, dropout_rate: float, max_len: int = 5000
) -> None:
"""Construct an PositionalEncoding object.""" """Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__() super(RelPositionalEncoding, self).__init__()
self.d_model = d_model self.d_model = d_model
@ -351,7 +361,9 @@ class RelPositionalEncoding(torch.nn.Module):
# the length of self.pe is 2 * input_len - 1 # the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1: if self.pe.size(1) >= x.size(1) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device # Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device) self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return return
# Suppose `i` means to the position of query vector and `j` means the # Suppose `i` means to the position of query vector and `j` means the
@ -621,9 +633,9 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value): if torch.equal(query, key) and torch.equal(key, value):
# self-attention # self-attention
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( q, k, v = nn.functional.linear(
3, dim=-1 query, in_proj_weight, in_proj_bias
) ).chunk(3, dim=-1)
elif torch.equal(key, value): elif torch.equal(key, value):
# encoder-decoder attention # encoder-decoder attention
@ -691,25 +703,33 @@ class RelPositionMultiheadAttention(nn.Module):
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0) attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError("The size of the 2D attn_mask is not correct.") raise RuntimeError(
"The size of the 2D attn_mask is not correct."
)
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [ if list(attn_mask.size()) != [
bsz * num_heads, bsz * num_heads,
query.size(0), query.size(0),
key.size(0), key.size(0),
]: ]:
raise RuntimeError("The size of the 3D attn_mask is not correct.") raise RuntimeError(
"The size of the 3D attn_mask is not correct."
)
else: else:
raise RuntimeError( raise RuntimeError(
"attn_mask's dimension {} is not supported".format(attn_mask.dim()) "attn_mask's dimension {} is not supported".format(
attn_mask.dim()
)
) )
# attn_mask's dim is 3 now. # attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool # convert ByteTensor key_padding_mask to bool
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: if (
key_padding_mask is not None
and key_padding_mask.dtype == torch.uint8
):
warnings.warn( warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor" "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
" instead."
) )
key_padding_mask = key_padding_mask.to(torch.bool) key_padding_mask = key_padding_mask.to(torch.bool)
@ -746,7 +766,9 @@ class RelPositionMultiheadAttention(nn.Module):
# first compute matrix a and matrix c # first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) matrix_ac = torch.matmul(
q_with_bias_u, k
) # (batch, head, time1, time2)
# compute matrix b and matrix d # compute matrix b and matrix d
matrix_bd = torch.matmul( matrix_bd = torch.matmul(
@ -758,7 +780,9 @@ class RelPositionMultiheadAttention(nn.Module):
matrix_ac + matrix_bd matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2) ) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1
)
assert list(attn_output_weights.size()) == [ assert list(attn_output_weights.size()) == [
bsz * num_heads, bsz * num_heads,
@ -792,9 +816,13 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = torch.bmm(attn_output_weights, v) attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = ( attn_output = (
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn_output.transpose(0, 1)
.contiguous()
.view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias
) )
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights: if need_weights:
# average attention weights over heads # average attention weights over heads
@ -817,7 +845,9 @@ class ConvolutionModule(nn.Module):
""" """
def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: def __init__(
self, channels: int, kernel_size: int, bias: bool = True
) -> None:
"""Construct an ConvolutionModule object.""" """Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__() super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding # kernerl_size should be a odd number for 'SAME' padding

View File

@ -59,19 +59,16 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=49, default=49,
help=( help="It specifies the checkpoint to use for decoding."
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." "Note: Epoch counts from 0.",
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=20, default=20,
help=( help="Number of checkpoints to average. Automatically select "
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. " "'--epoch'. ",
),
) )
parser.add_argument( parser.add_argument(
@ -416,7 +413,9 @@ def decode_dataset(
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -444,7 +443,9 @@ def save_results(
# we compute CER for aishell dataset. # we compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
@ -452,7 +453,9 @@ def save_results(
test_set_wers[key] = wer test_set_wers[key] = wer
if enable_log: if enable_log:
logging.info("Wrote detailed error stats to {}".format(errs_filename)) logging.info(
"Wrote detailed error stats to {}".format(errs_filename)
)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt" errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
@ -547,7 +550,9 @@ def main():
if params.export: if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return return
model.to(device) model.to(device)
@ -576,7 +581,9 @@ def main():
eos_id=eos_id, eos_id=eos_id,
) )
save_results(params=params, test_set_name=test_set, results_dict=results_dict) save_results(
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!") logging.info("Done!")

View File

@ -42,9 +42,13 @@ class Conv2dSubsampling(nn.Module):
assert idim >= 7 assert idim >= 7
super().__init__() super().__init__()
self.conv = nn.Sequential( self.conv = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.Conv2d(
in_channels=1, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(), nn.ReLU(),
nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.Conv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(), nn.ReLU(),
) )
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
@ -128,13 +132,17 @@ class VggSubsampling(nn.Module):
) )
) )
layers.append( layers.append(
torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) torch.nn.MaxPool2d(
kernel_size=2, stride=2, padding=0, ceil_mode=True
)
) )
cur_channels = block_dim cur_channels = block_dim
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) self.out = nn.Linear(
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x. """Subsample x.

View File

@ -511,7 +511,9 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
@ -623,7 +625,9 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate cur_lr = optimizer._rate
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0: if rank == 0:

View File

@ -149,7 +149,9 @@ class Transformer(nn.Module):
norm=decoder_norm, norm=decoder_norm,
) )
self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_output_layer = torch.nn.Linear(
d_model, self.decoder_num_class
)
self.decoder_criterion = LabelSmoothingLoss() self.decoder_criterion = LabelSmoothingLoss()
else: else:
@ -181,7 +183,9 @@ class Transformer(nn.Module):
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x) x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision
)
x = self.ctc_output(encoder_memory) x = self.ctc_output(encoder_memory)
return x, encoder_memory, memory_key_padding_mask return x, encoder_memory, memory_key_padding_mask
@ -262,17 +266,23 @@ class Transformer(nn.Module):
""" """
ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in] ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out] ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device device = memory.device
ys_in_pad = ys_in_pad.to(device) ys_in_pad = ys_in_pad.to(device)
ys_out_pad = ys_out_pad.to(device) ys_out_pad = ys_out_pad.to(device)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
device
)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask # TODO: Use length information to create the decoder padding mask
@ -333,17 +343,23 @@ class Transformer(nn.Module):
ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in] ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out] ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
device
)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask # TODO: Use length information to create the decoder padding mask
@ -616,7 +632,9 @@ def _get_activation_fn(activation: str):
elif activation == "gelu": elif activation == "gelu":
return nn.functional.gelu return nn.functional.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) raise RuntimeError(
"activation should be relu/gelu, not {}".format(activation)
)
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):
@ -818,7 +836,9 @@ def encoder_padding_mask(
1, 1,
).to(torch.int32) ).to(torch.int32)
lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] lengths = [
0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
]
for idx in range(supervision_segments.size(0)): for idx in range(supervision_segments.size(0)):
# Note: TorchScript doesn't allow to unpack tensors as tuples # Note: TorchScript doesn't allow to unpack tensors as tuples
sequence_idx = supervision_segments[idx, 0].item() sequence_idx = supervision_segments[idx, 0].item()
@ -839,7 +859,9 @@ def encoder_padding_mask(
return mask return mask
def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: def decoder_padding_mask(
ys_pad: torch.Tensor, ignore_id: int = -1
) -> torch.Tensor:
"""Generate a length mask for input. """Generate a length mask for input.
The masked position are filled with True, The masked position are filled with True,

View File

@ -87,7 +87,9 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
) )
if "train" in partition: if "train" in partition:
cut_set = ( cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
) )
cut_set = cut_set.compute_and_store_features( cut_set = cut_set.compute_and_store_features(
extractor=extractor, extractor=extractor,
@ -114,7 +116,9 @@ def get_args():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -83,7 +83,9 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
) )
if "train" in partition: if "train" in partition:
cut_set = ( cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
) )
cut_set = cut_set.compute_and_store_features( cut_set = cut_set.compute_and_store_features(
extractor=extractor, extractor=extractor,
@ -109,7 +111,9 @@ def get_args():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state cur_state = loop_state
word = word2id[word] word = word2id[word]
pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces] pieces = [
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1): for i in range(len(pieces) - 1):
w = word if i == 0 else eps w = word if i == 0 else eps
@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False return False
def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: def generate_lexicon(
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table. """Generate a lexicon from a word list and token_sym_table.
Args: Args:

View File

@ -317,7 +317,9 @@ def lexicon_to_fst(
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") parser.add_argument(
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
)
return parser.parse_args() return parser.parse_args()

View File

@ -88,7 +88,9 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L") fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig = lexicon_to_fst(
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig") fsa_disambig.draw("L_disambig.pdf", title="L_disambig")

View File

@ -76,7 +76,11 @@ from beam_search import (
) )
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -114,11 +118,9 @@ def get_parser():
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help=( help="Number of checkpoints to average. Automatically select "
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'" "'--epoch' and '--iter'",
),
) )
parser.add_argument( parser.add_argument(
@ -186,7 +188,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -246,7 +249,9 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -258,7 +263,10 @@ def decode_one_batch(
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -302,7 +310,11 @@ def decode_one_batch(
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
return { return {
f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
} }
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -375,7 +387,9 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -401,7 +415,9 @@ def save_results(
# we compute CER for aishell dataset. # we compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -412,7 +428,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
@ -456,7 +473,9 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -485,7 +504,8 @@ def main():
] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg: elif len(filenames) < params.avg:
raise ValueError( raise ValueError(

View File

@ -50,7 +50,11 @@ from pathlib import Path
import torch import torch
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import str2bool from icefall.utils import str2bool
@ -83,11 +87,9 @@ def get_parser():
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help=( help="Number of checkpoints to average. Automatically select "
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'" "'--epoch' and '--iter'",
),
) )
parser.add_argument( parser.add_argument(
@ -118,7 +120,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -154,7 +157,8 @@ def main():
] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg: elif len(filenames) < params.avg:
raise ValueError( raise ValueError(
@ -187,7 +191,9 @@ def main():
model.__class__.forward = torch.jit.ignore(model.__class__.forward) model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" filename = (
params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
)
model.save(str(filename)) model.save(str(filename))
logging.info(f"Saved to {filename}") logging.info(f"Saved to {filename}")
else: else:
@ -195,14 +201,17 @@ def main():
# Save it using a format so that it can be loaded # Save it using a format so that it can be loaded
# by :func:`load_checkpoint` # by :func:`load_checkpoint`
filename = ( filename = (
params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" params.exp_dir
/ f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
) )
torch.save({"model": model.state_dict()}, str(filename)) torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}") logging.info(f"Saved to {filename}")
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -87,11 +87,9 @@ def get_parser():
"--checkpoint", "--checkpoint",
type=str, type=str,
required=True, required=True,
help=( help="Path to the checkpoint. "
"Path to the checkpoint. "
"The checkpoint is assumed to be saved by " "The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint()." "icefall.checkpoint.save_checkpoint().",
),
) )
parser.add_argument( parser.add_argument(
@ -117,12 +115,10 @@ def get_parser():
"sound_files", "sound_files",
type=str, type=str,
nargs="+", nargs="+",
help=( help="The input sound file(s) to transcribe. "
"The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). " "Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. " "For example, wav and flac are supported. "
"The sample rate has to be 16kHz." "The sample rate has to be 16kHz.",
),
) )
parser.add_argument( parser.add_argument(
@ -169,16 +165,15 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=1, default=1,
help=( help="Maximum number of symbols per frame. "
"Maximum number of symbols per frame. " "Use only when --method is greedy_search",
"Use only when --method is greedy_search"
),
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -201,9 +196,10 @@ def read_sound_files(
ans = [] ans = []
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert ( assert sample_rate == expected_sample_rate, (
sample_rate == expected_sample_rate f"expected sample rate: {expected_sample_rate}. "
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" f"Given: {sample_rate}"
)
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
return ans return ans
@ -260,9 +256,13 @@ def main():
feature_lens = [f.size(0) for f in features] feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device) feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lens
)
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyp_list = [] hyp_list = []
@ -310,7 +310,9 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
else: else:
raise ValueError(f"Unsupported decoding method: {params.method}") raise ValueError(
f"Unsupported decoding method: {params.method}"
)
hyp_list.append(hyp) hyp_list.append(hyp)
hyps = [] hyps = []
@ -327,7 +329,9 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -49,6 +49,7 @@ import optim
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import AishellAsrDataModule from asr_datamodule import AishellAsrDataModule
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
@ -74,7 +75,9 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
@ -200,7 +203,8 @@ def get_parser():
"--initial-lr", "--initial-lr",
type=float, type=float,
default=0.003, default=0.003,
help="The initial learning rate. This value should not need to be changed.", help="The initial learning rate. This value should not need "
"to be changed.",
) )
parser.add_argument( parser.add_argument(
@ -223,45 +227,42 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--prune-range", "--prune-range",
type=int, type=int,
default=5, default=5,
help=( help="The prune range for rnnt loss, it means how many symbols(context)"
"The prune range for rnnt loss, it means how many symbols(context)" "we are using to compute the loss",
"we are using to compute the loss"
),
) )
parser.add_argument( parser.add_argument(
"--lm-scale", "--lm-scale",
type=float, type=float,
default=0.25, default=0.25,
help=( help="The scale to smooth the loss with lm "
"The scale to smooth the loss with lm (output of prediction network) part." "(output of prediction network) part.",
),
) )
parser.add_argument( parser.add_argument(
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)part.", help="The scale to smooth the loss with am (output of encoder network)"
"part.",
) )
parser.add_argument( parser.add_argument(
"--simple-loss-scale", "--simple-loss-scale",
type=float, type=float,
default=0.5, default=0.5,
help=( help="To get pruning ranges, we will calculate a simple version"
"To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for" "loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss" "training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss." "with this parameter before adding to the final loss.",
),
) )
parser.add_argument( parser.add_argument(
@ -560,7 +561,11 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = model.device if isinstance(model, DDP) else next(model.parameters()).device device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
@ -588,16 +593,23 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge, # overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet. # in case it had not fully learned the alignment yet.
pruned_loss_scale = ( pruned_loss_scale = (
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) 0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
) )
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -713,7 +725,9 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
except: # noqa except: # noqa
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) display_and_save_batch(
batch, params=params, graph_compiler=graph_compiler
)
raise raise
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
@ -1015,7 +1029,9 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} " f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..." f"(={crit_values[criterion]}) ..."
) )
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) display_and_save_batch(
batch, params=params, graph_compiler=graph_compiler
)
raise raise

View File

@ -121,24 +121,20 @@ def get_parser():
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help=( help="Number of checkpoints to average. Automatically select "
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'" "'--epoch' and '--iter'",
),
) )
parser.add_argument( parser.add_argument(
"--use-averaged-model", "--use-averaged-model",
type=str2bool, type=str2bool,
default=False, default=False,
help=( help="Whether to load averaged model. Currently it only supports "
"Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model " "using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`." "over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and " "Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. " "`epoch` are loaded for averaging. ",
),
) )
parser.add_argument( parser.add_argument(
@ -206,7 +202,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -266,7 +263,9 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -278,7 +277,10 @@ def decode_one_batch(
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -322,7 +324,11 @@ def decode_one_batch(
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
return { return {
f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
} }
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -395,7 +401,9 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -421,7 +429,9 @@ def save_results(
# we compute CER for aishell dataset. # we compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -432,7 +442,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tCER", file=f) print("settings\tCER", file=f)
@ -477,7 +488,9 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -505,12 +518,13 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ filenames = find_checkpoints(
: params.avg params.exp_dir, iteration=-params.iter
] )[: params.avg]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg: elif len(filenames) < params.avg:
raise ValueError( raise ValueError(
@ -537,12 +551,13 @@ def main():
) )
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ filenames = find_checkpoints(
: params.avg + 1 params.exp_dir, iteration=-params.iter
] )[: params.avg + 1]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg + 1: elif len(filenames) < params.avg + 1:
raise ValueError( raise ValueError(
@ -571,7 +586,7 @@ def main():
filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info( logging.info(
"Calculating the averaged model over epoch range from " f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}" f"{start} (excluded) to {params.epoch}"
) )
model.to(device) model.to(device)

View File

@ -88,24 +88,20 @@ def get_parser():
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help=( help="Number of checkpoints to average. Automatically select "
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'" "'--epoch' and '--iter'",
),
) )
parser.add_argument( parser.add_argument(
"--use-averaged-model", "--use-averaged-model",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="Whether to load averaged model. Currently it only supports "
"Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model " "using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`." "over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and " "Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. " "`epoch` are loaded for averaging. ",
),
) )
parser.add_argument( parser.add_argument(
@ -136,7 +132,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -169,12 +166,13 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ filenames = find_checkpoints(
: params.avg params.exp_dir, iteration=-params.iter
] )[: params.avg]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg: elif len(filenames) < params.avg:
raise ValueError( raise ValueError(
@ -197,12 +195,13 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ filenames = find_checkpoints(
: params.avg + 1 params.exp_dir, iteration=-params.iter
] )[: params.avg + 1]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg + 1: elif len(filenames) < params.avg + 1:
raise ValueError( raise ValueError(
@ -230,7 +229,7 @@ def main():
filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info( logging.info(
"Calculating the averaged model over epoch range from " f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}" f"{start} (excluded) to {params.epoch}"
) )
model.to(device) model.to(device)
@ -253,7 +252,9 @@ def main():
model.__class__.forward = torch.jit.ignore(model.__class__.forward) model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" filename = (
params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
)
model.save(str(filename)) model.save(str(filename))
logging.info(f"Saved to {filename}") logging.info(f"Saved to {filename}")
else: else:
@ -261,14 +262,17 @@ def main():
# Save it using a format so that it can be loaded # Save it using a format so that it can be loaded
# by :func:`load_checkpoint` # by :func:`load_checkpoint`
filename = ( filename = (
params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" params.exp_dir
/ f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
) )
torch.save({"model": model.state_dict()}, str(filename)) torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}") logging.info(f"Saved to {filename}")
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -84,7 +84,9 @@ class Transducer(nn.Module):
self.decoder_datatang = decoder_datatang self.decoder_datatang = decoder_datatang
self.joiner_datatang = joiner_datatang self.joiner_datatang = joiner_datatang
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_am_proj = ScaledLinear(
encoder_dim, vocab_size, initial_speed=0.5
)
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
if decoder_datatang is not None: if decoder_datatang is not None:
@ -177,7 +179,9 @@ class Transducer(nn.Module):
y_padded = y.pad(mode="constant", padding_value=0) y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64) y_padded = y_padded.to(torch.int64)
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary = torch.zeros(
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens boundary[:, 2] = y_lens
boundary[:, 3] = encoder_out_lens boundary[:, 3] = encoder_out_lens

View File

@ -87,11 +87,9 @@ def get_parser():
"--checkpoint", "--checkpoint",
type=str, type=str,
required=True, required=True,
help=( help="Path to the checkpoint. "
"Path to the checkpoint. "
"The checkpoint is assumed to be saved by " "The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint()." "icefall.checkpoint.save_checkpoint().",
),
) )
parser.add_argument( parser.add_argument(
@ -117,12 +115,10 @@ def get_parser():
"sound_files", "sound_files",
type=str, type=str,
nargs="+", nargs="+",
help=( help="The input sound file(s) to transcribe. "
"The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). " "Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. " "For example, wav and flac are supported. "
"The sample rate has to be 16kHz." "The sample rate has to be 16kHz.",
),
) )
parser.add_argument( parser.add_argument(
@ -169,16 +165,15 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=1, default=1,
help=( help="Maximum number of symbols per frame. "
"Maximum number of symbols per frame. " "Use only when --method is greedy_search",
"Use only when --method is greedy_search"
),
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -201,9 +196,10 @@ def read_sound_files(
ans = [] ans = []
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert ( assert sample_rate == expected_sample_rate, (
sample_rate == expected_sample_rate f"expected sample rate: {expected_sample_rate}. "
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" f"Given: {sample_rate}"
)
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
return ans return ans
@ -261,9 +257,13 @@ def main():
feature_lens = [f.size(0) for f in features] feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device) feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lens
)
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyp_list = [] hyp_list = []
@ -311,7 +311,9 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
else: else:
raise ValueError(f"Unsupported decoding method: {params.method}") raise ValueError(
f"Unsupported decoding method: {params.method}"
)
hyp_list.append(hyp) hyp_list.append(hyp)
hyps = [] hyps = []
@ -328,7 +330,9 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -96,7 +96,9 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
@ -222,7 +224,8 @@ def get_parser():
"--initial-lr", "--initial-lr",
type=float, type=float,
default=0.003, default=0.003,
help="The initial learning rate. This value should not need to be changed.", help="The initial learning rate. This value should not need "
"to be changed.",
) )
parser.add_argument( parser.add_argument(
@ -245,45 +248,42 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--prune-range", "--prune-range",
type=int, type=int,
default=5, default=5,
help=( help="The prune range for rnnt loss, it means how many symbols(context)"
"The prune range for rnnt loss, it means how many symbols(context)" "we are using to compute the loss",
"we are using to compute the loss"
),
) )
parser.add_argument( parser.add_argument(
"--lm-scale", "--lm-scale",
type=float, type=float,
default=0.25, default=0.25,
help=( help="The scale to smooth the loss with lm "
"The scale to smooth the loss with lm (output of prediction network) part." "(output of prediction network) part.",
),
) )
parser.add_argument( parser.add_argument(
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)part.", help="The scale to smooth the loss with am (output of encoder network)"
"part.",
) )
parser.add_argument( parser.add_argument(
"--simple-loss-scale", "--simple-loss-scale",
type=float, type=float,
default=0.5, default=0.5,
help=( help="To get pruning ranges, we will calculate a simple version"
"To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for" "loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss" "training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss." "with this parameter before adding to the final loss.",
),
) )
parser.add_argument( parser.add_argument(
@ -635,7 +635,11 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = model.device if isinstance(model, DDP) else next(model.parameters()).device device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
@ -666,16 +670,23 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge, # overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet. # in case it had not fully learned the alignment yet.
pruned_loss_scale = ( pruned_loss_scale = (
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) 0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
) )
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -813,7 +824,9 @@ def train_one_epoch(
) )
# summary stats # summary stats
if datatang_train_dl is not None: if datatang_train_dl is not None:
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (
tot_loss * (1 - 1 / params.reset_interval)
) + loss_info
if aishell: if aishell:
aishell_tot_loss = ( aishell_tot_loss = (
@ -834,7 +847,9 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
except: # noqa except: # noqa
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) display_and_save_batch(
batch, params=params, graph_compiler=graph_compiler
)
raise raise
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
@ -877,7 +892,9 @@ def train_one_epoch(
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]
if datatang_train_dl is not None: if datatang_train_dl is not None:
datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], " datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], "
tot_loss_str = f"tot_loss[{tot_loss}], batch size: {batch_size}, " tot_loss_str = (
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
)
else: else:
tot_loss_str = "" tot_loss_str = ""
datatang_str = "" datatang_str = ""
@ -1059,7 +1076,9 @@ def run(rank, world_size, args):
train_cuts = filter_short_and_long_utterances(train_cuts) train_cuts = filter_short_and_long_utterances(train_cuts)
if args.enable_musan: if args.enable_musan:
cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") cuts_musan = load_manifest(
Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
)
else: else:
cuts_musan = None cuts_musan = None
@ -1074,7 +1093,9 @@ def run(rank, world_size, args):
if params.datatang_prob > 0: if params.datatang_prob > 0:
datatang = AIDatatang200zh(manifest_dir=args.manifest_dir) datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
train_datatang_cuts = datatang.train_cuts() train_datatang_cuts = datatang.train_cuts()
train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts) train_datatang_cuts = filter_short_and_long_utterances(
train_datatang_cuts
)
train_datatang_cuts = train_datatang_cuts.repeat(times=None) train_datatang_cuts = train_datatang_cuts.repeat(times=None)
datatang_train_dl = asr_datamodule.train_dataloaders( datatang_train_dl = asr_datamodule.train_dataloaders(
train_datatang_cuts, train_datatang_cuts,
@ -1228,7 +1249,9 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} " f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..." f"(={crit_values[criterion]}) ..."
) )
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) display_and_save_batch(
batch, params=params, graph_compiler=graph_compiler
)
raise raise

View File

@ -64,12 +64,10 @@ class AishellAsrDataModule:
def add_arguments(cls, parser: argparse.ArgumentParser): def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group( group = parser.add_argument_group(
title="ASR data related options", title="ASR data related options",
description=( description="These options are used for the preparation of "
"These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the " "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data " "effective batch sizes, sampling strategies, applied data "
"augmentations, etc." "augmentations, etc.",
),
) )
group.add_argument( group.add_argument(
"--manifest-dir", "--manifest-dir",
@ -81,74 +79,59 @@ class AishellAsrDataModule:
"--max-duration", "--max-duration",
type=int, type=int,
default=200.0, default=200.0,
help=( help="Maximum pooled recordings duration (seconds) in a "
"Maximum pooled recordings duration (seconds) in a " "single batch. You can reduce it if it causes CUDA OOM.",
"single batch. You can reduce it if it causes CUDA OOM."
),
) )
group.add_argument( group.add_argument(
"--bucketing-sampler", "--bucketing-sampler",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled, the batches will come from buckets of "
"When enabled, the batches will come from buckets of " "similar duration (saves padding frames).",
"similar duration (saves padding frames)."
),
) )
group.add_argument( group.add_argument(
"--num-buckets", "--num-buckets",
type=int, type=int,
default=30, default=30,
help=( help="The number of buckets for the DynamicBucketingSampler"
"The number of buckets for the DynamicBucketingSampler" "(you might want to increase it for larger datasets).",
"(you might want to increase it for larger datasets)."
),
) )
group.add_argument( group.add_argument(
"--concatenate-cuts", "--concatenate-cuts",
type=str2bool, type=str2bool,
default=False, default=False,
help=( help="When enabled, utterances (cuts) will be concatenated "
"When enabled, utterances (cuts) will be concatenated " "to minimize the amount of padding.",
"to minimize the amount of padding."
),
) )
group.add_argument( group.add_argument(
"--duration-factor", "--duration-factor",
type=float, type=float,
default=1.0, default=1.0,
help=( help="Determines the maximum duration of a concatenated cut "
"Determines the maximum duration of a concatenated cut " "relative to the duration of the longest cut in a batch.",
"relative to the duration of the longest cut in a batch."
),
) )
group.add_argument( group.add_argument(
"--gap", "--gap",
type=float, type=float,
default=1.0, default=1.0,
help=( help="The amount of padding (in seconds) inserted between "
"The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when " "concatenated cuts. This padding is filled with noise when "
"noise augmentation is used." "noise augmentation is used.",
),
) )
group.add_argument( group.add_argument(
"--on-the-fly-feats", "--on-the-fly-feats",
type=str2bool, type=str2bool,
default=False, default=False,
help=( help="When enabled, use on-the-fly cut mixing and feature "
"When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests " "extraction. Will drop existing precomputed feature manifests "
"if available." "if available.",
),
) )
group.add_argument( group.add_argument(
"--shuffle", "--shuffle",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled (=default), the examples will be "
"When enabled (=default), the examples will be shuffled for each epoch." "shuffled for each epoch.",
),
) )
group.add_argument( group.add_argument(
"--drop-last", "--drop-last",
@ -160,18 +143,17 @@ class AishellAsrDataModule:
"--return-cuts", "--return-cuts",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled, each batch will have the "
"When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that " "field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it." "were used to construct it.",
),
) )
group.add_argument( group.add_argument(
"--num-workers", "--num-workers",
type=int, type=int,
default=2, default=2,
help="The number of training dataloader workers that collect the batches.", help="The number of training dataloader workers that "
"collect the batches.",
) )
group.add_argument( group.add_argument(
@ -185,40 +167,40 @@ class AishellAsrDataModule:
"--spec-aug-time-warp-factor", "--spec-aug-time-warp-factor",
type=int, type=int,
default=80, default=80,
help=( help="Used only when --enable-spec-aug is True. "
"Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. " "It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. " "Larger values mean more warping. "
"A value less than 1 means to disable time warp." "A value less than 1 means to disable time warp.",
),
) )
group.add_argument( group.add_argument(
"--enable-musan", "--enable-musan",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled, select noise from MUSAN and mix it"
"When enabled, select noise from MUSAN and mix it" "with training dataset. ",
"with training dataset. "
),
) )
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") cuts_musan = load_manifest(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms = [] transforms = []
if self.args.enable_musan: if self.args.enable_musan:
logging.info("Enable MUSAN") logging.info("Enable MUSAN")
transforms.append( transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
if self.args.concatenate_cuts: if self.args.concatenate_cuts:
logging.info( logging.info(
"Using cut concatenation with duration factor " f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}." f"{self.args.duration_factor} and gap {self.args.gap}."
) )
# Cut concatenation should be the first transform in the list, # Cut concatenation should be the first transform in the list,
@ -233,7 +215,9 @@ class AishellAsrDataModule:
input_transforms = [] input_transforms = []
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
logging.info("Enable SpecAugment") logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version. # Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is # In different Lhotse's versions, the default of num_frame_masks is
# different. # different.
@ -276,7 +260,9 @@ class AishellAsrDataModule:
# Drop feats to be on the safe side. # Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -322,7 +308,9 @@ class AishellAsrDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:
@ -378,9 +366,13 @@ class AishellAsrDataModule:
@lru_cache() @lru_cache()
def valid_cuts(self) -> CutSet: def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts") logging.info("About to get dev cuts")
return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz") return load_manifest_lazy(
self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz"
)
@lru_cache() @lru_cache()
def test_cuts(self) -> List[CutSet]: def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts") logging.info("About to get test cuts")
return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz") return load_manifest_lazy(
self.args.manifest_dir / "aishell_cuts_test.jsonl.gz"
)

View File

@ -49,19 +49,16 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=19, default=19,
help=( help="It specifies the checkpoint to use for decoding."
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." "Note: Epoch counts from 0.",
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=5, default=5,
help=( help="Number of checkpoints to average. Automatically select "
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. " "'--epoch'. ",
),
) )
parser.add_argument( parser.add_argument(
"--method", "--method",
@ -268,7 +265,9 @@ def decode_dataset(
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -290,7 +289,9 @@ def save_results(
# We compute CER for aishell dataset. # We compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results_char) wer = write_error_stats(f, f"{test_set_name}-{key}", results_char)
test_set_wers[key] = wer test_set_wers[key] = wer
@ -334,7 +335,9 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device) HLG = HLG.to(device)
assert HLG.requires_grad is False assert HLG.requires_grad is False
@ -359,7 +362,9 @@ def main():
if params.export: if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
model.to(device) model.to(device)
model.eval() model.eval()
@ -387,7 +392,9 @@ def main():
lexicon=lexicon, lexicon=lexicon,
) )
save_results(params=params, test_set_name=test_set, results_dict=results_dict) save_results(
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!") logging.info("Done!")

View File

@ -66,7 +66,10 @@ class TdnnLstm(nn.Module):
nn.BatchNorm1d(num_features=500, affine=False), nn.BatchNorm1d(num_features=500, affine=False),
) )
self.lstms = nn.ModuleList( self.lstms = nn.ModuleList(
[nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)] [
nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
for _ in range(5)
]
) )
self.lstm_bnorms = nn.ModuleList( self.lstm_bnorms = nn.ModuleList(
[nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)] [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]

View File

@ -41,11 +41,9 @@ def get_parser():
"--checkpoint", "--checkpoint",
type=str, type=str,
required=True, required=True,
help=( help="Path to the checkpoint. "
"Path to the checkpoint. "
"The checkpoint is assumed to be saved by " "The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint()." "icefall.checkpoint.save_checkpoint().",
),
) )
parser.add_argument( parser.add_argument(
@ -55,7 +53,9 @@ def get_parser():
help="Path to words.txt", help="Path to words.txt",
) )
parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") parser.add_argument(
"--HLG", type=str, required=True, help="Path to HLG.pt."
)
parser.add_argument( parser.add_argument(
"--method", "--method",
@ -71,12 +71,10 @@ def get_parser():
"sound_files", "sound_files",
type=str, type=str,
nargs="+", nargs="+",
help=( help="The input sound file(s) to transcribe. "
"The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). " "Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. " "For example, wav and flac are supported. "
"The sample rate has to be 16kHz." "The sample rate has to be 16kHz.",
),
) )
return parser return parser
@ -114,9 +112,10 @@ def read_sound_files(
ans = [] ans = []
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert ( assert sample_rate == expected_sample_rate, (
sample_rate == expected_sample_rate f"expected sample rate: {expected_sample_rate}. "
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" f"Given: {sample_rate}"
)
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
return ans return ans
@ -174,7 +173,9 @@ def main():
logging.info("Decoding started") logging.info("Decoding started")
features = fbank(waves) features = fbank(waves)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
features = features.permute(0, 2, 1) # now features is [N, C, T] features = features.permute(0, 2, 1) # now features is [N, C, T]
with torch.no_grad(): with torch.no_grad():
@ -218,7 +219,9 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -49,7 +49,12 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool from icefall.utils import (
AttributeDict,
encode_supervisions,
setup_logger,
str2bool,
)
def get_parser(): def get_parser():

View File

@ -47,9 +47,9 @@ def greedy_search(
device = model.device device = model.device
decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape( decoder_input = torch.tensor(
1, context_size [blank_id] * context_size, device=device
) ).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
@ -81,9 +81,9 @@ def greedy_search(
y = logits.argmax().item() y = logits.argmax().item()
if y != blank_id: if y != blank_id:
hyp.append(y) hyp.append(y)
decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( decoder_input = torch.tensor(
1, context_size [hyp[-context_size:]], device=device
) ).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
@ -157,7 +157,9 @@ class HypothesisList(object):
""" """
if length_norm: if length_norm:
return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) return max(
self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
)
else: else:
return max(self._data.values(), key=lambda hyp: hyp.log_prob) return max(self._data.values(), key=lambda hyp: hyp.log_prob)
@ -244,9 +246,9 @@ def beam_search(
device = model.device device = model.device
decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape( decoder_input = torch.tensor(
1, context_size [blank_id] * context_size, device=device
) ).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)

View File

@ -155,7 +155,9 @@ class ConformerEncoderLayer(nn.Module):
normalize_before: bool = True, normalize_before: bool = True,
) -> None: ) -> None:
super(ConformerEncoderLayer, self).__init__() super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0
)
self.feed_forward = nn.Sequential( self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
@ -173,14 +175,18 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff_macaron = nn.LayerNorm(
d_model
) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5 self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.norm_final = nn.LayerNorm(
d_model
) # for the final output of the block
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
@ -214,7 +220,9 @@ class ConformerEncoderLayer(nn.Module):
residual = src residual = src
if self.normalize_before: if self.normalize_before:
src = self.norm_ff_macaron(src) src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) src = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(src)
)
if not self.normalize_before: if not self.normalize_before:
src = self.norm_ff_macaron(src) src = self.norm_ff_macaron(src)
@ -333,7 +341,9 @@ class RelPositionalEncoding(torch.nn.Module):
""" """
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: def __init__(
self, d_model: int, dropout_rate: float, max_len: int = 5000
) -> None:
"""Construct an PositionalEncoding object.""" """Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__() super(RelPositionalEncoding, self).__init__()
self.d_model = d_model self.d_model = d_model
@ -349,7 +359,9 @@ class RelPositionalEncoding(torch.nn.Module):
# the length of self.pe is 2 * input_len - 1 # the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1: if self.pe.size(1) >= x.size(1) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device # Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device) self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return return
# Suppose `i` means to the position of query vector and `j` means the # Suppose `i` means to the position of query vector and `j` means the
@ -619,9 +631,9 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value): if torch.equal(query, key) and torch.equal(key, value):
# self-attention # self-attention
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( q, k, v = nn.functional.linear(
3, dim=-1 query, in_proj_weight, in_proj_bias
) ).chunk(3, dim=-1)
elif torch.equal(key, value): elif torch.equal(key, value):
# encoder-decoder attention # encoder-decoder attention
@ -689,25 +701,33 @@ class RelPositionMultiheadAttention(nn.Module):
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0) attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError("The size of the 2D attn_mask is not correct.") raise RuntimeError(
"The size of the 2D attn_mask is not correct."
)
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [ if list(attn_mask.size()) != [
bsz * num_heads, bsz * num_heads,
query.size(0), query.size(0),
key.size(0), key.size(0),
]: ]:
raise RuntimeError("The size of the 3D attn_mask is not correct.") raise RuntimeError(
"The size of the 3D attn_mask is not correct."
)
else: else:
raise RuntimeError( raise RuntimeError(
"attn_mask's dimension {} is not supported".format(attn_mask.dim()) "attn_mask's dimension {} is not supported".format(
attn_mask.dim()
)
) )
# attn_mask's dim is 3 now. # attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool # convert ByteTensor key_padding_mask to bool
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: if (
key_padding_mask is not None
and key_padding_mask.dtype == torch.uint8
):
warnings.warn( warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor" "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
" instead."
) )
key_padding_mask = key_padding_mask.to(torch.bool) key_padding_mask = key_padding_mask.to(torch.bool)
@ -744,7 +764,9 @@ class RelPositionMultiheadAttention(nn.Module):
# first compute matrix a and matrix c # first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) matrix_ac = torch.matmul(
q_with_bias_u, k
) # (batch, head, time1, time2)
# compute matrix b and matrix d # compute matrix b and matrix d
matrix_bd = torch.matmul( matrix_bd = torch.matmul(
@ -756,7 +778,9 @@ class RelPositionMultiheadAttention(nn.Module):
matrix_ac + matrix_bd matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2) ) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1
)
assert list(attn_output_weights.size()) == [ assert list(attn_output_weights.size()) == [
bsz * num_heads, bsz * num_heads,
@ -790,9 +814,13 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = torch.bmm(attn_output_weights, v) attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = ( attn_output = (
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn_output.transpose(0, 1)
.contiguous()
.view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias
) )
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights: if need_weights:
# average attention weights over heads # average attention weights over heads
@ -815,7 +843,9 @@ class ConvolutionModule(nn.Module):
""" """
def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: def __init__(
self, channels: int, kernel_size: int, bias: bool = True
) -> None:
"""Construct an ConvolutionModule object.""" """Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__() super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding # kernerl_size should be a odd number for 'SAME' padding

View File

@ -52,19 +52,16 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=30, default=30,
help=( help="It specifies the checkpoint to use for decoding."
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." "Note: Epoch counts from 0.",
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=10, default=10,
help=( help="Number of checkpoints to average. Automatically select "
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. " "'--epoch'. ",
),
) )
parser.add_argument( parser.add_argument(
@ -102,7 +99,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -229,7 +227,9 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = [] hyps = []
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -248,7 +248,9 @@ def decode_one_batch(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model, encoder_out=encoder_out_i, beam=params.beam_size
) )
else: else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}") raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[i] for i in hyp]) hyps.append([lexicon.token_table[i] for i in hyp])
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
@ -317,7 +319,9 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -342,7 +346,9 @@ def save_results(
# we compute CER for aishell dataset. # we compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -353,7 +359,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tCER", file=f) print("settings\tCER", file=f)
@ -423,7 +430,9 @@ def main():
if params.export: if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return return
model.to(device) model.to(device)

View File

@ -86,7 +86,9 @@ class Decoder(nn.Module):
if self.context_size > 1: if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True: if need_pad is True:
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
else: else:
# During inference time, there is no need to do extra padding # During inference time, there is no need to do extra padding
# as we only need one output # as we only need one output

View File

@ -69,20 +69,17 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=20, default=20,
help=( help="It specifies the checkpoint to use for decoding."
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." "Note: Epoch counts from 0.",
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=10, default=10,
help=( help="Number of checkpoints to average. Automatically select "
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. " "'--epoch'. ",
),
) )
parser.add_argument( parser.add_argument(
@ -113,7 +110,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
return parser return parser
@ -245,7 +243,9 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -103,7 +103,9 @@ class Transducer(nn.Module):
y_padded = y.pad(mode="constant", padding_value=0) y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64) y_padded = y_padded.to(torch.int64)
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary = torch.zeros(
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens boundary[:, 2] = y_lens
boundary[:, 3] = x_lens boundary[:, 3] = x_lens

View File

@ -73,11 +73,9 @@ def get_parser():
"--checkpoint", "--checkpoint",
type=str, type=str,
required=True, required=True,
help=( help="Path to the checkpoint. "
"Path to the checkpoint. "
"The checkpoint is assumed to be saved by " "The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint()." "icefall.checkpoint.save_checkpoint().",
),
) )
parser.add_argument( parser.add_argument(
@ -102,12 +100,10 @@ def get_parser():
"sound_files", "sound_files",
type=str, type=str,
nargs="+", nargs="+",
help=( help="The input sound file(s) to transcribe. "
"The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). " "Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. " "For example, wav and flac are supported. "
"The sample rate has to be 16kHz." "The sample rate has to be 16kHz.",
),
) )
parser.add_argument( parser.add_argument(
@ -121,7 +117,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -214,9 +211,10 @@ def read_sound_files(
ans = [] ans = []
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert ( assert sample_rate == expected_sample_rate, (
sample_rate == expected_sample_rate f"expected sample rate: {expected_sample_rate}. "
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" f"Given: {sample_rate}"
)
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
return ans return ans
@ -275,7 +273,9 @@ def main():
features = fbank(waves) features = fbank(waves)
feature_lengths = [f.size(0) for f in features] feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device) feature_lengths = torch.tensor(feature_lengths, device=device)
@ -319,7 +319,9 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -126,7 +126,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -388,7 +389,9 @@ def compute_loss(
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -501,7 +504,9 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
@ -620,7 +625,9 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate cur_lr = optimizer._rate
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0: if rank == 0:

View File

@ -250,7 +250,9 @@ def _get_activation_fn(activation: str):
elif activation == "gelu": elif activation == "gelu":
return nn.functional.gelu return nn.functional.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) raise RuntimeError(
"activation should be relu/gelu, not {}".format(activation)
)
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):

View File

@ -29,7 +29,10 @@ from lhotse.dataset import (
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures from lhotse.dataset.input_strategies import (
OnTheFlyFeatures,
PrecomputedFeatures,
)
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from icefall.utils import str2bool from icefall.utils import str2bool
@ -43,69 +46,59 @@ class AsrDataModule:
def add_arguments(cls, parser: argparse.ArgumentParser): def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group( group = parser.add_argument_group(
title="ASR data related options", title="ASR data related options",
description=( description="These options are used for the preparation of "
"These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the " "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data " "effective batch sizes, sampling strategies, applied data "
"augmentations, etc." "augmentations, etc.",
),
) )
group.add_argument( group.add_argument(
"--max-duration", "--max-duration",
type=int, type=int,
default=200.0, default=200.0,
help=( help="Maximum pooled recordings duration (seconds) in a "
"Maximum pooled recordings duration (seconds) in a " "single batch. You can reduce it if it causes CUDA OOM.",
"single batch. You can reduce it if it causes CUDA OOM."
),
) )
group.add_argument( group.add_argument(
"--bucketing-sampler", "--bucketing-sampler",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled, the batches will come from buckets of "
"When enabled, the batches will come from buckets of " "similar duration (saves padding frames).",
"similar duration (saves padding frames)."
),
) )
group.add_argument( group.add_argument(
"--num-buckets", "--num-buckets",
type=int, type=int,
default=30, default=30,
help=( help="The number of buckets for the DynamicBucketingSampler "
"The number of buckets for the DynamicBucketingSampler " "(you might want to increase it for larger datasets).",
"(you might want to increase it for larger datasets)."
),
) )
group.add_argument( group.add_argument(
"--shuffle", "--shuffle",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled (=default), the examples will be "
"When enabled (=default), the examples will be shuffled for each epoch." "shuffled for each epoch.",
),
) )
group.add_argument( group.add_argument(
"--return-cuts", "--return-cuts",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled, each batch will have the "
"When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that " "field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it." "were used to construct it.",
),
) )
group.add_argument( group.add_argument(
"--num-workers", "--num-workers",
type=int, type=int,
default=2, default=2,
help="The number of training dataloader workers that collect the batches.", help="The number of training dataloader workers that "
"collect the batches.",
) )
group.add_argument( group.add_argument(
@ -119,22 +112,18 @@ class AsrDataModule:
"--spec-aug-time-warp-factor", "--spec-aug-time-warp-factor",
type=int, type=int,
default=80, default=80,
help=( help="Used only when --enable-spec-aug is True. "
"Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. " "It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. " "Larger values mean more warping. "
"A value less than 1 means to disable time warp." "A value less than 1 means to disable time warp.",
),
) )
group.add_argument( group.add_argument(
"--enable-musan", "--enable-musan",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled, select noise from MUSAN and mix it"
"When enabled, select noise from MUSAN and mix it" "with training dataset. ",
"with training dataset. "
),
) )
group.add_argument( group.add_argument(
@ -148,11 +137,9 @@ class AsrDataModule:
"--on-the-fly-feats", "--on-the-fly-feats",
type=str2bool, type=str2bool,
default=False, default=False,
help=( help="When enabled, use on-the-fly cut mixing and feature "
"When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests " "extraction. Will drop existing precomputed feature manifests "
"if available. Used only in dev/test CutSet" "if available. Used only in dev/test CutSet",
),
) )
def train_dataloaders( def train_dataloaders(
@ -175,7 +162,9 @@ class AsrDataModule:
if cuts_musan is not None: if cuts_musan is not None:
logging.info("Enable MUSAN") logging.info("Enable MUSAN")
transforms.append( transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
@ -184,7 +173,9 @@ class AsrDataModule:
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
logging.info("Enable SpecAugment") logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version. # Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is # In different Lhotse's versions, the default of num_frame_masks is
# different. # different.
@ -261,7 +252,9 @@ class AsrDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:

View File

@ -93,19 +93,16 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=30, default=30,
help=( help="It specifies the checkpoint to use for decoding."
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." "Note: Epoch counts from 0.",
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=10, default=10,
help=( help="Number of checkpoints to average. Automatically select "
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. " "'--epoch'. ",
),
) )
parser.add_argument( parser.add_argument(
@ -173,7 +170,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -229,7 +227,9 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -241,7 +241,10 @@ def decode_one_batch(
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -285,7 +288,11 @@ def decode_one_batch(
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
return { return {
f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
} }
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -358,7 +365,9 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -384,7 +393,9 @@ def save_results(
# we compute CER for aishell dataset. # we compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -395,7 +406,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tCER", file=f) print("settings\tCER", file=f)
@ -436,7 +448,9 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"

View File

@ -68,20 +68,17 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=20, default=20,
help=( help="It specifies the checkpoint to use for decoding."
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." "Note: Epoch counts from 0.",
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=10, default=10,
help=( help="Number of checkpoints to average. Automatically select "
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. " "'--epoch'. ",
),
) )
parser.add_argument( parser.add_argument(
@ -112,7 +109,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
return parser return parser
@ -243,7 +241,9 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -87,11 +87,9 @@ def get_parser():
"--checkpoint", "--checkpoint",
type=str, type=str,
required=True, required=True,
help=( help="Path to the checkpoint. "
"Path to the checkpoint. "
"The checkpoint is assumed to be saved by " "The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint()." "icefall.checkpoint.save_checkpoint().",
),
) )
parser.add_argument( parser.add_argument(
@ -117,12 +115,10 @@ def get_parser():
"sound_files", "sound_files",
type=str, type=str,
nargs="+", nargs="+",
help=( help="The input sound file(s) to transcribe. "
"The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). " "Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. " "For example, wav and flac are supported. "
"The sample rate has to be 16kHz." "The sample rate has to be 16kHz.",
),
) )
parser.add_argument( parser.add_argument(
@ -169,16 +165,15 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=1, default=1,
help=( help="Maximum number of symbols per frame. "
"Maximum number of symbols per frame. " "Use only when --method is greedy_search",
"Use only when --method is greedy_search"
),
) )
return parser return parser
@ -199,9 +194,10 @@ def read_sound_files(
ans = [] ans = []
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert ( assert sample_rate == expected_sample_rate, (
sample_rate == expected_sample_rate f"expected sample rate: {expected_sample_rate}. "
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" f"Given: {sample_rate}"
)
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
return ans return ans
@ -258,9 +254,13 @@ def main():
feature_lens = [f.size(0) for f in features] feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device) feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lens
)
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyp_list = [] hyp_list = []
@ -308,7 +308,9 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
else: else:
raise ValueError(f"Unsupported decoding method: {params.method}") raise ValueError(
f"Unsupported decoding method: {params.method}"
)
hyp_list.append(hyp) hyp_list.append(hyp)
hyps = [] hyps = []
@ -325,7 +327,9 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -149,7 +149,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -167,7 +168,8 @@ def get_parser():
"--datatang-prob", "--datatang-prob",
type=float, type=float,
default=0.2, default=0.2,
help="The probability to select a batch from the aidatatang_200zh dataset", help="The probability to select a batch from the "
"aidatatang_200zh dataset",
) )
return parser return parser
@ -447,7 +449,9 @@ def compute_loss(
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -601,7 +605,9 @@ def train_one_epoch(
f"train/current_{prefix}_", f"train/current_{prefix}_",
params.batch_idx_train, params.batch_idx_train,
) )
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
aishell_tot_loss.write_summary( aishell_tot_loss.write_summary(
tb_writer, "train/aishell_tot_", params.batch_idx_train tb_writer, "train/aishell_tot_", params.batch_idx_train
) )
@ -729,7 +735,9 @@ def run(rank, world_size, args):
train_datatang_cuts = train_datatang_cuts.repeat(times=None) train_datatang_cuts = train_datatang_cuts.repeat(times=None)
if args.enable_musan: if args.enable_musan:
cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") cuts_musan = load_manifest(
Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
)
else: else:
cuts_musan = None cuts_musan = None
@ -768,7 +776,9 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate cur_lr = optimizer._rate
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0: if rank == 0:

View File

@ -94,19 +94,16 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=30, default=30,
help=( help="It specifies the checkpoint to use for decoding."
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." "Note: Epoch counts from 0.",
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=10, default=10,
help=( help="Number of checkpoints to average. Automatically select "
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. " "'--epoch'. ",
),
) )
parser.add_argument( parser.add_argument(
@ -174,7 +171,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -233,7 +231,9 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -245,7 +245,10 @@ def decode_one_batch(
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -289,7 +292,11 @@ def decode_one_batch(
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
return { return {
f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
} }
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -362,7 +369,9 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -388,7 +397,9 @@ def save_results(
# we compute CER for aishell dataset. # we compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -399,7 +410,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tCER", file=f) print("settings\tCER", file=f)
@ -440,7 +452,9 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"

View File

@ -68,20 +68,17 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=20, default=20,
help=( help="It specifies the checkpoint to use for decoding."
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." "Note: Epoch counts from 0.",
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=10, default=10,
help=( help="Number of checkpoints to average. Automatically select "
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. " "'--epoch'. ",
),
) )
parser.add_argument( parser.add_argument(
@ -112,7 +109,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
return parser return parser
@ -243,7 +241,9 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -87,11 +87,9 @@ def get_parser():
"--checkpoint", "--checkpoint",
type=str, type=str,
required=True, required=True,
help=( help="Path to the checkpoint. "
"Path to the checkpoint. "
"The checkpoint is assumed to be saved by " "The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint()." "icefall.checkpoint.save_checkpoint().",
),
) )
parser.add_argument( parser.add_argument(
@ -117,12 +115,10 @@ def get_parser():
"sound_files", "sound_files",
type=str, type=str,
nargs="+", nargs="+",
help=( help="The input sound file(s) to transcribe. "
"The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). " "Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. " "For example, wav and flac are supported. "
"The sample rate has to be 16kHz." "The sample rate has to be 16kHz.",
),
) )
parser.add_argument( parser.add_argument(
@ -169,16 +165,15 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=1, default=1,
help=( help="Maximum number of symbols per frame. "
"Maximum number of symbols per frame. " "Use only when --method is greedy_search",
"Use only when --method is greedy_search"
),
) )
return parser return parser
@ -199,9 +194,10 @@ def read_sound_files(
ans = [] ans = []
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert ( assert sample_rate == expected_sample_rate, (
sample_rate == expected_sample_rate f"expected sample rate: {expected_sample_rate}. "
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" f"Given: {sample_rate}"
)
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
return ans return ans
@ -258,9 +254,13 @@ def main():
feature_lens = [f.size(0) for f in features] feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device) feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lens
)
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyp_list = [] hyp_list = []
@ -308,7 +308,9 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
else: else:
raise ValueError(f"Unsupported decoding method: {params.method}") raise ValueError(
f"Unsupported decoding method: {params.method}"
)
hyp_list.append(hyp) hyp_list.append(hyp)
hyps = [] hyps = []
@ -325,7 +327,9 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -142,7 +142,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -413,7 +414,9 @@ def compute_loss(
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -526,7 +529,9 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
@ -652,7 +657,9 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate cur_lr = optimizer._rate
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0: if rank == 0:

0
egs/aishell2/ASR/local/__init__.py Normal file → Executable file
View File

View File

@ -83,7 +83,9 @@ def compute_fbank_aishell2(num_mel_bins: int = 80):
) )
if "train" in partition: if "train" in partition:
cut_set = ( cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
) )
cut_set = cut_set.compute_and_store_features( cut_set = cut_set.compute_and_store_features(
extractor=extractor, extractor=extractor,
@ -109,7 +111,9 @@ def get_args():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)

View File

View File

@ -76,12 +76,10 @@ class AiShell2AsrDataModule:
def add_arguments(cls, parser: argparse.ArgumentParser): def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group( group = parser.add_argument_group(
title="ASR data related options", title="ASR data related options",
description=( description="These options are used for the preparation of "
"These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the " "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data " "effective batch sizes, sampling strategies, applied data "
"augmentations, etc." "augmentations, etc.",
),
) )
group.add_argument( group.add_argument(
"--manifest-dir", "--manifest-dir",
@ -93,74 +91,59 @@ class AiShell2AsrDataModule:
"--max-duration", "--max-duration",
type=int, type=int,
default=200.0, default=200.0,
help=( help="Maximum pooled recordings duration (seconds) in a "
"Maximum pooled recordings duration (seconds) in a " "single batch. You can reduce it if it causes CUDA OOM.",
"single batch. You can reduce it if it causes CUDA OOM."
),
) )
group.add_argument( group.add_argument(
"--bucketing-sampler", "--bucketing-sampler",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled, the batches will come from buckets of "
"When enabled, the batches will come from buckets of " "similar duration (saves padding frames).",
"similar duration (saves padding frames)."
),
) )
group.add_argument( group.add_argument(
"--num-buckets", "--num-buckets",
type=int, type=int,
default=30, default=30,
help=( help="The number of buckets for the DynamicBucketingSampler"
"The number of buckets for the DynamicBucketingSampler" "(you might want to increase it for larger datasets).",
"(you might want to increase it for larger datasets)."
),
) )
group.add_argument( group.add_argument(
"--concatenate-cuts", "--concatenate-cuts",
type=str2bool, type=str2bool,
default=False, default=False,
help=( help="When enabled, utterances (cuts) will be concatenated "
"When enabled, utterances (cuts) will be concatenated " "to minimize the amount of padding.",
"to minimize the amount of padding."
),
) )
group.add_argument( group.add_argument(
"--duration-factor", "--duration-factor",
type=float, type=float,
default=1.0, default=1.0,
help=( help="Determines the maximum duration of a concatenated cut "
"Determines the maximum duration of a concatenated cut " "relative to the duration of the longest cut in a batch.",
"relative to the duration of the longest cut in a batch."
),
) )
group.add_argument( group.add_argument(
"--gap", "--gap",
type=float, type=float,
default=1.0, default=1.0,
help=( help="The amount of padding (in seconds) inserted between "
"The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when " "concatenated cuts. This padding is filled with noise when "
"noise augmentation is used." "noise augmentation is used.",
),
) )
group.add_argument( group.add_argument(
"--on-the-fly-feats", "--on-the-fly-feats",
type=str2bool, type=str2bool,
default=False, default=False,
help=( help="When enabled, use on-the-fly cut mixing and feature "
"When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests " "extraction. Will drop existing precomputed feature manifests "
"if available." "if available.",
),
) )
group.add_argument( group.add_argument(
"--shuffle", "--shuffle",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled (=default), the examples will be "
"When enabled (=default), the examples will be shuffled for each epoch." "shuffled for each epoch.",
),
) )
group.add_argument( group.add_argument(
"--drop-last", "--drop-last",
@ -172,18 +155,17 @@ class AiShell2AsrDataModule:
"--return-cuts", "--return-cuts",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled, each batch will have the "
"When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that " "field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it." "were used to construct it.",
),
) )
group.add_argument( group.add_argument(
"--num-workers", "--num-workers",
type=int, type=int,
default=2, default=2,
help="The number of training dataloader workers that collect the batches.", help="The number of training dataloader workers that "
"collect the batches.",
) )
group.add_argument( group.add_argument(
@ -197,22 +179,18 @@ class AiShell2AsrDataModule:
"--spec-aug-time-warp-factor", "--spec-aug-time-warp-factor",
type=int, type=int,
default=80, default=80,
help=( help="Used only when --enable-spec-aug is True. "
"Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. " "It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. " "Larger values mean more warping. "
"A value less than 1 means to disable time warp." "A value less than 1 means to disable time warp.",
),
) )
group.add_argument( group.add_argument(
"--enable-musan", "--enable-musan",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled, select noise from MUSAN and mix it"
"When enabled, select noise from MUSAN and mix it" "with training dataset. ",
"with training dataset. "
),
) )
group.add_argument( group.add_argument(
@ -238,16 +216,20 @@ class AiShell2AsrDataModule:
if self.args.enable_musan: if self.args.enable_musan:
logging.info("Enable MUSAN") logging.info("Enable MUSAN")
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") cuts_musan = load_manifest(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms.append( transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
if self.args.concatenate_cuts: if self.args.concatenate_cuts:
logging.info( logging.info(
"Using cut concatenation with duration factor " f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}." f"{self.args.duration_factor} and gap {self.args.gap}."
) )
# Cut concatenation should be the first transform in the list, # Cut concatenation should be the first transform in the list,
@ -262,7 +244,9 @@ class AiShell2AsrDataModule:
input_transforms = [] input_transforms = []
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
logging.info("Enable SpecAugment") logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version. # Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is # In different Lhotse's versions, the default of num_frame_masks is
# different. # different.
@ -306,7 +290,9 @@ class AiShell2AsrDataModule:
# Drop feats to be on the safe side. # Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -362,7 +348,9 @@ class AiShell2AsrDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:
@ -418,7 +406,9 @@ class AiShell2AsrDataModule:
@lru_cache() @lru_cache()
def valid_cuts(self) -> CutSet: def valid_cuts(self) -> CutSet:
logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz") logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz")
return load_manifest_lazy(self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz") return load_manifest_lazy(
self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz"
)
@lru_cache() @lru_cache()
def test_cuts(self) -> CutSet: def test_cuts(self) -> CutSet:

View File

@ -168,24 +168,20 @@ def get_parser():
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help=( help="Number of checkpoints to average. Automatically select "
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'" "'--epoch' and '--iter'",
),
) )
parser.add_argument( parser.add_argument(
"--use-averaged-model", "--use-averaged-model",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="Whether to load averaged model. Currently it only supports "
"Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model " "using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`." "over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and " "Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. " "`epoch` are loaded for averaging. ",
),
) )
parser.add_argument( parser.add_argument(
@ -273,7 +269,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -351,7 +348,9 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
@ -410,7 +409,10 @@ def decode_one_batch(
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -536,7 +538,9 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -569,7 +573,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
@ -620,7 +625,9 @@ def main():
if "LG" in params.decoding_method: if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -654,12 +661,13 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ filenames = find_checkpoints(
: params.avg params.exp_dir, iteration=-params.iter
] )[: params.avg]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg: elif len(filenames) < params.avg:
raise ValueError( raise ValueError(
@ -682,12 +690,13 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ filenames = find_checkpoints(
: params.avg + 1 params.exp_dir, iteration=-params.iter
] )[: params.avg + 1]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg + 1: elif len(filenames) < params.avg + 1:
raise ValueError( raise ValueError(
@ -715,7 +724,7 @@ def main():
filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info( logging.info(
"Calculating the averaged model over epoch range from " f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}" f"{start} (excluded) to {params.epoch}"
) )
model.to(device) model.to(device)
@ -740,7 +749,9 @@ def main():
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else: else:
decoding_graph = None decoding_graph = None

View File

@ -89,24 +89,20 @@ def get_parser():
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help=( help="Number of checkpoints to average. Automatically select "
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'" "'--epoch' and '--iter'",
),
) )
parser.add_argument( parser.add_argument(
"--use-averaged-model", "--use-averaged-model",
type=str2bool, type=str2bool,
default=False, default=False,
help=( help="Whether to load averaged model. Currently it only supports "
"Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model " "using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`." "over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and " "Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. " "`epoch` are loaded for averaging. ",
),
) )
parser.add_argument( parser.add_argument(
@ -137,7 +133,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -170,12 +167,13 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ filenames = find_checkpoints(
: params.avg params.exp_dir, iteration=-params.iter
] )[: params.avg]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg: elif len(filenames) < params.avg:
raise ValueError( raise ValueError(
@ -198,12 +196,13 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ filenames = find_checkpoints(
: params.avg + 1 params.exp_dir, iteration=-params.iter
] )[: params.avg + 1]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg + 1: elif len(filenames) < params.avg + 1:
raise ValueError( raise ValueError(
@ -231,7 +230,7 @@ def main():
filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info( logging.info(
"Calculating the averaged model over epoch range from " f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}" f"{start} (excluded) to {params.epoch}"
) )
model.to(device) model.to(device)
@ -267,7 +266,9 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -81,11 +81,9 @@ def get_parser():
"--checkpoint", "--checkpoint",
type=str, type=str,
required=True, required=True,
help=( help="Path to the checkpoint. "
"Path to the checkpoint. "
"The checkpoint is assumed to be saved by " "The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint()." "icefall.checkpoint.save_checkpoint().",
),
) )
parser.add_argument( parser.add_argument(
@ -111,12 +109,10 @@ def get_parser():
"sound_files", "sound_files",
type=str, type=str,
nargs="+", nargs="+",
help=( help="The input sound file(s) to transcribe. "
"The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). " "Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. " "For example, wav and flac are supported. "
"The sample rate has to be 16kHz." "The sample rate has to be 16kHz.",
),
) )
parser.add_argument( parser.add_argument(
@ -163,7 +159,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -194,9 +191,10 @@ def read_sound_files(
ans = [] ans = []
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert ( assert sample_rate == expected_sample_rate, (
sample_rate == expected_sample_rate f"expected sample rate: {expected_sample_rate}. "
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" f"Given: {sample_rate}"
)
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
return ans return ans
@ -256,11 +254,15 @@ def main():
features = fbank(waves) features = fbank(waves)
feature_lengths = [f.size(0) for f in features] feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device) feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyps = [] hyps = []
@ -332,7 +334,9 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -92,7 +92,9 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
@ -218,7 +220,8 @@ def get_parser():
"--initial-lr", "--initial-lr",
type=float, type=float,
default=0.003, default=0.003,
help="The initial learning rate. This value should not need to be changed.", help="The initial learning rate. This value should not need "
"to be changed.",
) )
parser.add_argument( parser.add_argument(
@ -241,45 +244,42 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--prune-range", "--prune-range",
type=int, type=int,
default=5, default=5,
help=( help="The prune range for rnnt loss, it means how many symbols(context)"
"The prune range for rnnt loss, it means how many symbols(context)" "we are using to compute the loss",
"we are using to compute the loss"
),
) )
parser.add_argument( parser.add_argument(
"--lm-scale", "--lm-scale",
type=float, type=float,
default=0.25, default=0.25,
help=( help="The scale to smooth the loss with lm "
"The scale to smooth the loss with lm (output of prediction network) part." "(output of prediction network) part.",
),
) )
parser.add_argument( parser.add_argument(
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)part.", help="The scale to smooth the loss with am (output of encoder network)"
"part.",
) )
parser.add_argument( parser.add_argument(
"--simple-loss-scale", "--simple-loss-scale",
type=float, type=float,
default=0.5, default=0.5,
help=( help="To get pruning ranges, we will calculate a simple version"
"To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for" "loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss" "training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss." "with this parameter before adding to the final loss.",
),
) )
parser.add_argument( parser.add_argument(
@ -603,7 +603,11 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = model.device if isinstance(model, DDP) else next(model.parameters()).device device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
@ -632,16 +636,23 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge, # overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet. # in case it had not fully learned the alignment yet.
pruned_loss_scale = ( pruned_loss_scale = (
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) 0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
) )
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -760,7 +771,9 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
except: # noqa except: # noqa
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) display_and_save_batch(
batch, params=params, graph_compiler=graph_compiler
)
raise raise
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
@ -816,7 +829,9 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
@ -1089,7 +1104,9 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} " f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..." f"(={crit_values[criterion]}) ..."
) )
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) display_and_save_batch(
batch, params=params, graph_compiler=graph_compiler
)
raise raise

View File

@ -85,7 +85,9 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
) )
if "train" in partition: if "train" in partition:
cut_set = ( cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
) )
cut_set = cut_set.compute_and_store_features( cut_set = cut_set.compute_and_store_features(
extractor=extractor, extractor=extractor,
@ -118,7 +120,9 @@ def get_args():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state cur_state = loop_state
word = word2id[word] word = word2id[word]
pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces] pieces = [
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1): for i in range(len(pieces) - 1):
w = word if i == 0 else eps w = word if i == 0 else eps
@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False return False
def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: def generate_lexicon(
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table. """Generate a lexicon from a word list and token_sym_table.
Args: Args:

View File

@ -317,7 +317,9 @@ def lexicon_to_fst(
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") parser.add_argument(
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
)
return parser.parse_args() return parser.parse_args()

View File

@ -88,7 +88,9 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L") fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig = lexicon_to_fst(
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig") fsa_disambig.draw("L_disambig.pdf", title="L_disambig")

View File

@ -50,15 +50,15 @@ def get_parser():
"-n", "-n",
default=1, default=1,
type=int, type=int,
help=( help="number of characters to split, i.e., \
"number of characters to split, i.e., aabb -> a a b" aabb -> a a b b with -n 1 and aa bb with -n 2",
" b with -n 1 and aa bb with -n 2"
),
) )
parser.add_argument( parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns" "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
) )
parser.add_argument("--space", default="<space>", type=str, help="space symbol") parser.add_argument(
"--space", default="<space>", type=str, help="space symbol"
)
parser.add_argument( parser.add_argument(
"--non-lang-syms", "--non-lang-syms",
"-l", "-l",
@ -66,7 +66,9 @@ def get_parser():
type=str, type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.", help="list of non-linguistic symobles, e.g., <NOISE> etc.",
) )
parser.add_argument("text", type=str, default=False, nargs="?", help="input text") parser.add_argument(
"text", type=str, default=False, nargs="?", help="input text"
)
parser.add_argument( parser.add_argument(
"--trans_type", "--trans_type",
"-t", "-t",
@ -106,7 +108,8 @@ def token2id(
if token_type == "lazy_pinyin": if token_type == "lazy_pinyin":
text = lazy_pinyin(chars_list) text = lazy_pinyin(chars_list)
sub_ids = [ sub_ids = [
token_table[txt] if txt in token_table else oov_id for txt in text token_table[txt] if txt in token_table else oov_id
for txt in text
] ]
ids.append(sub_ids) ids.append(sub_ids)
else: # token_type = "pinyin" else: # token_type = "pinyin"
@ -132,7 +135,9 @@ def main():
if args.text: if args.text:
f = codecs.open(args.text, encoding="utf-8") f = codecs.open(args.text, encoding="utf-8")
else: else:
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) f = codecs.getreader("utf-8")(
sys.stdin if is_python2 else sys.stdin.buffer
)
sys.stdout = codecs.getwriter("utf-8")( sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer sys.stdout if is_python2 else sys.stdout.buffer

View File

@ -74,12 +74,10 @@ class Aishell4AsrDataModule:
def add_arguments(cls, parser: argparse.ArgumentParser): def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group( group = parser.add_argument_group(
title="ASR data related options", title="ASR data related options",
description=( description="These options are used for the preparation of "
"These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the " "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data " "effective batch sizes, sampling strategies, applied data "
"augmentations, etc." "augmentations, etc.",
),
) )
group.add_argument( group.add_argument(
@ -93,81 +91,66 @@ class Aishell4AsrDataModule:
"--max-duration", "--max-duration",
type=int, type=int,
default=200.0, default=200.0,
help=( help="Maximum pooled recordings duration (seconds) in a "
"Maximum pooled recordings duration (seconds) in a " "single batch. You can reduce it if it causes CUDA OOM.",
"single batch. You can reduce it if it causes CUDA OOM."
),
) )
group.add_argument( group.add_argument(
"--bucketing-sampler", "--bucketing-sampler",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled, the batches will come from buckets of "
"When enabled, the batches will come from buckets of " "similar duration (saves padding frames).",
"similar duration (saves padding frames)."
),
) )
group.add_argument( group.add_argument(
"--num-buckets", "--num-buckets",
type=int, type=int,
default=300, default=300,
help=( help="The number of buckets for the DynamicBucketingSampler"
"The number of buckets for the DynamicBucketingSampler" "(you might want to increase it for larger datasets).",
"(you might want to increase it for larger datasets)."
),
) )
group.add_argument( group.add_argument(
"--concatenate-cuts", "--concatenate-cuts",
type=str2bool, type=str2bool,
default=False, default=False,
help=( help="When enabled, utterances (cuts) will be concatenated "
"When enabled, utterances (cuts) will be concatenated " "to minimize the amount of padding.",
"to minimize the amount of padding."
),
) )
group.add_argument( group.add_argument(
"--duration-factor", "--duration-factor",
type=float, type=float,
default=1.0, default=1.0,
help=( help="Determines the maximum duration of a concatenated cut "
"Determines the maximum duration of a concatenated cut " "relative to the duration of the longest cut in a batch.",
"relative to the duration of the longest cut in a batch."
),
) )
group.add_argument( group.add_argument(
"--gap", "--gap",
type=float, type=float,
default=1.0, default=1.0,
help=( help="The amount of padding (in seconds) inserted between "
"The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when " "concatenated cuts. This padding is filled with noise when "
"noise augmentation is used." "noise augmentation is used.",
),
) )
group.add_argument( group.add_argument(
"--on-the-fly-feats", "--on-the-fly-feats",
type=str2bool, type=str2bool,
default=False, default=False,
help=( help="When enabled, use on-the-fly cut mixing and feature "
"When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests " "extraction. Will drop existing precomputed feature manifests "
"if available." "if available.",
),
) )
group.add_argument( group.add_argument(
"--shuffle", "--shuffle",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled (=default), the examples will be "
"When enabled (=default), the examples will be shuffled for each epoch." "shuffled for each epoch.",
),
) )
group.add_argument( group.add_argument(
@ -181,18 +164,17 @@ class Aishell4AsrDataModule:
"--return-cuts", "--return-cuts",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled, each batch will have the "
"When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that " "field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it." "were used to construct it.",
),
) )
group.add_argument( group.add_argument(
"--num-workers", "--num-workers",
type=int, type=int,
default=2, default=2,
help="The number of training dataloader workers that collect the batches.", help="The number of training dataloader workers that "
"collect the batches.",
) )
group.add_argument( group.add_argument(
@ -206,22 +188,18 @@ class Aishell4AsrDataModule:
"--spec-aug-time-warp-factor", "--spec-aug-time-warp-factor",
type=int, type=int,
default=80, default=80,
help=( help="Used only when --enable-spec-aug is True. "
"Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. " "It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. " "Larger values mean more warping. "
"A value less than 1 means to disable time warp." "A value less than 1 means to disable time warp.",
),
) )
group.add_argument( group.add_argument(
"--enable-musan", "--enable-musan",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled, select noise from MUSAN and mix it"
"When enabled, select noise from MUSAN and mix it" "with training dataset. ",
"with training dataset. "
),
) )
group.add_argument( group.add_argument(
@ -244,20 +222,24 @@ class Aishell4AsrDataModule:
The state dict for the training sampler. The state dict for the training sampler.
""" """
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") cuts_musan = load_manifest(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms = [] transforms = []
if self.args.enable_musan: if self.args.enable_musan:
logging.info("Enable MUSAN") logging.info("Enable MUSAN")
transforms.append( transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
if self.args.concatenate_cuts: if self.args.concatenate_cuts:
logging.info( logging.info(
"Using cut concatenation with duration factor " f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}." f"{self.args.duration_factor} and gap {self.args.gap}."
) )
# Cut concatenation should be the first transform in the list, # Cut concatenation should be the first transform in the list,
@ -272,7 +254,9 @@ class Aishell4AsrDataModule:
input_transforms = [] input_transforms = []
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
logging.info("Enable SpecAugment") logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version. # Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is # In different Lhotse's versions, the default of num_frame_masks is
# different. # different.
@ -316,7 +300,9 @@ class Aishell4AsrDataModule:
# Drop feats to be on the safe side. # Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -373,7 +359,9 @@ class Aishell4AsrDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:

View File

@ -117,24 +117,20 @@ def get_parser():
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help=( help="Number of checkpoints to average. Automatically select "
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'" "'--epoch' and '--iter'",
),
) )
parser.add_argument( parser.add_argument(
"--use-averaged-model", "--use-averaged-model",
type=str2bool, type=str2bool,
default=False, default=False,
help=( help="Whether to load averaged model. Currently it only supports "
"Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model " "using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`." "over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and " "Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. " "`epoch` are loaded for averaging. ",
),
) )
parser.add_argument( parser.add_argument(
@ -205,7 +201,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -263,7 +260,9 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
@ -278,7 +277,10 @@ def decode_one_batch(
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -324,7 +326,11 @@ def decode_one_batch(
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
return { return {
f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
} }
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -395,7 +401,9 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -428,7 +436,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
@ -471,7 +480,9 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -499,12 +510,13 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ filenames = find_checkpoints(
: params.avg params.exp_dir, iteration=-params.iter
] )[: params.avg]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg: elif len(filenames) < params.avg:
raise ValueError( raise ValueError(
@ -531,12 +543,13 @@ def main():
) )
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ filenames = find_checkpoints(
: params.avg + 1 params.exp_dir, iteration=-params.iter
] )[: params.avg + 1]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg + 1: elif len(filenames) < params.avg + 1:
raise ValueError( raise ValueError(
@ -565,7 +578,7 @@ def main():
filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info( logging.info(
"Calculating the averaged model over epoch range from " f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}" f"{start} (excluded) to {params.epoch}"
) )
model.to(device) model.to(device)

View File

@ -89,24 +89,20 @@ def get_parser():
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help=( help="Number of checkpoints to average. Automatically select "
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'" "'--epoch' and '--iter'",
),
) )
parser.add_argument( parser.add_argument(
"--use-averaged-model", "--use-averaged-model",
type=str2bool, type=str2bool,
default=False, default=False,
help=( help="Whether to load averaged model. Currently it only supports "
"Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model " "using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`." "over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and " "Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. " "`epoch` are loaded for averaging. ",
),
) )
parser.add_argument( parser.add_argument(
@ -140,7 +136,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -172,12 +169,13 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ filenames = find_checkpoints(
: params.avg params.exp_dir, iteration=-params.iter
] )[: params.avg]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg: elif len(filenames) < params.avg:
raise ValueError( raise ValueError(
@ -204,12 +202,13 @@ def main():
) )
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ filenames = find_checkpoints(
: params.avg + 1 params.exp_dir, iteration=-params.iter
] )[: params.avg + 1]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg + 1: elif len(filenames) < params.avg + 1:
raise ValueError( raise ValueError(
@ -238,7 +237,7 @@ def main():
filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info( logging.info(
"Calculating the averaged model over epoch range from " f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}" f"{start} (excluded) to {params.epoch}"
) )
model.to(device) model.to(device)
@ -277,7 +276,9 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -94,11 +94,9 @@ def get_parser():
"--checkpoint", "--checkpoint",
type=str, type=str,
required=True, required=True,
help=( help="Path to the checkpoint. "
"Path to the checkpoint. "
"The checkpoint is assumed to be saved by " "The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint()." "icefall.checkpoint.save_checkpoint().",
),
) )
parser.add_argument( parser.add_argument(
@ -124,12 +122,10 @@ def get_parser():
"sound_files", "sound_files",
type=str, type=str,
nargs="+", nargs="+",
help=( help="The input sound file(s) to transcribe. "
"The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). " "Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. " "For example, wav and flac are supported. "
"The sample rate has to be 16kHz." "The sample rate has to be 16kHz.",
),
) )
parser.add_argument( parser.add_argument(
@ -176,7 +172,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -207,9 +204,10 @@ def read_sound_files(
ans = [] ans = []
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert ( assert sample_rate == expected_sample_rate, (
sample_rate == expected_sample_rate f"expected sample rate: {expected_sample_rate}. "
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" f"Given: {sample_rate}"
)
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
return ans return ans
@ -268,11 +266,15 @@ def main():
features = fbank(waves) features = fbank(waves)
feature_lengths = [f.size(0) for f in features] feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device) feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyps = [] hyps = []
@ -304,7 +306,10 @@ def main():
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -345,7 +350,9 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -85,7 +85,9 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
@ -211,7 +213,8 @@ def get_parser():
"--initial-lr", "--initial-lr",
type=float, type=float,
default=0.003, default=0.003,
help="The initial learning rate. This value should not need to be changed.", help="The initial learning rate. This value should not need "
"to be changed.",
) )
parser.add_argument( parser.add_argument(
@ -234,45 +237,42 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--prune-range", "--prune-range",
type=int, type=int,
default=5, default=5,
help=( help="The prune range for rnnt loss, it means how many symbols(context)"
"The prune range for rnnt loss, it means how many symbols(context)" "we are using to compute the loss",
"we are using to compute the loss"
),
) )
parser.add_argument( parser.add_argument(
"--lm-scale", "--lm-scale",
type=float, type=float,
default=0.25, default=0.25,
help=( help="The scale to smooth the loss with lm "
"The scale to smooth the loss with lm (output of prediction network) part." "(output of prediction network) part.",
),
) )
parser.add_argument( parser.add_argument(
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)part.", help="The scale to smooth the loss with am (output of encoder network)"
"part.",
) )
parser.add_argument( parser.add_argument(
"--simple-loss-scale", "--simple-loss-scale",
type=float, type=float,
default=0.5, default=0.5,
help=( help="To get pruning ranges, we will calculate a simple version"
"To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for" "loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss" "training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss." "with this parameter before adding to the final loss.",
),
) )
parser.add_argument( parser.add_argument(
@ -599,7 +599,11 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = model.device if isinstance(model, DDP) else next(model.parameters()).device device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
@ -629,15 +633,22 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge, # overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet. # in case it had not fully learned the alignment yet.
pruned_loss_scale = ( pruned_loss_scale = (
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) 0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
) )
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -816,7 +827,9 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")

View File

@ -84,7 +84,9 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
) )
if "train" in partition: if "train" in partition:
cut_set = ( cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
) )
cur_num_jobs = num_jobs if ex is None else 80 cur_num_jobs = num_jobs if ex is None else 80
cur_num_jobs = min(cur_num_jobs, len(cut_set)) cur_num_jobs = min(cur_num_jobs, len(cut_set))
@ -119,7 +121,9 @@ def get_args():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state cur_state = loop_state
word = word2id[word] word = word2id[word]
pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces] pieces = [
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1): for i in range(len(pieces) - 1):
w = word if i == 0 else eps w = word if i == 0 else eps
@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False return False
def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: def generate_lexicon(
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table. """Generate a lexicon from a word list and token_sym_table.
Args: Args:

View File

@ -317,7 +317,9 @@ def lexicon_to_fst(
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") parser.add_argument(
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
)
return parser.parse_args() return parser.parse_args()

View File

@ -88,7 +88,9 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L") fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig = lexicon_to_fst(
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig") fsa_disambig.draw("L_disambig.pdf", title="L_disambig")

View File

@ -30,8 +30,8 @@ with word segmenting:
import argparse import argparse
import jieba
import paddle import paddle
import jieba
from tqdm import tqdm from tqdm import tqdm
paddle.enable_static() paddle.enable_static()

View File

@ -50,15 +50,15 @@ def get_parser():
"-n", "-n",
default=1, default=1,
type=int, type=int,
help=( help="number of characters to split, i.e., \
"number of characters to split, i.e., aabb -> a a b" aabb -> a a b b with -n 1 and aa bb with -n 2",
" b with -n 1 and aa bb with -n 2"
),
) )
parser.add_argument( parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns" "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
) )
parser.add_argument("--space", default="<space>", type=str, help="space symbol") parser.add_argument(
"--space", default="<space>", type=str, help="space symbol"
)
parser.add_argument( parser.add_argument(
"--non-lang-syms", "--non-lang-syms",
"-l", "-l",
@ -66,7 +66,9 @@ def get_parser():
type=str, type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.", help="list of non-linguistic symobles, e.g., <NOISE> etc.",
) )
parser.add_argument("text", type=str, default=False, nargs="?", help="input text") parser.add_argument(
"text", type=str, default=False, nargs="?", help="input text"
)
parser.add_argument( parser.add_argument(
"--trans_type", "--trans_type",
"-t", "-t",
@ -106,7 +108,8 @@ def token2id(
if token_type == "lazy_pinyin": if token_type == "lazy_pinyin":
text = lazy_pinyin(chars_list) text = lazy_pinyin(chars_list)
sub_ids = [ sub_ids = [
token_table[txt] if txt in token_table else oov_id for txt in text token_table[txt] if txt in token_table else oov_id
for txt in text
] ]
ids.append(sub_ids) ids.append(sub_ids)
else: # token_type = "pinyin" else: # token_type = "pinyin"
@ -132,7 +135,9 @@ def main():
if args.text: if args.text:
f = codecs.open(args.text, encoding="utf-8") f = codecs.open(args.text, encoding="utf-8")
else: else:
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) f = codecs.getreader("utf-8")(
sys.stdin if is_python2 else sys.stdin.buffer
)
sys.stdout = codecs.getwriter("utf-8")( sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer sys.stdout if is_python2 else sys.stdout.buffer

View File

@ -81,12 +81,10 @@ class AlimeetingAsrDataModule:
def add_arguments(cls, parser: argparse.ArgumentParser): def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group( group = parser.add_argument_group(
title="ASR data related options", title="ASR data related options",
description=( description="These options are used for the preparation of "
"These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the " "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data " "effective batch sizes, sampling strategies, applied data "
"augmentations, etc." "augmentations, etc.",
),
) )
group.add_argument( group.add_argument(
"--manifest-dir", "--manifest-dir",
@ -98,91 +96,75 @@ class AlimeetingAsrDataModule:
"--max-duration", "--max-duration",
type=int, type=int,
default=200.0, default=200.0,
help=( help="Maximum pooled recordings duration (seconds) in a "
"Maximum pooled recordings duration (seconds) in a " "single batch. You can reduce it if it causes CUDA OOM.",
"single batch. You can reduce it if it causes CUDA OOM."
),
) )
group.add_argument( group.add_argument(
"--bucketing-sampler", "--bucketing-sampler",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled, the batches will come from buckets of "
"When enabled, the batches will come from buckets of " "similar duration (saves padding frames).",
"similar duration (saves padding frames)."
),
) )
group.add_argument( group.add_argument(
"--num-buckets", "--num-buckets",
type=int, type=int,
default=300, default=300,
help=( help="The number of buckets for the DynamicBucketingSampler"
"The number of buckets for the DynamicBucketingSampler" "(you might want to increase it for larger datasets).",
"(you might want to increase it for larger datasets)."
),
) )
group.add_argument( group.add_argument(
"--concatenate-cuts", "--concatenate-cuts",
type=str2bool, type=str2bool,
default=False, default=False,
help=( help="When enabled, utterances (cuts) will be concatenated "
"When enabled, utterances (cuts) will be concatenated " "to minimize the amount of padding.",
"to minimize the amount of padding."
),
) )
group.add_argument( group.add_argument(
"--duration-factor", "--duration-factor",
type=float, type=float,
default=1.0, default=1.0,
help=( help="Determines the maximum duration of a concatenated cut "
"Determines the maximum duration of a concatenated cut " "relative to the duration of the longest cut in a batch.",
"relative to the duration of the longest cut in a batch."
),
) )
group.add_argument( group.add_argument(
"--gap", "--gap",
type=float, type=float,
default=1.0, default=1.0,
help=( help="The amount of padding (in seconds) inserted between "
"The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when " "concatenated cuts. This padding is filled with noise when "
"noise augmentation is used." "noise augmentation is used.",
),
) )
group.add_argument( group.add_argument(
"--on-the-fly-feats", "--on-the-fly-feats",
type=str2bool, type=str2bool,
default=False, default=False,
help=( help="When enabled, use on-the-fly cut mixing and feature "
"When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests " "extraction. Will drop existing precomputed feature manifests "
"if available." "if available.",
),
) )
group.add_argument( group.add_argument(
"--shuffle", "--shuffle",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled (=default), the examples will be "
"When enabled (=default), the examples will be shuffled for each epoch." "shuffled for each epoch.",
),
) )
group.add_argument( group.add_argument(
"--return-cuts", "--return-cuts",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled, each batch will have the "
"When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that " "field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it." "were used to construct it.",
),
) )
group.add_argument( group.add_argument(
"--num-workers", "--num-workers",
type=int, type=int,
default=2, default=2,
help="The number of training dataloader workers that collect the batches.", help="The number of training dataloader workers that "
"collect the batches.",
) )
group.add_argument( group.add_argument(
@ -196,22 +178,18 @@ class AlimeetingAsrDataModule:
"--spec-aug-time-warp-factor", "--spec-aug-time-warp-factor",
type=int, type=int,
default=80, default=80,
help=( help="Used only when --enable-spec-aug is True. "
"Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. " "It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. " "Larger values mean more warping. "
"A value less than 1 means to disable time warp." "A value less than 1 means to disable time warp.",
),
) )
group.add_argument( group.add_argument(
"--enable-musan", "--enable-musan",
type=str2bool, type=str2bool,
default=True, default=True,
help=( help="When enabled, select noise from MUSAN and mix it"
"When enabled, select noise from MUSAN and mix it" "with training dataset. ",
"with training dataset. "
),
) )
def train_dataloaders( def train_dataloaders(
@ -227,20 +205,24 @@ class AlimeetingAsrDataModule:
The state dict for the training sampler. The state dict for the training sampler.
""" """
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") cuts_musan = load_manifest(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms = [] transforms = []
if self.args.enable_musan: if self.args.enable_musan:
logging.info("Enable MUSAN") logging.info("Enable MUSAN")
transforms.append( transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
if self.args.concatenate_cuts: if self.args.concatenate_cuts:
logging.info( logging.info(
"Using cut concatenation with duration factor " f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}." f"{self.args.duration_factor} and gap {self.args.gap}."
) )
# Cut concatenation should be the first transform in the list, # Cut concatenation should be the first transform in the list,
@ -255,7 +237,9 @@ class AlimeetingAsrDataModule:
input_transforms = [] input_transforms = []
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
logging.info("Enable SpecAugment") logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version. # Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is # In different Lhotse's versions, the default of num_frame_masks is
# different. # different.
@ -298,7 +282,9 @@ class AlimeetingAsrDataModule:
# Drop feats to be on the safe side. # Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -355,7 +341,9 @@ class AlimeetingAsrDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:

View File

@ -70,7 +70,11 @@ from beam_search import (
from lhotse.cut import Cut from lhotse.cut import Cut
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -89,30 +93,25 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=28, default=28,
help=( help="It specifies the checkpoint to use for decoding."
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." "Note: Epoch counts from 0.",
),
) )
parser.add_argument( parser.add_argument(
"--batch", "--batch",
type=int, type=int,
default=None, default=None,
help=( help="It specifies the batch checkpoint to use for decoding."
"It specifies the batch checkpoint to use for decoding." "Note: Epoch counts from 0.",
"Note: Epoch counts from 0."
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help=( help="Number of checkpoints to average. Automatically select "
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. " "'--epoch'. ",
),
) )
parser.add_argument( parser.add_argument(
@ -194,7 +193,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -249,7 +249,9 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
@ -264,7 +266,10 @@ def decode_one_batch(
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -310,7 +315,11 @@ def decode_one_batch(
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
return { return {
f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
} }
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -381,7 +390,9 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -414,7 +425,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
@ -551,7 +563,8 @@ def main():
) )
dev_shards = [ dev_shards = [
str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) str(path)
for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
] ]
cuts_dev_webdataset = CutSet.from_webdataset( cuts_dev_webdataset = CutSet.from_webdataset(
dev_shards, dev_shards,
@ -561,7 +574,8 @@ def main():
) )
test_shards = [ test_shards = [
str(path) for path in sorted(glob.glob(os.path.join(test, "shared-*.tar"))) str(path)
for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
] ]
cuts_test_webdataset = CutSet.from_webdataset( cuts_test_webdataset = CutSet.from_webdataset(
test_shards, test_shards,
@ -574,7 +588,9 @@ def main():
return 1.0 <= c.duration return 1.0 <= c.duration
cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt) cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt)
cuts_test_webdataset = cuts_test_webdataset.filter(remove_short_and_long_utt) cuts_test_webdataset = cuts_test_webdataset.filter(
remove_short_and_long_utt
)
dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset) dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset)
test_dl = alimeeting.test_dataloaders(cuts_test_webdataset) test_dl = alimeeting.test_dataloaders(cuts_test_webdataset)

View File

@ -62,20 +62,17 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=28, default=28,
help=( help="It specifies the checkpoint to use for decoding."
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." "Note: Epoch counts from 0.",
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help=( help="Number of checkpoints to average. Automatically select "
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. " "'--epoch'. ",
),
) )
parser.add_argument( parser.add_argument(
@ -106,7 +103,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
return parser return parser
@ -175,7 +173,9 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -85,11 +85,9 @@ def get_parser():
"--checkpoint", "--checkpoint",
type=str, type=str,
required=True, required=True,
help=( help="Path to the checkpoint. "
"Path to the checkpoint. "
"The checkpoint is assumed to be saved by " "The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint()." "icefall.checkpoint.save_checkpoint().",
),
) )
parser.add_argument( parser.add_argument(
@ -114,12 +112,10 @@ def get_parser():
"sound_files", "sound_files",
type=str, type=str,
nargs="+", nargs="+",
help=( help="The input sound file(s) to transcribe. "
"The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). " "Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. " "For example, wav and flac are supported. "
"The sample rate has to be 16kHz." "The sample rate has to be 16kHz.",
),
) )
parser.add_argument( parser.add_argument(
@ -166,7 +162,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -196,9 +193,10 @@ def read_sound_files(
ans = [] ans = []
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert ( assert sample_rate == expected_sample_rate, (
sample_rate == expected_sample_rate f"expected sample rate: {expected_sample_rate}. "
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" f"Given: {sample_rate}"
)
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
return ans return ans
@ -259,7 +257,9 @@ def main():
features = fbank(waves) features = fbank(waves)
feature_lengths = [f.size(0) for f in features] feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device) feature_lengths = torch.tensor(feature_lengths, device=device)
@ -284,7 +284,10 @@ def main():
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -336,7 +339,9 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -81,7 +81,9 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
@ -185,45 +187,42 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--prune-range", "--prune-range",
type=int, type=int,
default=5, default=5,
help=( help="The prune range for rnnt loss, it means how many symbols(context)"
"The prune range for rnnt loss, it means how many symbols(context)" "we are using to compute the loss",
"we are using to compute the loss"
),
) )
parser.add_argument( parser.add_argument(
"--lm-scale", "--lm-scale",
type=float, type=float,
default=0.25, default=0.25,
help=( help="The scale to smooth the loss with lm "
"The scale to smooth the loss with lm (output of prediction network) part." "(output of prediction network) part.",
),
) )
parser.add_argument( parser.add_argument(
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)part.", help="The scale to smooth the loss with am (output of encoder network)"
"part.",
) )
parser.add_argument( parser.add_argument(
"--simple-loss-scale", "--simple-loss-scale",
type=float, type=float,
default=0.5, default=0.5,
help=( help="To get pruning ranges, we will calculate a simple version"
"To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for" "loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss" "training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss." "with this parameter before adding to the final loss.",
),
) )
parser.add_argument( parser.add_argument(
@ -543,15 +542,22 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge, # overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet. # in case it had not fully learned the alignment yet.
pruned_loss_scale = ( pruned_loss_scale = (
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) 0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
) )
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -705,7 +711,9 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")

View File

@ -25,10 +25,15 @@ from random import Random
from typing import List, Tuple from typing import List, Tuple
import torch import torch
from lhotse import ( # fmt: off; See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527; fmt: on from lhotse import (
CutSet, CutSet,
Fbank, Fbank,
FbankConfig, FbankConfig,
# fmt: off
# See the following for why LilcomChunkyWriter is preferred
# https://github.com/k2-fsa/icefall/pull/404
# https://github.com/lhotse-speech/lhotse/pull/527
# fmt: on
LilcomChunkyWriter, LilcomChunkyWriter,
RecordingSet, RecordingSet,
SupervisionSet, SupervisionSet,
@ -76,13 +81,17 @@ def make_cutset_blueprints(
cut_sets.append((f"eval{i}", cut_set)) cut_sets.append((f"eval{i}", cut_set))
# Create train and valid cuts # Create train and valid cuts
logging.info("Loading, trimming, and shuffling the remaining core+noncore cuts.") logging.info(
"Loading, trimming, and shuffling the remaining core+noncore cuts."
)
recording_set = RecordingSet.from_file( recording_set = RecordingSet.from_file(
manifest_dir / "csj_recordings_core.jsonl.gz" manifest_dir / "csj_recordings_core.jsonl.gz"
) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz") ) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz")
supervision_set = SupervisionSet.from_file( supervision_set = SupervisionSet.from_file(
manifest_dir / "csj_supervisions_core.jsonl.gz" manifest_dir / "csj_supervisions_core.jsonl.gz"
) + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz") ) + SupervisionSet.from_file(
manifest_dir / "csj_supervisions_noncore.jsonl.gz"
)
cut_set = CutSet.from_manifests( cut_set = CutSet.from_manifests(
recordings=recording_set, recordings=recording_set,
@ -92,12 +101,15 @@ def make_cutset_blueprints(
cut_set = cut_set.shuffle(Random(RNG_SEED)) cut_set = cut_set.shuffle(Random(RNG_SEED))
logging.info( logging.info(
f"Creating valid and train cuts from core and noncore,split at {split}." "Creating valid and train cuts from core and noncore,"
f"split at {split}."
) )
valid_set = CutSet.from_cuts(islice(cut_set, 0, split)) valid_set = CutSet.from_cuts(islice(cut_set, 0, split))
train_set = CutSet.from_cuts(islice(cut_set, split, None)) train_set = CutSet.from_cuts(islice(cut_set, split, None))
train_set = train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1) train_set = (
train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1)
)
cut_sets.extend([("valid", valid_set), ("train", train_set)]) cut_sets.extend([("valid", valid_set), ("train", train_set)])
@ -110,9 +122,15 @@ def get_args():
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
) )
parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests") parser.add_argument(
parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features") "--manifest-dir", type=Path, help="Path to save manifests"
parser.add_argument("--split", type=int, default=4000, help="Split at this index") )
parser.add_argument(
"--fbank-dir", type=Path, help="Path to save fbank features"
)
parser.add_argument(
"--split", type=int, default=4000, help="Split at this index"
)
return parser.parse_args() return parser.parse_args()
@ -123,7 +141,9 @@ def main():
extractor = Fbank(FbankConfig(num_mel_bins=80)) extractor = Fbank(FbankConfig(num_mel_bins=80))
num_jobs = min(16, os.cpu_count()) num_jobs = min(16, os.cpu_count())
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -26,6 +26,7 @@ from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor from icefall.utils import get_executor
ARGPARSE_DESCRIPTION = """ ARGPARSE_DESCRIPTION = """
This file computes fbank features of the musan dataset. This file computes fbank features of the musan dataset.
It looks for manifests in the directory data/manifests. It looks for manifests in the directory data/manifests.
@ -83,7 +84,9 @@ def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path):
# create chunks of Musan with duration 5 - 10 seconds # create chunks of Musan with duration 5 - 10 seconds
musan_cuts = ( musan_cuts = (
CutSet.from_manifests( CutSet.from_manifests(
recordings=combine(part["recordings"] for part in manifests.values()) recordings=combine(
part["recordings"] for part in manifests.values()
)
) )
.cut_into_windows(10.0) .cut_into_windows(10.0)
.filter(lambda c: c.duration > 5) .filter(lambda c: c.duration > 5)
@ -104,15 +107,21 @@ def get_args():
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
) )
parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests") parser.add_argument(
parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features") "--manifest-dir", type=Path, help="Path to save manifests"
)
parser.add_argument(
"--fbank-dir", type=Path, help="Path to save fbank features"
)
return parser.parse_args() return parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_musan(args.manifest_dir, args.fbank_dir) compute_fbank_musan(args.manifest_dir, args.fbank_dir)

View File

@ -318,3 +318,4 @@ spk_id = 2
= ǐa = ǐa
= ǐu = ǐu
= ǐo = ǐo

View File

@ -318,3 +318,4 @@ spk_id = 2
= ǐa = ǐa
= ǐu = ǐu
= ǐo = ǐo

Some files were not shown because too many files have changed in this diff Show More